Code Monkey home page Code Monkey logo

Comments (3)

wpeebles avatar wpeebles commented on May 20, 2024

Hi @tankche1, yeah this is normal behavior. The reason why this happens is because of how we construct the target images early in training. We gradually interpolate the latent code that synthesizes the target images from the input image’s code to the learned congealing vector over the first 150K gradient steps. This stabilizes training by letting the STN predict only small warps at the start of training. So basically the learning task gets harder for the STN in the early training steps, leading to the loss curve you’re seeing. Once you hit ~100K to 150K iterations, the STN will start making a lot of progress.

You don’t have to train for the full 1.5M steps unless you want to squeeze out every last bit of performance. From my experience, you get most of the benefit after about 400K to 712K steps. Make sure you use one of the checkpoints that gets saved when the learning rate hits zero. Those seem to be the best.

from gangealing.

tankche1 avatar tankche1 commented on May 20, 2024

Thanks! The training looks great now. One question, I find that both the transformed sample and the target sample (truncated sample) are gradually moving from the original image to the final congealing version.

If I put the original image(e.g., a white cat) into the stn and set the final congealing output (e.g., the head of a white cat) as the target in the perceptual loss, the stn can not learn the transform.

Does this mean that the congealing algorithm is based on the gradual improvement from both the stn and the latent learning embedding? Also, is the t_ema only use for visualization?

from gangealing.

wpeebles avatar wpeebles commented on May 20, 2024

Yeah, the reason you see the gradual transformation is a result of this gradual annealing of the target latent code. We have an ablation in Table 4 of the supplementary materials where we omit this annealing, and it drops [email protected] of the cat model from 67% --> 59%, so it definitely makes a significant impact.

That's an interesting experiment you ran. I guess some images should be able to be successfully congealed without using gradual annealing (otherwise the ablation would probably be closer to 0% PCK :), but I don't have great intuition for the specific subset of images that it helps the most with.

At the end of training, t is effectively discarded and t_ema is the final model used for everything (that's the reason we visualize it during training, since we care about t_ema's final performance more than t's). It's an exponential moving average of t's parameters over training, which is a trick that a lot of generative models (DDPMs, GANs, etc.) use to improve performance.

Btw, as an aside, when you use your trained models at test time, I would recommend using iters=3 when calling the STN (e.g., stn(x, iters=3)). The iters argument recursively applies the similarity STN on its own output, which helps a lot for harder datasets like LSUN. If you're using the testing scripts in the applications folder, you can specify this from the command line with --iters 3. The visualizations made during training all use iters=1, so it's a lower bound on performance.

from gangealing.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.