Comments (7)
The default batch size 64
(see configs/default_lsun_configs.py
) is intended for training on multiple GPUs. You can improve memory efficiency by reducing the batch size. The batch size can be set either in the config files, or by command line --config.training.batch_size
.
from score_sde_pytorch.
Great, thank you!
For (future) colab users, I am now using a batch size of 16 for 16GB RAM (P100) and 128x128.
from score_sde_pytorch.
Hey,
the training worked really well, thanks for that. Now I am trying to do the evaluation. I managed to create my stats file and now I want to calculate the FID for 50k images. The process seems really slow on one V100 16GB. My eval config is:
evaluate = config.eval
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 20
evaluate.batch_size = 16
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = False
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'
Is there any chance to optimize this for one GPU?
Thanks a lot!
P.S.: The PyTorch requirements.txt do not include jax and jaxlib but you need them to run the code. I am not sure if they are just forgotten imports or if they are really needed for the code but this lead to errors for me.
from score_sde_pytorch.
You may increase evaluate.batch_size
by quite a large factor, as evaluating models do not require backpropagation and requires much less GPU memory. jax
and jaxlib
can be refactored out and technically evaluation code shouldn't depend on them. Thanks for catching these imports and I will optimize them out in the next revision.
from score_sde_pytorch.
Thanks for the help. I increased it to 64, anything above that runs out of memory. It takes really long either way. I have 782 (50000//64+1) sampling rounds and each round takes about 35 minutes. So getting the 50kFID of one model takes about 19 days 😂 Do you have any experience with reducing the sample size and the corresponding FID accuracy?
Thanks!
from score_sde_pytorch.
Yeah, that’s unfortunately due to the slow sampling of diffusion score models. Using the JAX version can be slightly better, since JAX code can sample faster than PyTorch. In my previous papers I also reported FID scores on 1k samples for some experiments, but in that case the FID score will be way larger than that evaluated on 50k samples.
from score_sde_pytorch.
Thanks for the info👍
from score_sde_pytorch.
Related Issues (20)
- cpu code availability
- How to understand `snr` in `LangevinCorrector`? HOT 1
- Unable to recreate working python environment to run this codebase using requirements.txt HOT 5
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 1
- Question about Eq. (4) and VP SDE implementation HOT 1
- KeyError: 'ncsnpp' of get_models() in models.utils HOT 3
- CUDA_HOME environment variable is not set. Please set it to your CUDA install root HOT 2
- Issues on evaluation
- ConditionalResidualBlock not working
- Question about the scaling operation of score function of VP and VE HOT 1
- PC sampler mismatched? HOT 3
- Question about reporting likelihoods in bits per dim HOT 4
- ImportError: cannot import name 'ParamSpec' from 'typing_extensions' HOT 1
- an error of the upfirn2d.py HOT 2
- Question about conditional generation
- How to calculate the score of a new unseen datapoint by a score based diffusion model?
- Some tips on why the model ain't working HOT 1
- The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead. HOT 1
- Segmentation fault (core dumped) HOT 1
- Likelihood estimation HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from score_sde_pytorch.