Code Monkey home page Code Monkey logo

segan's Introduction

SEGAN (NNabla)

Implementation of Speech Enhancement GAN (SEGAN) by NNabla

Read me Japanese Ver. (日本語バージョンはこちら) -> Link

Original Paper
SEGAN: Speech Enhancement Generative Adversarial Network
https://arxiv.org/abs/1703.09452

Requrement

Python

  • Python 3.6
  • CUDA 10.0 & CuDNN 7.6
    • Please choose the appropriate CUDA and CuDNN version to match your NNabla version

Packages

Please install the following packages with pip. (If necessary, install latest pip first.)

  • nnabla (over v1.0.19)
  • nnabla-ext-cuda (over v1.0.19)
  • scipy
  • numba
  • joblib
  • pyQT5
  • pyqtgraph (after installing pyQT5)
  • pypesq (see "install with pip" in offical site)

Contents

  • segan.py
    This is main source code. Run this.

  • data.py
    This is for creating Batch Data. Before runnning, please download wav dataset as seen below.

  • settings.py
    This includes setting parameters.

  • display.py
    This includes some functions to display results.

Download & Create Database

  1. Download segan.py, settings.py, data.py, display.py and save them into the same directory.

  2. In the directory, make three folders data, pkl, params .

    • data folder : save wav data.
    • pickle folder : save pickled database "~.pkl".
    • params folder : save parameters including network models.
  3. Download the following 4 dataset, and unzip them.

  4. Move those unzipped 4 folders into data folder.

  5. Convert the sampling frequency of all the wav data to 16kHz. For example, this site is useful. After converting, you can delete the original wav data.

Settings

settings.py

settings.py is a parameter list including the setting parameters for learning & predicting. Refer to the below when you want to know how to use the spectial paramters.

  • self.epoch_from :
    Number of starting Epoch when retraining. If self.epoch_from > 0, restart learing after loading pre-trained models "discriminator_param_xxxx.h5" and "generator_param_xxxx.h5". The value of self.epoch_from should be corresponding to "xxxx".
    If self.epoch_from = 0, retraining does not work.

  • self.model_save_cycle :
    Cycle of Epoch for saving network model. If "1", network model is saved for every 1 epoch.

Float 16bit (Half Precision Floating Point Mode)

If you are facing GPU Memory Stack Error, please try Half Precision Floating Point Mode which can downsize the calculation precision and thus reduce the memory usage. If you want to use, please run the following commands before defining the network.

ctx = get_extension_context('cudnn', device_id=args.device_id, type_config='half')
nn.set_default_context(ctx)

In segan.py, this mode is enable by default. Refer to "nnabla-ext-cuda" for more information.

Run

  1. If training, set Train=1 in main function of segan.py. If predicting, set Train=0 .
    Train = 0
    if Train:
        # Training
        nn.set_default_context(ctx)
        train(args)
    else:
        # Test
        #nn.set_default_context(ctx)
        test(args)
        pesq_score('clean.wav','output_segan.wav')
  1. Run segan.py.

During Training

If you run train(args) function, the training dataset (xxxx.pkl) is created in pkl at the beginning (for only the first time). And network model (xxxx.h5) is saved in params folder by every cycle that you set by self.model_save_cycle.

During Predicting

If you run test(args) function, the test dataset (xxxx.pkl) is created in pkl at the beginning (for only the first time). And the following wav data are generated as the results. PESQ value is also displayed.

  • clean.wav : clean speech wav file
  • noisy.wav : noisy speech wav file
  • output.wav : reconstructed speech wav file

segan's People

Contributors

yosukesugiura avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

segan's Issues

error

I have some problems with training. I follow the steps but got the error.

