Code Monkey home page Code Monkey logo

diffusion_models_distillation's Introduction

Diffusion Models Distillation

This repository shows an implementation of distilling diffusion models into fewer sampling steps based on On Distillation of Guided Diffusion Models and Progressive Distillation for Fast Sampling of Diffusion Models. The implementation is based on diffusers, and distills a classifier-free guidance model on imagenet into 1/2 sampling steps.

For more distillation papers about diffusion models, please see diffusion models distillation papers

The following is a sample of images: the left is generated by 2-step distilled model, and the right is generated by the original diffusion model with 4 DDIM steps.

images

Requirements

pip install accelerate einops_exts diffusers datasets transformers
# Download stable-diffusion into third_party
git clone https://github.com/YongfeiYan/diffusion_models_distillation.git
cd diffusion_models_distillation/third_party && git clone https://github.com/CompVis/stable-diffusion.git

Run

Download imagenet and pretrained model

The imagenet data is downloaded through stable-diffusion code. It will download all images into cache dir at the first time to create ImageNetTrain dataset. Use the following to download:

PYTHONPATH=.:third_party/stable-diffusion python diffdstl/data/get_imagenet.py

Download pretrained model:

dst=data/ldm/cin256-v2
mkdir -p $dst && cd $dst
wget https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt 

Finetune v_prediction

The first step is to finetune the original model into v_prediction to stablize distillation process.

# Convert the pretrained model into diffusers pipeline
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/ldm_ckpt_to_pipeline.py configs/imagenet/cin256-v2.yaml data/ldm/cin256-v2/model.ckpt data/test-pipeline
# Finetune
CUDA_VISIBLE_DEVICES=0 bash scripts/progressdstl/finetune_v_prediction.sh

Stage one: classifier-free guidance removal

The second step is to remove classifier-free guidance in sampling:

CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_one.sh

Stage two: distilling to less sampling steps

The third step is to iteratively halve sampling steps. To reduce training time, the script begines with 64 DDIM sampling steps and runs 5 times to distill the student model into 1 sampling step.

# Convert sampling scheduler
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/convert_pipeline_scheduler.py data/log/imagenet/stage_one/pipeline data/log/imagenet/stage_one/pipeline-converted
# Distill
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_two.sh &> stage_two.log & 

Reference

@inproceedings{DBLP:conf/cvpr/MengRGKEHS23,
  author       = {Chenlin Meng and
                  Robin Rombach and
                  Ruiqi Gao and
                  Diederik P. Kingma and
                  Stefano Ermon and
                  Jonathan Ho and
                  Tim Salimans},
  title        = {On Distillation of Guided Diffusion Models},
  booktitle    = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
                  {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023},
  pages        = {14297--14306},
  publisher    = {{IEEE}},
  year         = {2023},
  url          = {https://doi.org/10.1109/CVPR52729.2023.01374},
  doi          = {10.1109/CVPR52729.2023.01374},
  timestamp    = {Tue, 29 Aug 2023 15:44:40 +0200},
  biburl       = {https://dblp.org/rec/conf/cvpr/MengRGKEHS23.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}
@inproceedings{DBLP:conf/iclr/SalimansH22,
  author       = {Tim Salimans and
                  Jonathan Ho},
  title        = {Progressive Distillation for Fast Sampling of Diffusion Models},
  booktitle    = {The Tenth International Conference on Learning Representations, {ICLR}
                  2022, Virtual Event, April 25-29, 2022},
  publisher    = {OpenReview.net},
  year         = {2022},
  url          = {https://openreview.net/forum?id=TIdIXIpzhoI},
  timestamp    = {Sat, 20 Aug 2022 01:15:42 +0200},
  biburl       = {https://dblp.org/rec/conf/iclr/SalimansH22.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}

diffusion_models_distillation's People

Contributors

yongfeiyan avatar

Stargazers

 avatar

Watchers

 avatar

diffusion_models_distillation's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.