Code Monkey home page Code Monkey logo

Comments (6)

yl4579 avatar yl4579 commented on August 14, 2024

Sorry for the late reply. I was quite busy recently. Have you checked #10 and #11? Did you use mixed precision as well?

from styletts2.

zhiqiuiyiye avatar zhiqiuiyiye commented on August 14, 2024

thanks for your reply, I have fixed this issue, may caused by too small batch size

from styletts2.

akshatgarg99 avatar akshatgarg99 commented on August 14, 2024

Were you able to do it? I was trying to train but was facing some issue. Can we discuss?

from styletts2.

RillmentGames avatar RillmentGames commented on August 14, 2024

Same issue with batch size 2, generator loss can reach about 100 and then it Nan's. (EDIT: Didn't work!) I have a preliminary solution, still testing though but based on #11 (comment) it seems to be discriminator overfitting. So I am trying to force the discriminators weight decay to a high value to prevent overfitting, in train_first:

for module in ["mpd", "msd"]:
    for g in optimizer.optimizers[module].param_groups:
        g['weight_decay'] = 0.1

and also lowering the feature discriminator gain by premultiplying by 0.5, in losses.py

class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd
        
    def forward(self, y, y_hat):
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_gen_all = loss_gen_s + loss_gen_f + 0.5*loss_fm_s + 0.5*loss_fm_f + loss_rel
        
        return loss_gen_all.mean()

At first I tried decay = 0.01 and gains 1.0,1.0 but that only delayed the problem.
Then I tried decay = 1.0 and gains 0.1,0.1 and that seemed to prevent Nan but the audio quality wasn't good.
So now I am trying decay = 0.1 and gains 0.5,0.5. I should be able to report back the results in a few days.

from styletts2.

RillmentGames avatar RillmentGames commented on August 14, 2024

No that didn't work :( the loss made some strange moves and eventually ended with Nan.
Stts2NanProblem

from styletts2.

RillmentGames avatar RillmentGames commented on August 14, 2024

Integrating PhaseAug and using batch_percentage=1.0 with Batch=2, fixed it for me.
PhaseAug tries to address the overfitting issue by randomly rotating the phase of each frequency bin.
The gen error still creeps up but very slowly now and audio quality becomes quite nice after 2 epochs:

...
    aug = PhaseAug()
    gl = GeneratorLoss(model.mpd, model.msd, aug).to(device)
    dl = DiscriminatorLoss(model.mpd, model.msd, aug).to(device)
...
class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd
        self.aug = aug  
        
    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat)                 #               <--- Augment here
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_gen_all = loss_gen_s + loss_gen_f + 1.0*loss_fm_s + 1.0*loss_fm_f + loss_rel
        
        return loss_gen_all.mean()
    
class DiscriminatorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(DiscriminatorLoss, self).__init__()
        self.aug = aug
        self.mpd = mpd
        self.msd = msd
        
    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat.detach())    #                   <--- Augment here
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
        loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
        loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)


        d_loss = loss_disc_s + loss_disc_f + loss_rel
        
        return d_loss.mean()

from styletts2.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.