runfile('C:/Users/snow/Desktop/SEGAN-master/segan.py', wdir='C:/Users/snow/Desktop/SEGAN-master')
2020-12-23 20:58:27,899 [nnabla][INFO]: Initializing CPU extension...
2020-12-23 20:58:29,379 [nnabla][INFO]: Initializing CUDA extension...
2020-12-23 20:58:29,382 [nnabla][INFO]: Initializing cuDNN extension...
Load Clean Pkl...
Load Noisy Pkl...
Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsObject.py", line 40, in itemChange
ret = sip.cast(ret, QtGui.QGraphicsItem)

TypeError: cast() argument 1 must be sip.simplewrapper, not PlotItem

view= None
dt= None
px= 0
py= 0
view= None
dt= None
px= 0
py= 0
v= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
obj= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 438, in resizeEvent
self.updateAutoRange()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 890, in updateAutoRange
childRange = self.childrenBounds(frac=fractionVisible)

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 1355, in childrenBounds
px, py = [v.length() if v is not None else 0 for v in self.childGroup.pixelVectors()]

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 191, in pixelVectors
dt = self.deviceTransform()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 109, in deviceTransform
view = self.getViewWidget()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 66, in getViewWidget
if v is not None and not isQObjectAlive(v):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\Qt.py", line 329, in isQObjectAlive
return not sip.isdeleted(obj)

TypeError: isdeleted() argument 1 must be sip.simplewrapper, not GraphicsWindow

Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsObject.py", line 40, in itemChange
ret = sip.cast(ret, QtGui.QGraphicsItem)

TypeError: cast() argument 1 must be sip.simplewrapper, not PlotDataItem

Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsObject.py", line 40, in itemChange
ret = sip.cast(ret, QtGui.QGraphicsItem)

TypeError: cast() argument 1 must be sip.simplewrapper, not PlotDataItem

v= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
obj= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsObject.py", line 26, in itemChange
self.parentChanged()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 466, in parentChanged
self._updateView()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 483, in _updateView
view = self.getViewBox()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 89, in getViewBox
vb = self.getViewWidget()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 66, in getViewWidget
if v is not None and not isQObjectAlive(v):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\Qt.py", line 329, in isQObjectAlive
return not sip.isdeleted(obj)

TypeError: isdeleted() argument 1 must be sip.simplewrapper, not GraphicsWindow

Traceback (most recent call last):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsObject.py", line 40, in itemChange
ret = sip.cast(ret, QtGui.QGraphicsItem)

TypeError: cast() argument 1 must be sip.simplewrapper, not ChildGroup

v= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
obj= <pyqtgraph.graphicsWindows.GraphicsWindow object at 0x000002AD03BDBE58>
Traceback (most recent call last):

File "C:\Users\snow\Desktop\SEGAN-master\segan.py", line 430, in
train(args)

File "C:\Users\snow\Desktop\SEGAN-master\segan.py", line 288, in train
fig = figout()

File "C:\Users\snow\Desktop\SEGAN-master\display.py", line 109, in init
self.c11 = self.p1.plot(pen=(255, 0, 0), name="Input")

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\PlotItem\PlotItem.py", line 653, in plot
self.addItem(item, params=params)

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\PlotItem\PlotItem.py", line 530, in addItem
self.vb.addItem(item, *args, **vbargs)

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 409, in addItem
self.updateAutoRange()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 890, in updateAutoRange
childRange = self.childrenBounds(frac=fractionVisible)

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\ViewBox\ViewBox.py", line 1355, in childrenBounds
px, py = [v.length() if v is not None else 0 for v in self.childGroup.pixelVectors()]

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 191, in pixelVectors
dt = self.deviceTransform()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 109, in deviceTransform
view = self.getViewWidget()

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\graphicsItems\GraphicsItem.py", line 66, in getViewWidget
if v is not None and not isQObjectAlive(v):

File "D:\anaconda3\envs\SEGAN\lib\site-packages\pyqtgraph\Qt.py", line 329, in isQObjectAlive
return not sip.isdeleted(obj)

TypeError: isdeleted() argument 1 must be sip.simplewrapper, not GraphicsWindow

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.