Code Monkey home page Code Monkey logo

Comments (5)

zcemycl avatar zcemycl commented on May 31, 2024

@NicoMigenda
From AAE.m script, I did update encoder every iteration.

The updates are listed in the while loop.
From lines 56-76,

        [GradEn,GradDe,GradDis] = ...
                dlfeval(@modelGradients,XBatch,...
                paramsEn,paramsDe,paramsDis,settings);

        % Update Discriminator network parameters
        [paramsDis,avgG.Dis,avgGS.Dis] = ...
            adamupdate(paramsDis, GradDis, ...
            avgG.Dis, avgGS.Dis, global_iter, ...
            settings.lrD, settings.beta1, settings.beta2);

        % Update Encoder network parameters
        [paramsEn,avgG.En,avgGS.En] = ...
            adamupdate(paramsEn, GradEn, ...
            avgG.En, avgGS.En, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        % Update Decoder network parameters
        [paramsDe,avgG.De,avgGS.De] = ...
            adamupdate(paramsDe, GradDe, ...
            avgG.De, avgGS.De, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);

Details are described in the modelGraidents from lines 106-125. The generator gradient (encoder+decoder structure) is updated together at line 123.
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);

Edit:
Maybe I get what you mean. You mean the gradient of the encoder should not only depend on reconstruction loss, but also binary cross entropy (for True False output). I did include this line 120,

g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));

Hope this replies your question.

from matlab-gan.

NicoMigenda avatar NicoMigenda commented on May 31, 2024

from matlab-gan.

zcemycl avatar zcemycl commented on May 31, 2024

@NicoMigenda

Oh I see. What you mention is common for all GANs. Usually, we will not train the generator twice, but still retain that concept in mind.

For example in Keras,

# Train the discriminator 
d_loss_real = self.discriminator.train_on_batch(latent_real, valid) 
d_loss_fake = self.discriminator.train_on_batch(latent_fake, fake) 
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 
# --------------------- 
# Train Generator 
# --------------------- 
# Train the generator 
g_loss = self.adversarial_autoencoder.train_on_batch(imgs, [imgs, valid])

They also only train once. But they set discriminator trainable to false first before training generator.

In Matlab, they do it by setting generator dlgradient retaindata to true.

from matlab-gan.

NicoMigenda avatar NicoMigenda commented on May 31, 2024

Thank you very much for that explanation. Is there further literature about that?

from matlab-gan.

zcemycl avatar zcemycl commented on May 31, 2024

@NicoMigenda
I would suggest to read how others implement vae and gan with keras and pytorch, and compare their implementations with the matlab version.

Keras AAE:
https://github.com/eriklindernoren/Keras-GAN/blob/master/aae/aae.py
Pay attention on discriminator.trainable and train_on_batch.

PyTorch AAE:
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/aae/aae.py
Pay attention on steps between zero_grad, backward and step.

Matlab dlgradient:
https://uk.mathworks.com/help/deeplearning/ref/dlarray.dlgradient.html

Also, compare the documentations in matlab deep learning toolbox, keras and pytorch.

from matlab-gan.

Related Issues (14)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.