Code Monkey home page Code Monkey logo

k-diffusion's Introduction

k-diffusion

DOI

An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models.

Hourglass diffusion transformer

k-diffusion contains a new model type, image_transformer_v2, that uses ideas from Hourglass Transformer and DiT.

Requirements

To use the new model type you will need to install custom CUDA kernels:

  • NATTEN for the sparse (neighborhood) attention used at low levels of the hierarchy. There is a shifted window attention version of the model type which does not require a custom CUDA kernel, but it does not perform as well and is slower to train and inference.

  • FlashAttention-2 for global attention. It will fall back to plain PyTorch if it is not installed.

Also, you should make sure your PyTorch installation is capable of using torch.compile(). It will fall back to eager mode if torch.compile() is not available, but it will be slower and use more memory in training.

Usage

Demo

To train a 256x256 RGB model on Oxford Flowers without installing custom CUDA kernels, install Hugging Face Datasets:

pip install datasets

and run:

python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16

If you run out of memory, try adding --checkpointing or reducing the batch size. If you are using an older GPU (pre-Ampere), omit --mixed-precision bf16 to train in FP32. It is not recommended to train in FP16.

If you have NATTEN installed and working (preferred), you can train with neighborhood attention instead of shifted window attention by specifying --config configs/config_oxford_flowers.json.

Config file

In the "model" key of the config file:

  1. Set the "type" key to "image_transformer_v2".

  2. The base patch size is set by the "patch_size" key, like "patch_size": [4, 4].

  3. Model depth for each level of the hierarchy is specified by the "depths" config key, like "depths": [2, 2, 4]. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level.

  4. Model width for each level of the hierarchy is specified by the "widths" config key, like "widths": [192, 384, 768]. The widths must be multiples of the attention head dimension.

  5. The self-attention mechanism for each level of the hierarchy is specified by the "self_attns" config key, like:

    "self_attns": [
        {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
        {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
        {"type": "global", "d_head": 64},
    ]

    If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large.

  6. As a fallback if you or your users cannot use NATTEN, you can also train a model with shifted window attention at the low levels of the hierarchy. Shifted window attention does not perform as well as neighborhood attention and it is slower to train and inference, but it does not require custom CUDA kernels. Specify it like:

    "self_attns": [
        {"type": "shifted-window", "d_head": 64, "window_size": 8},
        {"type": "shifted-window", "d_head": 64, "window_size": 8},
        {"type": "global", "d_head": 64},
    ]

    The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type.

Inference

TODO: write this section

Installation

k-diffusion can be installed via PyPI (pip install k-diffusion) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run pip install -e <path to repository>.

Training

To train models:

$ ./train.py --config CONFIG_FILE --name RUN_NAME

For instance, to train a model on MNIST:

$ ./train.py --config configs/config_mnist_transformer.json --name RUN_NAME

The configuration file allows you to specify the dataset type. Currently supported types are "imagefolder" (finds all images in that folder and its subfolders, recursively), "cifar10" (CIFAR-10), and "mnist" (MNIST). "huggingface" Hugging Face Datasets is also supported.

Multi-GPU and multi-node training is supported with Hugging Face Accelerate. You can configure Accelerate by running:

$ accelerate config

then running:

$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME

Enhancements/additional features

  • k-diffusion supports a highly efficient hierarchical transformer model type.

  • k-diffusion supports a soft version of Min-SNR loss weighting for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022).

  • k-diffusion has wrappers for v-diffusion-pytorch, OpenAI diffusion, and CompVis diffusion models allowing them to be used with its samplers and ODE/SDE.

  • k-diffusion implements DPM-Solver, which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. DPM-Solver++(2S) and (2M) are implemented now too for improved quality with low numbers of steps.

  • k-diffusion supports CLIP guided sampling from unconditional diffusion models (see sample_clip_guided.py).

  • k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.

  • k-diffusion can calculate, during training, the FID and KID vs the training set.

  • k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from An Empirical Model of Large-Batch Training, https://arxiv.org/abs/1812.06162).

To do

  • Latent diffusion

k-diffusion's People

Contributors

crowsonkb avatar rom1504 avatar tmabraham avatar johnowhitaker avatar storyicon avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.