Code Monkey home page Code Monkey logo

nn-similarity-diarization's Introduction

Neural network based similarity scoring for diarization

PyTorch implementation of neural network based similarity scoring for diarization: based on the paper "LSTM based Similarity Measurement with Spectral Clustering for Speaker Diarization" [1] at INTERSPEECH 2019 https://arxiv.org/abs/1907.10393, https://www.isca-speech.org/archive/Interspeech_2019/pdfs/1388.pdf

I am not affiliated with the paper authors.

The basic concept behind this method is to learn the similarity scoring matrix needed for diarizing a recording. Inputs of concatenated speaker embeddings (such as x-vectors) are fed through an LSTM or other architecture to predict the similarity of the concatenated embeddings.

model_fig Figure taken from [1]

Requirements

Kaldi, python, kaldi_io, scipy, sklearn, torch (tested on torch ver. 1.3.0), CALLHOME dataset

TL;DR

You can run most of the steps (make train/test folds -> train -> predict -> cluster) with run.sh.

NOTE: The Kaldi Data preparation must be run first, follow those instructions up until 'Make train/test folds' and then run.sh can be run from inside the repo folder. Make sure to configure the variables at the top of run.sh as well as your configured .cfg file (see configs/example.cfg):

xvector_dir=/PATH/TO/XVECS #path to extracted xvectors, same as run_data_prep.sh
KALDI_PATH=/PATH/TO/KALDI_ROOT # path to kaldi root, neede for finding egs folder
folds_path=/PATH/TO/FOLDS_DATA # path to where the train/test split folds will be stored
cfg_path=/PATH/TO/CFG # path to cfg file, $folds_path is data_path in the cfg

Kaldi Data preparation

The data-preparation for this will involve the following steps:

  1. Make kaldi data folder for CALLHOME
  2. Feature extraction (MFCCs)
  3. X-vector extraction (using the pre-trained CALLHOME model available on the Kaldi website)
  4. As in the paper, make a 5 fold train/test split to train and evaluate on

First, some variables need to be configured in run_data_prep.sh. These are located at the top of the file and are as follows:

callhome_path=/PATH/TO/CALLHOME #path to raw callhome data
xvector_dir=/PATH/TO/XVECS #path to extracted xvectors

These need to point to where the CALLHOME dataset is and also where you would like the extracted x-vectors to reside (make sure to use an absolute path). Once this is done, copy this script to the Kaldi recipe folder for CALLHOME (as existing data prep scripts are leveraged):

The location of the Kaldi installation will be referred to as $KALDI_PATH in the following instructions.

cp run_data_prep.sh $KALDI_PATH/egs/callhome_diarization/v2
cd $KALDI_PATH/egs/callhome_diarization/v2
source path.sh
./run_data_prep.sh

Make train/test folds

Changing directory back to where this repo is, run the following command to make the train/test folds, replacing the variables as is necessary. Here $xvector_dir is as above and $folds_path is the location in which the splits will reside: (recommended to use num_folds=5)

python -m scripts.make_kfold_callhome $xvector_dir $KALDI_PATH/egs/callhome_diarization/v2/data/callhome/fullref.rttm $folds_path $num_folds

cp $KALDI_PATH/egs/callhome_diarization/v2/data/callhome/fullref.rttm $folds_path

which makes a folder structure like so

folds_path
├── fullref.rttm
├── ch0
|   ├── train
|   |   ├── ref.rttm
|   |   ├── segments
|   |   ├── utt2spk
|   |   └── xvector.scp
|   └── test
├── ch1
|   ├── train
|   └── test
├── ...

Training

The primary training script train.py is mostly defined by the config file which it reads. An example config file is shown in configs/example.cfg. The relevant fields to this section are shown below:

[Datasets]
# this is $folds_path in the data preparation step (also in run.sh)
data_path = /PATH/TO/FOLDS_PATH

