Code Monkey home page Code Monkey logo

Comments (7)

yang-song avatar yang-song commented on July 19, 2024

The default batch size 64 (see configs/default_lsun_configs.py) is intended for training on multiple GPUs. You can improve memory efficiency by reducing the batch size. The batch size can be set either in the config files, or by command line --config.training.batch_size.

from score_sde_pytorch.

pbizimis avatar pbizimis commented on July 19, 2024

Great, thank you!

For (future) colab users, I am now using a batch size of 16 for 16GB RAM (P100) and 128x128.

from score_sde_pytorch.

pbizimis avatar pbizimis commented on July 19, 2024

Hey,
the training worked really well, thanks for that. Now I am trying to do the evaluation. I managed to create my stats file and now I want to calculate the FID for 50k images. The process seems really slow on one V100 16GB. My eval config is:

evaluate = config.eval
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 20
evaluate.batch_size = 16
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = False
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'

Is there any chance to optimize this for one GPU?
Thanks a lot!

P.S.: The PyTorch requirements.txt do not include jax and jaxlib but you need them to run the code. I am not sure if they are just forgotten imports or if they are really needed for the code but this lead to errors for me.

from score_sde_pytorch.

yang-song avatar yang-song commented on July 19, 2024

You may increase evaluate.batch_size by quite a large factor, as evaluating models do not require backpropagation and requires much less GPU memory. jax and jaxlib can be refactored out and technically evaluation code shouldn't depend on them. Thanks for catching these imports and I will optimize them out in the next revision.

from score_sde_pytorch.

pbizimis avatar pbizimis commented on July 19, 2024

Thanks for the help. I increased it to 64, anything above that runs out of memory. It takes really long either way. I have 782 (50000//64+1) sampling rounds and each round takes about 35 minutes. So getting the 50kFID of one model takes about 19 days 😂 Do you have any experience with reducing the sample size and the corresponding FID accuracy?
Thanks!

from score_sde_pytorch.

yang-song avatar yang-song commented on July 19, 2024

Yeah, that’s unfortunately due to the slow sampling of diffusion score models. Using the JAX version can be slightly better, since JAX code can sample faster than PyTorch. In my previous papers I also reported FID scores on 1k samples for some experiments, but in that case the FID score will be way larger than that evaluated on 50k samples.

from score_sde_pytorch.

pbizimis avatar pbizimis commented on July 19, 2024

Thanks for the info👍

from score_sde_pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.