Comments (6)
Sorry for the late reply. I was quite busy recently. Have you checked #10 and #11? Did you use mixed precision as well?
from styletts2.
thanks for your reply, I have fixed this issue, may caused by too small batch size
from styletts2.
Were you able to do it? I was trying to train but was facing some issue. Can we discuss?
from styletts2.
Same issue with batch size 2, generator loss can reach about 100 and then it Nan's. (EDIT: Didn't work!) I have a preliminary solution, still testing though but based on #11 (comment) it seems to be discriminator overfitting. So I am trying to force the discriminators weight decay to a high value to prevent overfitting, in train_first:
for module in ["mpd", "msd"]:
for g in optimizer.optimizers[module].param_groups:
g['weight_decay'] = 0.1
and also lowering the feature discriminator gain by premultiplying by 0.5, in losses.py
class GeneratorLoss(torch.nn.Module):
def __init__(self, mpd, msd):
super(GeneratorLoss, self).__init__()
self.mpd = mpd
self.msd = msd
def forward(self, y, y_hat):
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + 0.5*loss_fm_s + 0.5*loss_fm_f + loss_rel
return loss_gen_all.mean()
At first I tried decay = 0.01 and gains 1.0,1.0 but that only delayed the problem.
Then I tried decay = 1.0 and gains 0.1,0.1 and that seemed to prevent Nan but the audio quality wasn't good.
So now I am trying decay = 0.1 and gains 0.5,0.5. I should be able to report back the results in a few days.
from styletts2.
No that didn't work :( the loss made some strange moves and eventually ended with Nan.
from styletts2.
Integrating PhaseAug and using batch_percentage=1.0 with Batch=2, fixed it for me.
PhaseAug tries to address the overfitting issue by randomly rotating the phase of each frequency bin.
The gen error still creeps up but very slowly now and audio quality becomes quite nice after 2 epochs:
...
aug = PhaseAug()
gl = GeneratorLoss(model.mpd, model.msd, aug).to(device)
dl = DiscriminatorLoss(model.mpd, model.msd, aug).to(device)
...
class GeneratorLoss(torch.nn.Module):
def __init__(self, mpd, msd, aug):
super(GeneratorLoss, self).__init__()
self.mpd = mpd
self.msd = msd
self.aug = aug
def forward(self, y, y_hat):
y, y_hat = self.aug.forward_sync(y, y_hat) # <--- Augment here
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + 1.0*loss_fm_s + 1.0*loss_fm_f + loss_rel
return loss_gen_all.mean()
class DiscriminatorLoss(torch.nn.Module):
def __init__(self, mpd, msd, aug):
super(DiscriminatorLoss, self).__init__()
self.aug = aug
self.mpd = mpd
self.msd = msd
def forward(self, y, y_hat):
y, y_hat = self.aug.forward_sync(y, y_hat.detach()) # <--- Augment here
# MPD
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
d_loss = loss_disc_s + loss_disc_f + loss_rel
return d_loss.mean()
from styletts2.
Related Issues (20)
- Possible Bug in Style Diffusion Inference Code
- Issue with impropper pauses and random bursts of noise
- Cannot Convert float NaN to integer HOT 1
- HELP WANTED!!!!!!!!!!! HOT 3
- asr negative loss
- Resuming finetuning uses second to last epoch
- Help Wanted For Stage-1 HOT 2
- Inference with multilingual PL-BERT Model HOT 4
- During training, the graphics memory has been continuously increasing
- May be a bug? input parameters for model.predictor_encoder and model.style_encoder in train_finetune.py
- S_loss = 0 ... why? HOT 2
- Inference Error: context_features exists but no features provided HOT 1
- Speech conditioning like tortoise TTS HOT 1
- FP8 Fine Tuning Crashes HOT 1
- Error Message After Using a fine tuned ASR Model
- Stage 2 Training Fails with NaN Loss on Single GPU Due to Inconsistent Checkpoint Keys
- Getting CUDA Out of memory error in Stage2 training HOT 13
- Multi-lingual training HOT 18
- In training Stage1 after 49th epoch getting RuntimeError: you can only change requires_grad flags of leaf variables, g_loss.requires_grad = True HOT 1
- First stage training after 49th epoch (i.e., when epoch >= TMA_epoch)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from styletts2.