[Model]
# Supported models: 'lstm', 'lstm_cos_ws', 'lstm_cos_res'
# 'lstm' is the same as in the original paper
model_type = lstm

[Hyperparams]
lr = 0.01
max_len = 400
no_cuda = False
seed = 1234
num_epochs = 100
# at the epoch numbers in scheduler_steps, the lr will be multiplied by scheduler_lambda
scheduler_steps = [40, 80]
scheduler_lambda = 0.1

[Outputs]
# this is where models will be saved
base_model_dir = exp/example_models_folder
# Interval at which models will be stored for checkpointing purposes
checkpoint_interval = 1

The main fields which need to be configured are data_path and base_model_dir. The first corresponds to $folds_path used above and the latter will be the place in which the models are stored.

Once this cfg file is configured, a model can be trained on a fold like so:

python train.py --cfg configs/<your_config>.cfg --fold 0

This will need to be run (in parallel or sequentially) for each fold [0,1,2,3,4].

This will store .pt models into base_model_dir in a very similar structure as above:

folds_path
├── ch0
|   ├── epoch_1.pt
|   ├── epoch_2.pt
|   ├── ...
|   └── final_100.pt
├── ch1
|   ├── epoch_1.pt
|   ├── ...
├── ...

Inference

Processing the folds of data using the final model is done using predict.py. This script assumes a file structure produced as above. The similarity matrix predictions for each recording are stored in a <recording_id>.npy format in subfolders called ch*/<tr|te>_preds.

To produce predictions:

python predict.py --cfg configs/<your_config>.cfg

Evaluation

To obtain a diarization prediction, clustering is performed (using cluster.py) with the similarity matrix enhancement described in [1]. Like the paper, spectral clustering is included, and agglomerative clustering is also available.

For each fold of the CALLHOME dataset, a configurable range of cluster parameter values are evaluated to find the best performing value on the train set. The single best one is then used to cluster that test set. Each test set hypothesis is then combined to create the overall system hypothesis for CALLHOME.

The relevant sections in the configuration file for clustering are as follows:

[Clustering]
# Only 'sc' and 'ahc' are supported
cluster_type = sc

# The following values are fed into np.linspace to produce a range of parameters to try clustering the train portion over
# Note: cparam_start must be positive if spectral clustering is used.
cparam_start = 0.95
cparam_end = 1.0
cparam_steps = 20

Before running the clustering step, md-eval.pl will need to be obtained, which can be downloaded using:

wget https://raw.githubusercontent.com/foundintranslation/Kaldi/master/tools/sctk-2.4.0/src/md-eval/md-eval.pl

Finally:

python cluster.py --cfg configs/<your_config>.cfg

which will have an output similar to this:

Fold 0, cparam 0.9       Train DER: 15.5
5%|###########8                                  | 1/20 [00:55<17:33, 55.44s/it]
Fold 0, cparam 0.9052631578947369        Train DER: 15.07
10%|#######################6                    | 2/20 [01:43<15:59, 53.29s/it]
Fold 0, cparam 0.9105263157894737        Train DER: 14.44
...

Results

More results incoming...

Similarity Scoring Method CALLHOME DER
x-vector + PLDA + AHC [1] 8.64%
x-vector + LSTM + SC [1] 7.73%
x-vector + LSTM + SC [this repo] 8.83%
x-vector + LSTM_CosRes + SC [this repo] 9.12%
x-vector + TDNNs + SC [this repo] ----%

(the paper) refers to the results reported in [1]

Other issues/todos

  • Test changing num_folds
  • I-vectors instead of x-vectors - with system fusion
  • Transformer and other architectures (some of which are in models.py)
  • logspace option for cluster thresholds, or some other spacing options
  • Conv architectures
  • investigate data aug by: adding noise to x-vecs, swapping same speaker x-vecs around, sampling from learned GMM per recording speakers.

References

