Code Monkey home page Code Monkey logo

fn-ssl's Introduction

Full-band and Narrow-band fusion Network for SSL

Introduction

This repository provides methods which based on full-band and narrow-band fusion network for sound source localization. The narrow-band module processes the along-time sequences to focus on learning these narrow-band spatial information. The full-band module processes the along-frequency sequence to focus on learning the full-band correlation of spatial cues, such as the linear relation of DP-IPD to frequency.

Methods

Two official implemented sound source localization methods are included:

Datasets

Quick start (will be update soon)

  • Preparation

    • Download the required dataset and organize the data according to the data_org in the data folder.

    • Generate multi-channel data, You can set data_num (in Simu.py) to control the size of the dataset. --train, -- test, --dev are used to control the generation of train dataset, test dataset, and validation dataset, respectively. The source data path of them are specified by dirs ['sousig_train '] in Opt.py.

    python Simu.py --train/--test/--dev
    
    • For DP-IPD regression, set is_doa = False (Model.FN_SSL), and use mse loss function, for DOA classification, set is_doa = True (Model.FN_SSL), and use ce loss function, meanwhile, the predgt2doa needs to be replaced synchronously. The initial Learning rate of doa classification is set to 5e-4.
    net = at_model.FN_SSL(is_doa=True/False)
    
  • Training

    • For train step, --gpu-id is used to specify the gpu, ---bz corresponds to the batch size of train process, validation process, and test process, respectively.
    python Train.py --train --gpu-id [*] --bz * * * 
    
  • Evaluation

    • In the inference stage, you can set checkpoints_dir (Predict. py) to select weights, we provide simulation dataset inference and locata dataset inference.
    • For simulated data evaluation
    python Predict.py --test --datasetMode simulate --bz * * *
    
    • For LOCATA dataset evaluation
    python Predict.py --test --datasetMode locata
    
  • Pytorch Lightning version

    • We have re implemented FN-SSL using the Pytorch-lightning framework, which has a improvement in training speed compared to the torch.
    • For Train,
    python main.py fit --data.batch_size=[*,*] --trainer.devices=*,*
    
    • For test,
    python main.py test  --ckpt_path logs/MyModel/version_x/checkpoints/**.ckpt --trainer.devices=*,*
    
  • Pretrained models

    • Using the FN_lightning model to load the lightning checkpoint in torch framework.
Framework Task Checkpoint
Lightning DP-IPD regression https://pan.baidu.com/s/1zRKpiqbSuo80Xu5ZRoS1gQ?pwd=6w51
Lightning DOA classification https://pan.baidu.com/s/1U1Wl5ZBZBItc2Vku7AyqNA?pwd=ceqm

more checkpoints will be update soon.

Citation

If you find our work useful in your research, please consider citing:

@InProceedings{wang2023fnssl,
    author = "Yabo Wang and Bing Yang and Xiaofei Li",
    title = "FN-SSL: Full-Band and Narrow-Band Fusion for Sound Source Localization",
    booktitle = "Proceedings of INTERSPEECH",
    year = "2023",
    pages = ""}

Reference code

Licence

MIT

fn-ssl's People

Contributors

bingyang-20 avatar wangyabo123 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

fn-ssl's Issues

change target to location classification

Dear author, thanks for your sharing of code and the state-of-the-art performance of Doa estimation. there are two problem confusing me:

  1. in paper 2.1.2: if the target is from DP-IPD to localization classification, the activation will change from tanh to softmax, in readme, "set is_doa = True (Model.FN_SSL)", when I set the is_doa to True, seems the activation in Model.py is still tanh, the output shape is normal (B T 180), it can train normally, did I need to change the activation in Model.py to sofmax?
  2. as said in the readme "the predgt2doa needs to be replaced synchronously", did I need to change it to predgt2DOA_cls?, when I change it to this, it has errors,

hope to see your answers.
best regards

Addressing Bottlenecks in Training

I am using two RTX 3090 GPUs to run train.py following the provided guide.

I only used one FullNarrowBlock, but the training process took 5 hours to complete just one of the 15 epochs. The training process took too much time than I expected.

When I Checked the GPU utilization, it seemed there was a bottleneck somewhere in the code. I suspect the bottleneck might be in the data loading and processing part.

I am wondering if this is a normal occurrence.
If there is something wrong, could you give me an advice how to deal with it?

no 0 and 180 degrees

Hello, thanks for this algorithm. I would like to ask if it can locate 0 and 180 degrees. I don’t know why the positioning results never show 0 and 180 degrees.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.