Code Monkey home page Code Monkey logo

combostoc's Introduction

ComboStoc: Combinatorial Stochasticity for Diffusion Generative Models
Official Code

ComboStoc samples

This repo contains the image diffusion models and training/sampling code for our paper exploring the Combinatorial Stochasticity for Diffusion Generative Models. [Project] [Arxiv]

We will add the structured shape generation code later.

Pls cite our paper:

@misc{xu2024combostoccombinatorialstochasticitydiffusion,
      title={ComboStoc: Combinatorial Stochasticity for Diffusion Generative Models}, 
      author={Rui Xu and Jiepeng Wang and Hao Pan and Yang Liu and Xin Tong and Shiqing Xin and Changhe Tu and Taku Komura and Wenping Wang},
      year={2024},
      eprint={2405.13729},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2405.13729}, 
}

In this paper we study an under-explored but important factor of diffusion generative models, i.e., the combinatorial complexity. Data samples are generally high-dimensional, and for various structured generation tasks there are additional attributes which are combined to associate with data samples. We show that the space spanned by the combination of dimensions and attributes is insufficiently sampled by existing training scheme of diffusion generative models, causing degraded test time performance. We present a simple fix to this problem by constructing stochastic processes that fully exploit the combinatorial structures, hence the name ComboStoc. Using this simple strategy, we show that network training is significantly accelerated across diverse data modalities, including images and 3D structured shapes. Moreover, ComboStoc enables a new way of test time generation which uses insynchronized time steps for different dimensions and attributes, thus allowing for varying degrees of control over them.

This repository contains:

  • ๐Ÿ›ธ A ComboStoc training script using PyTorch DDP
  • ๐Ÿ“ซ A pre-trained ComboStoc-XL-2 model with 'INSYNC_ALL' setting at 800K. [Google Drive]

Setup

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate ComboStoc

Sampling

More ComboStoc samples

You can sample from ComboStoc checkpoints with sample.py. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 256x256 ComboStoc-XL model with default ODE setting, you can use:

python sample.py SDE --model ComboStoc-XL/2 --image-size 256 --ckpt /path/to/model.pt

Advanced sampler settings (Following SiT)

ODE --atol float Absolute error tolerance
--rtol float Relative error tolenrace
--sampling-method str Sampling methods (refer to [torchdiffeq] )
SDE --diffusion-form str Form of SDE's diffusion coefficient (refer to Tab. 2 in paper)
--diffusion-norm float Magnitude of SDE's diffusion coefficient
--last-step str Form of SDE's last step
None - Single SDE integration step
"Mean" - SDE integration step without diffusion coefficient
"Tweedie" - [Tweedie's denoising] step
"Euler" - Single ODE integration step
--sampling-method str Sampling methods
"Euler" - First order integration
"Heun" - Second order integration

There are some more options; refer to train_utils.py for details.

Training ComboStoc

We provide a training script for ComboStoc in train.py. To launch ComboStoc-XL/2 (256x256) training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-XL/2 --data-path /path/to/imagenet/train

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag:

torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-XL/2 --data-path /path/to/imagenet/train --wandb

Resume training. To resume training from custom checkpoint:

torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt

Caution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a ComboStoc model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with [PyTorch-FID] to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained ComboStoc-XL/2 model over N GPUs under default ODE sampler settings, run:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py SDE --model ComboStoc-XL/2 --num-fid-samples 50000

combostoc's People

Contributors

xrvitd avatar

Stargazers

Snow avatar  avatar Yichong Lu avatar  avatar Jionghao Wang avatar Zhihua Liu avatar wingwu avatar Jiantao Song avatar  avatar Shiguang Wu avatar Jin  Huang avatar Jean-Philippe Deblonde avatar  avatar Sandalots avatar  avatar  avatar Jiepeng Wang avatar Dongyu Yan avatar Qingzhe Gao avatar  avatar Rekkles avatar  avatar

Watchers

Snow avatar  avatar  avatar Zheng Zhang avatar

Forkers

jackzhousz

combostoc's Issues

use_blend and use_halfmixing both True?

Are both use_blend and use_halfmixing supposed to be both set to True to achieve the "ComboStoc-TB" approach?

Or is it just use_halfmixing? It would seem that the combination sets half to original T and half to blended T + combostoc T.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.