# [1]
@inproceedings{Lin2019,
  author={Qingjian Lin and Ruiqing Yin and Ming Li and Hervé Bredin and Claude Barras},
  title={{LSTM Based Similarity Measurement with Spectral Clustering for Speaker Diarization}},
  year=2019,
  booktitle={Proc. Interspeech 2019},
  pages={366--370},
  doi={10.21437/Interspeech.2019-1388},
  url={http://dx.doi.org/10.21437/Interspeech.2019-1388}
}

nn-similarity-diarization's People

Contributors

cvqluu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

nn-similarity-diarization's Issues

The paper used the label of center 750ms as label of the whole 1.5s segment, but it seems that in your code you use the whole 1.5s segment to calculate the label.

def segment_labels(segments, rttm, xvectorscp, xvecbase_path=None):
segment_cols = load_n_col(segments, numpy=True)
segment_rows = np.array(list(zip(*segment_cols))) #解压
rttm_cols = load_n_col(rttm, numpy=True)
vec_utts, vec_paths = load_n_col(xvectorscp, numpy=True)
if not xvecbase_path:
xvecbase_path = os.path.dirname(xvectorscp)
assert sum(vec_utts == segment_cols[0]) == len(segment_cols[0])
vec_paths = change_base_paths(vec_paths, new_base_path=xvecbase_path)

rttm_cols.append(rttm_cols[3].astype(float) + rttm_cols[4].astype(float))#起始时间+duration=结束时间
recording_ids = sorted(set(segment_cols[1])) # recording_id 如iaaa
events0 = np.array(segment_cols[2:4]).astype(float).transpose() #segment起止时间
events1 = np.vstack([rttm_cols[3].astype(float), rttm_cols[-1]]).transpose() #ref起止时间(groundtruth)

rec_batches = []

for rec_id in tqdm(recording_ids): #tqdm是进度条
    seg_indexes = segment_cols[1] == rec_id  #recording_id iaaa ==优先级高
    rttm_indexes = rttm_cols[1] == rec_id
    ev0 = events0[seg_indexes]   #rec_id对应的segment里每段的起止时间
    ev1 = events1[rttm_indexes]  #rec_id对应的rffm里每段音频的起止时间
    ev1_labels = rttm_cols[7][rttm_indexes] #rec_id对应的rffm里每段音频的speaker
    ev0_labels = assign_overlaps(evnew, ev1, ev1_labels)
    ev0_labels = ['{}_{}'.format(rec_id, l) for l in ev0_labels]  #形成speaker_id,如iaaa_A
    batch = (segment_cols[0][seg_indexes], ev0_labels, vec_paths[seg_indexes], segment_rows[seg_indexes])
    rec_batches.append(batch)

return recording_ids, rec_batches

How can datasets be obtained for free?

Hello:

I'm glad you can publish the code! But it's a headache to have no data set. I want to ask how to obtain the corresponding data set? Because the data set in the field of speaker segmentation and clustering is really a headache. Without a data set, many experiments can not be done.

If you can put forward some suggestions, it would be very grateful!

Why does LSTM model not have sigmoid fuction?

class LSTMSimilarity(nn.Module):

def __init__(self, input_size=256, hidden_size=256, num_layers=2):
    super(LSTMSimilarity, self).__init__()
    self.lstm = nn.LSTM(input_size,
                        hidden_size,
                        num_layers=num_layers,
                        bidirectional=True,
                        batch_first=True)
    self.fc1 = nn.Linear(hidden_size*2, 64)
    self.nl = nn.ReLU(inplace=True)
    self.fc2 = nn.Linear(64, 1)

def forward(self, x):
    self.lstm.flatten_parameters()
    x, _ = self.lstm(x)
    x = self.fc1(x)
    x = self.nl(x)
    x = self.fc2(x).squeeze(2)
    return x

this is your LSTM model, but in original paper, it need a sigmoid fuction before output, like this:

x = torch.sigmoid(self.fc2(x).squeeze(2))
return x
but when I change this, totally error change from 9.7% to 37%, I don't know why

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.