Code Monkey home page Code Monkey logo

torch-stft's Introduction

Build Status

STFT/iSTFT in PyTorch

Author: Prem Seetharaman

An STFT/iSTFT written up in PyTorch using 1D Convolutions. Requirements are a recent version PyTorch, numpy, and librosa (for loading audio in test_stft.py). Thanks to Shrikant Venkataramani for sharing code this was based off of and Rafael Valle for catching bugs and adding the proper windowing logic. Uses Python 3.

Installation

Install easily with pip:

pip install torch-stft

Usage

import torch
from torch_stft import STFT
import numpy as np
import librosa 
import matplotlib.pyplot as plt

audio = librosa.load(librosa.util.example_audio_file(), duration=10.0, offset=30)[0]
device = 'cpu'
filter_length = 1024
hop_length = 256
win_length = 1024 # doesn't need to be specified. if not specified, it's the same as filter_length
window = 'hann'

audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
audio = audio.to(device)

stft = STFT(
    filter_length=filter_length, 
    hop_length=hop_length, 
    win_length=win_length,
    window=window
).to(device)

magnitude, phase = stft.transform(audio)
output = stft.inverse(magnitude, phase)
output = output.cpu().data.numpy()[..., :]
audio = audio.cpu().data.numpy()[..., :]
print(np.mean((output - audio) ** 2)) # on order of 1e-16

Output of compare_stft.py:

images/stft.png

Tests

Test it by just cloning this repo and running

pip install -r requirements.txt
python -m pytest .

Unfortunately, since it's implemented with 1D Convolutions, some filter_length/hop_length combinations can result in out of memory errors on your GPU when run on sufficiently large input.

Contributing

Pull requests welcome.

torch-stft's People

Contributors

pseeth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch-stft's Issues

wrong mel convertion

this line robbed me: mel_spectrogram = F.linear(magnitude.transpose(-1, -2), self.mel_filter_bank)

instead we should use this: torch.bmm(self.mel_filter_bank.expand(input_data.size(0), -1, -1), magnitude).transpose(1, 2)

Question: torch-stft VS librosa.stft

Hi, dear author, i came across your project by chance, and i am newbie in audio processing, now we already have librosa.stft, i am curious what is the benefit or necessity of using torch-stft, when is it better to use torch-stft , could you please explain it a little, and maybe you can write this in README.md.

Export to onnx

Hi!

Thanks for this work!

I'm using this in a model and when I try to export the model to onnx, I get:
Missing key(s) in state_dict: "model_g.dec.stft.forward_basis", "model_g.dec.stft.inverse_basis".

If I change the lines below to persistent=False, the export works... But I'm not sure if this will have any effect that I can't foresee as this persistent flag is not clear to me.

Would you please tell me if this is acceptable?

Thank you!

self.register_buffer('forward_basis', forward_basis.float())

and
self.register_buffer('inverse_basis', inverse_basis.float())

to:

        self.register_buffer('forward_basis', forward_basis.float(), persistent=False)
        self.register_buffer('inverse_basis', inverse_basis.float(), persistent=False)

fmin和fmax

After I changed it to DFT, I found that the values of fmin and fmax would affect the magnitude of the amplitude. For example, I know that the frequency of a signal is 156.25Hz and 1562.5Hz, in which FFT=512 and sampling_rate=8000, by calculating that the index where the maximum value and the second largest value of its amplitude are located should be the frequency value. That is, 10 and 100 are calculated correctly by the original DFT, but after the linear calculation, it is found that the calculated value is inconsistent with the actual value, and the result will be affected by fmin and fmax. Why?

The example code on my machine yields ValueError.

I use the exact example code as you provided in the README, but I caught error like this:

Traceback (most recent call last):
  File "my_test.py", line 29, in <module>
    print(np.mean((output - audio) ** 2)) # on order of 1e-16
ValueError: operands could not be broadcast together with shapes (1,1,220416) (1,220500)

I know it's okay to check the shape before further processing, but I just want to know how to resolve that.

Spectrogram quality

Hi,
I've testing your code and it reconstruct your stft.
I'm afraid this is librosa stft
image

And this is the spectrogram computed with this method
image

Avoid inversion and save some space

Hey Prem,
This snippet looks familiar! :)

You don't need to invert the DFT matrix to get the inverse transform. You can do this instead:

# Transform constants
sz,hp = 128,32
wn = hanning( sz+1)[:-1]**.5

# Make DFT matrix, split into real/imag, scale DC/Nyquist to make orthogonal
f  = fft( eye( sz)) / (.5*sqrt( sz) * sqrt( sz/hp))
f = vstack( (real( f[:sz//2+1,:]),imag(f[:sz//2+1,:])))

# This makes the DFT matrix work in both ways using convs
f[0,:] /= sqrt(2)
f[sz//2,:] /= sqrt(2)

# Make the transform kernel
DFT  = torch.FloatTensor( wn * f)[:,None,:]

# Input is bt x dim x time
x = torch.sin( 64 * torch.linspace( 0, 2*pi, 1024, dtype=torch.float32)).view( 1, 1, -1)

# Do DFT
f = F.conv1d( x, DFT, stride=hp, padding=sz)

# Get amplitude and phase
a = torch.sqrt(  f[:,:sz//2+1,:]**2 + f[:,sz//2+1:,:]**2)
p = torch.atan2( f[:,sz//2+1:,:], f[:,:sz//2+1,:])

# ... process ...

# Stack amplitudes and phase
f = torch.cat( [a*torch.cos( p), a*torch.sin( p)], dim=1)

# Inverse STFT using the same kernel
y = F.conv_transpose1d( f, DFT, stride=hp, padding=sz)

Paris

Force alignment for output of STFT transform

I suggest adding some slicing on forward transform in stft.py to keep the output be aligned with the input audio data.

        forward_transform = F.conv1d(
            input_data,
            self.forward_basis,
            stride=self.hop_length,
            padding=0)
        forward_transform = forward_transform[:, :, :-1]

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.