Code Monkey home page Code Monkey logo

ldm's Introduction

3D version of "High-Resolution Image Synthesis with Latent Diffusion Models" or "Stable Diffusion"

Train command

torchrun --nproc_per_node $N_GPU main.py --base $CFG_FILE -t --name $EXP_NAME --gpus 0,1...

Stage -1: prepare environment

# fork this repository to your github
git clone $GIT_ADDR_TO_YOUR_REPO ./latentdiffusion
cd ./diffusion
pip install -r requirements.txt

Stage 0: write config file

config files resides under configs/

  • autoencoder config files:
model:
** take KL-AE as an example **
  base_learning_rate: $LR
  target: ldm.models.autoencoder.AutoencoderKL
  params:
    ** ldm.models.autoencoder.AutoencoderKL **'s keyword arguments
    lossconfig:
      target: ldm.modules.losses.LPIPSWithDiscriminator
      params:
        ...

    ddconfig:
      ** ldm.modules.diffusionmodules.model.Encoder/Decoder **'s keyword arguments
      ...
    use_checkpoint: true          # use checkpoint to save GPU memory

data:
** take BraTS21 as an example **
  target: main.DataModuleFromConfig
  params:
    batch_size: 1
    num_workers: 2
    wrap: False
    train:
      target: ldm.data.brats2021.BraTS2021_3D
      params:
        split: train
        crop_to: [64, 64, 64]
    validation:
      target: ldm.data.brats2021.BraTS2021_3D
      params:
        split: val
        crop_to: [64, 64, 64]

-- logging arguments and pytorchlightning callbacks --
...
  • latentdiffusion config files:
model:
  base_learning_rate: $LR
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    ...
    first_stage_key: image      # data key
    cond_stage_key: mask        # condition key
    conditioning_key: concat    # conditional type
    image_size: [8, 16, 16]     # after first-stage encoding
    channels: 4                 # after first-stage encoding
    ...

    unet_config:
      ** diffusion UNet config **
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 64  # not used
        ...
        use_checkpoint: True    # always use checkpoint to save GPU memory

    first_stage_config:
      ** your autoencoder setting, copy from autoencoder's config, with loss=nn.Identity **
      ckpt_path: /path/to/your/pretrained/ae
      ...

    cond_stage_config: 
      ** use __is_unconditional__ if running without conditions, else follow the same format as above **
      ...

data:
  ** same as in autoencoder config **
  ...

-- logging arguments and pytorchlightning callbacks --
...

Stage 1: train autoencoder

For training autoencoder using KL regularization, the code workflow is:
  • main.py ( trainer function ) ->
  • ldm/models/autoencoder.py ( autoencoder wrapper ) ->
  • ldm/modules/losses/contperceptual.py ( LPIPS and GAN loss ) & ldm/modules/diffusionmodules/model.py ( autoencoder model class )
For training autoencoder using Vector Quantization, the code workflow is:
  • main.py ( trainer function ) ->
  • ldm/models/autoencoder.py ( autoencoder wrapper ) ->
  • ldm/modules/losses/vqperceptual.py ( LPIPS and codebook loss ) & ldm/modules/diffusionmodules/model.py ( autoencoder model class )

Stage 2: train diffusion model

For training latentdiffusion model, the code workflow is:
  • main.py ( trainer function ) ->
  • ldm/models/diffusion/ddpm.py ( diffusion wrapper ) ->
  • ldm/modules/diffusionmodules/openaimodel.py ( diffusion UNet )
  • ldm/models/diffusion/ddim.py is for fast reverse sampling using DDIM

Inference command

python test.py --base $CFG_FILE --name $EXP_NAME --gpus 0,

  • CFG_FILE is specified like that in training cmd, the images and files will be stored under <path_to_$EXP_NAME>/images/test

References

Refer to the following directories for more details

ldm's People

Contributors

ovo1111 avatar rromb avatar crowsonkb avatar ak391 avatar pesser avatar

Stargazers

yyr avatar  avatar ALEX_ZGH avatar Wenhui Lei avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.