Comments (6)
Thanks a lot! And I will try again :)
from nbss.
You need to change the dataset's __getitem__ function. As we could see from its comments,
NBSS/data_loaders/ss_semi_online_dataset.py
Lines 69 to 78 in dfd2587
what you should return is xm the mixture, ys the ground truth speeches, and the paras.
To adapt to the dataset of Fasnet, the __getitem__ of your Dataset can be something like:
def __getitem__(self, index: Dict[str, int]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: # type: ignore
sidx = index['speech_index'] # an index in [0, 20000), [0, 5000), [0, 3000)
# xm= read the 'sidx'-th speech mixture in time domain
# ys= read the 'sidx'-th speech targets in time domain
paras = {
"index": sidx,
# any other paras you need for evaluation
}
Also, you need to remove the unnecessary code (like unnecessary parameters) in ss_semi_online_data_module.py and ss_semi_online_sampler.py to make your own datamodule and sampler. But collate_func_train, collate_func_val, and collate_func_test is necessary in datamodule as they are linked to the corresponding functions of class NBSS_ifp in
Lines 26 to 28 in dfd2587
from nbss.
Another simpler implementation could be something like the following code. As if your dataset returns xm, ys and paras in the correct shape and type, it works for the other code in NBSS.
class YourDatasetClass(Dataset):
def __init__(self, speech_paths) -> None:
super().__init__()
self.speech_paths=speech_paths
def __getitem__(self,index:int)-> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: # type: ignore
"""returns the indexed item
Args:
index: index
Returns:
Tensor: xm of shape [channel, time] in time domain
Tensor: ys of shape [spk, channel, time] in time domain
dict: paras used
"""
# read mix and targets
# xm= read the 'index'-th speech mixture, the path of which is self.speech_paths[index]['mix']
# ys= read the 'index'-th speech targets, the paths of which is self.speech_paths[index]['spk_1'] and self.speech_paths[index]['spk_2'] for 2-speaker case
paras = {
"index": index,
# any other paras you need for evaluation
}
return xm, ys, paras
def __len__(self):
return len(self.speech_paths)
class SS_SemiOnlineDataModule(LightningDataModule):
def __init__(
self,
speech_dir_path: str, # your speech dir including the generated multi-channel mixture and multi-channel target speeches
batch_size: List[int] = [5, 5],
speaker_num: int = 2,
num_workers: int = 5,
collate_func_train: Callable = None,
collate_func_val: Callable = None,
collate_func_test: Callable = None,
):
super().__init__()
self.speech_dir_path = speech_dir_path
self.batch_size = batch_size[0]
self.batch_size_val = batch_size[1]
self.speaker_num = speaker_num
self.num_workers = num_workers
self.collate_func_train = collate_func_train
self.collate_func_val = collate_func_val
self.collate_func_test = collate_func_test
self.prepare_data()
def prepare_data(self):
"""prepare data to self.speech_pathes
"""
self.speech_pathes=dict()
self.speech_pathes['train']=... # the pathes of training mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
self.speech_pathes['val']=... # the pathes of validation mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
self.speech_pathes['test']=... # the pathes of test mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
def setup(self, stage=None):
# YourDatasetClass is your Dataset implementation, which receives the paths and its __getitem__ function returns xm, ys and paras
self.train = YourDatasetClass(speeches=self.speech_pathes['train'])
self.val = YourDatasetClass(speeches=self.speech_pathes['train'])
self.test = YourDatasetClass(speeches=self.speech_pathes['train'])
# Sampler is removed as it is used for online shuffuling rirs and and speeches, thus it's unnecessary for already generated dataset
def train_dataloader(self) -> DataLoader:
prefetch_factor = self.batch_size
persistent_workers = False
return DataLoader(self.train,
batch_size=self.batch_size,
collate_fn=self.collate_func_train,
num_workers=self.num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
persistent_workers=persistent_workers)
def val_dataloader(self) -> DataLoader:
prefetch_factor = self.batch_size_val
persistent_workers = False
return DataLoader(self.val,
batch_size=self.batch_size_val,
collate_fn=self.collate_func_val,
num_workers=self.num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
persistent_workers=persistent_workers)
def test_dataloader(self) -> DataLoader:
prefetch_factor = 2
return DataLoader(
self.test,
batch_size=1,
collate_fn=self.collate_func_test,
num_workers=1,
prefetch_factor=prefetch_factor,
)
from nbss.
Another thing you need to change is the config file: Remove/Add the paramters in the config file you removed/added in the code
from nbss.
Thank you for your quick response and advice!
I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification.
I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet? And if you haven't generated the dataset with the script used in Fasnet (https://github.com/yluo42/TAC/tree/master/data), I'll soon email you the dataset to [email protected], which is displayed on your home page.
I will be so much appreciate it if you could help with that!
Looking forward to your reply and thanks again!
from nbss.
You are welcome. ^^
I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification.
You could post your error messages. Or send me your code. Maybe I can help some.
I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet?
I did train my model with the dataset of fasnet. But the code is not the lastest version.
To adapt to the dataset of fasnet, you need to implement your own dataset things like dataset, datamodule.
You could refer to the PyTorch Lightning doc to implememt your own datamodule.
https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.datamodule.html#pytorch_lightning.core.datamodule.LightningDataModule
And the lightning cli:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html
from nbss.
Related Issues (20)
- Loss Nan Value HOT 5
- pytorchlightning version mismatched or missing some code HOT 12
- For adhoc configuration HOT 1
- NBSSCLI.py: error: 'Configuration check failed :: No action for destination key "trainer.num_processes" to check its value.' HOT 3
- using one real recordings to inference HOT 23
- why use the reverberated speech signal as the training target HOT 4
- How to use custom dataset with SpatialNet HOT 2
- Error in WHAMR! Training HOT 3
- How to train model for more than 2 speakers? HOT 4
- 模型训练完成后,推理报错 HOT 2
- 关于SpatialNet 参数量变大从而对性能提升的上限 HOT 20
- Unable to run training HOT 1
- Question about mamba edition. HOT 3
- OSpatialNet貌似出现过拟合的问题? HOT 3
- 如何生成训练数据? HOT 19
- Ask Help for OnlineSpatialNet Mamba Version Can't Work HOT 3
- Can I use shorter training and testing utterances? HOT 1
- License HOT 1
- GPU memory requirements HOT 4
- `state` and `share_qk` options HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nbss.