Code Monkey home page Code Monkey logo

rectifiedflow's Introduction

Rectified Flow

This is the official implementation of ICLR2023 Spotlight paper

by Xingchao Liu, Chengyue Gong, Qiang Liu from UT Austin

InstaFlow

Rectified Flow can be applied to Stable Diffusion and make it a one-step generator. See here

Introduction

Rectified Flow is a novel method for learning transport maps between two distributions $\pi_0$ and $\pi_1$, by connecting straight paths between the samples and learning an ODE model.

Then, by a reflow operation, we iteratively straighten the ODE trajectories to eventually achieve one-step generation, with higher diversity than GAN and better FID than fast diffusion models.

An introductory website can be found here and the main idea is illustrated in the following figure:

Rectified Flow can be applied to both generative modeling and unsupervised domain transfer, as shown in the following figure:

For a more thorough inspection on the theoretical properties and its relationship to optimal transport, please refer to Rectified Flow: A Marginal Preserving Approach to Optimal Transport

Interactive Colab notebooks

We provide interactive tutorials with Colab notebooks to walk you through the whole pipeline of Rectified Flow. We provide two versions with different velocity models, neural network version and non-parametric version

Image Generation

The code for image generation is in ./ImageGeneration. Run the following command first

cd ./ImageGeneration

Dependencies

The following instructions has been tested on a Lambda Labs "1x A100 (40 GB SXM4)" instance, i.e. gpu_1x_a100_sxm4 and Ubuntu 20.04.2(Driver Version: 495.29.05, CUDA Version: 11.5, CuDNN Version: 8.1.0). We suggest to use Anaconda.

Run the following commands to install the dependencies:

conda create -n rectflow python=3.8
conda activate rectflow
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0
pip install jax==0.3.4 jaxlib==0.3.2
pip install numpy==1.21.6 ninja==1.11.1 matplotlib==3.7.0 ml_collections==0.1.1

Train 1-Rectified Flow

Run the following command to train a 1-Rectified Flow from scratch

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow
  • --config The configuration file for this run.

  • --eval_folder The generated images and other files for each evaluation during training will be stroed in ./workdir/eval_folder. In this command, it is ./logs/1_rectified_flow/eval/

  • ---mode Mode selection for main.py. Select from train, eval and reflow.

Sampling and Evaluation

We follow the evaluation pipeline as in Score SDE. You can download cifar10_stats.npz and save it to assets/stats/. Then run

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling  --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 2

which uses a batch size of 1024 to sample 50000 images, starting from checkpoint-2.pth, and computes the FID and IS.

Generate Data Pair $(Z_0, Z_1)$ with 1-Rectified Flow

To prepare data for reflow, run the following command

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_generate_data.py  --eval_folder eval --mode reflow --workdir ./logs/tmp --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/1_rectified_flow/" --config.reflow.total_number_of_samples 100000 --config.seed 0
  • --config.reflow.last_flow_ckpt The checkpoint for data generation.

  • --config.reflow.data_root The location where you would like the generated pairs to be saved. The $(Z_0, Z_1)$ pairs will be saved to ./data_root/seed/

  • --config.reflow.total_number_of_samples The total number of pairs you would like to generate

  • --config.seed The random seed. Change random seed to perform data generation in parallel.

For CIFAR10, we suggest to generate at least 1M pairs for reflow.

Reflow to get 2-Rectified Flow with the Generated Data Pair

After the data pairs are generated, run the following command to train 2-rectified flow

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_train.py  --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/1_rectified_flow/"

This command fine-tunes the checkpoint of 1-Rectified Flow with the data pairs generated in the last step, and save the logs of 2-rectified flow to ./logs/2_rectified_flow. 2-Rectified Flow should have a much better performance when using one-step generation $z_1=z_0 + v(z_0, 0)$, as shown in the following figure:

To evaluate with step N=1, run

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/2_rectified_flow --config.sampling.use_ode_sampler "euler" --config.sampling.sample_N 1 --config.eval.enable_sampling  --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 2

where sample_N refers to the number of sampling steps.

We can further improve the quality of 2-Rectified Flow in one-step generation with distillation.

Distill to get one-step 2-Rectified Flow

Before distillation, we need new data pairs from 2-Rectified Flow. This can be simply done with

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_generate_data.py  --eval_folder eval --mode reflow --workdir ./logs/tmp --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/" --config.reflow.total_number_of_samples 100000 --config.seed 0

Then we can distill 2-Rectified Flow with

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k=1.py  --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow_k=1_distill --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/"

Distill to get k-step 2-Rectified Flow (k>1)

Similarly, we can distill 2-Rectified Flow for k-step generation (k>1) with

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k_g_1.py  --eval_folder eval --mode reflow --config.reflow.reflow_t_schedule 2 --workdir ./logs/2_rectified_flow_k=2_distill --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/"

Here, we use k=2 as an example. Change --config.reflow.reflow_t_schedule to accomodate for different k.

Reflow and Distillation with Online Data Generation

To save storage space and simplify the training pipeline in reflow, data pairs can also be generated by a teacher model during training, which takes longer training time. To perform reflow with online data generation, run

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_train_online.py --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth"

Then, to distill 1-Rectified Flow with online data generation, run

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k=1_online.py  --eval_folder eval --mode reflow --workdir ./logs/1_rectified_flow_k=1_distill --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth"

High-Resolution Generation

We provide code and pre-trained checkpoints for high-resolution generation on four $256 \times 256$ datasets, LSUN Bedroom, LSUN Church, CelebA-HQ and AFHQ-Cat.

We use CelebA-HQ as an example. To train 1-rectified flow on CelebA-HQ, run

python ./main.py --config ./configs/rectified_flow/celeba_hq_pytorch_rf_gaussian.py --eval_folder eval --mode train --workdir ./logs/celebahq

To sample images from pre-trained rectified flow, run

python ./main.py --config ./configs/rectified_flow/celeba_hq_pytorch_rf_gaussian.py --eval_folder eval --mode eval --workdir ./logs/celebahq --config.eval.enable_figures_only --config.eval.begin_ckpt 10 --config.eval.end_ckpt 10 --config.training.data_dir YOUR_DATA_DIR

The images will be stored in ./logs/celebahq/eval/ckpt/figs

Pre-trained Checkpoints

As an example, to use pre-trained checkpoints, download the checkpoint_8.pth from CIFAR10 1-Rectified Flow, put it in ./logs/1_rectified_flow/checkpoints/, then run

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling  --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 8

The pre-trained checkpoints are listed here:

Citation

If you use the code or our work is related to yours, please cite us:

@article{liu2022flow,
  title={Flow straight and fast: Learning to generate and transfer data with rectified flow},
  author={Liu, Xingchao and Gong, Chengyue and Liu, Qiang},
  journal={arXiv preprint arXiv:2209.03003},
  year={2022}
}

Thanks

A Large portion of this codebase is built upon Score SDE.

rectifiedflow's People

Contributors

gnobitab avatar nalzok avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rectifiedflow's Issues

problems about the dependencies

hello,I follow the commands to install the dependencies,but I have a problem with this code, how to solve the following issue?
2021-06-02 10:02:23.367252: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1553 : Unknown: Fail to find the dnn implementation. 2021-06-02 10:02:23.369234: E tensorflow/stream_executor/cuda/cuda_dnn.cc:352] Loaded runtime CuDNN library: 8.0.5 but source was compiled with: 8.1.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration. 2021-06-02 10:02:23.370402: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1553 : Unknown: Fail to find the dnn implementation.

error of Downloading TF-Hub Module 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'.

I0619 12:16:40.841244 140569191580864 resolver.py:419] Downloading TF-Hub Module 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'.
Traceback (most recent call last):
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/urllib/request.py", line 1354, in do_open
h.request(req.get_method(), req.selector, req.data, headers,
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1256, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1302, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1251, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1011, in _send_output
self.send(msg)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 951, in send
self.connect()
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1418, in connect
super().connect()
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 922, in connect
self.sock = self._create_connection(
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/socket.py", line 808, in create_connection
raise err
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/socket.py", line 796, in create_connection
sock.connect(sa)
TimeoutError: [Errno 110] Connection timed out

Bug in Colab: Tutorial: Rectified Flow with Neural Network.ipynb

Hi, thanks for your great work. I especially appreciate your intuitive blog post. However, I just want to let you know that there's a tiny bug in the attached colab example, Tutorial: Rectified Flow with Neural Network.ipynb.

image

The selected variable should be diffusion, otherwise it overwrites the rectified_flow_1 in the subsequent blocks.

Implementation of feature loss, Equation 4 in the paper.

Hi @gnobitab , I implemented feature loss by myself, however, it did not work properly.
Could you provide some comments for my pseudo code?

import torch
import torch.nn.functional as F

def get_feature_weight(S):

    def _feature_func(x):
        feature = feature_extractor(x)  # shape [batch_size, feature_dim, H, W]
        feature = feature.sum(dim=(0, 2, 3))
        return feature  # shape [feature_dim]

    S = S.requires_grad_(True)  # shape [batch_size, dim, H, W]
    w = torch.autograd.functional.jacobian(_feature_func, S)   # shape [feature_dim, batch_size, dim, H, W]
    return w.transpose(0, 1).detach()   # shape [batch_size, feature_dim, dim, H, W]

w = get_feature_weight(z_t)
w_target = torch.einsum('bdchw,bchw->bdhw', w, target)
w_pred = torch.einsum('bdchw,bchw->bdhw', w, pred)
loss = F.mse_loss(target, pred)

Had a error about protobuf,what should I do

Traceback (most recent call last):
File "E:\Googledownload\RectifiedFlow-main\ImageGeneration\main.py", line 18, in
import run_lib
File "E:\Googledownload\RectifiedFlow-main\ImageGeneration\run_lib.py", line 25, in
import tensorflow as tf
File "D:\conda\lib\site-packages\tensorflow_init_.py", line 37, in
from tensorflow.python.tools import module_util as module_util
File "D:\conda\lib\site-packages\tensorflow\python_init
.py", line 37, in
from tensorflow.python.eager import context
File "D:\conda\lib\site-packages\tensorflow\python\eager\context.py", line 28, in
from tensorflow.core.framework import function_pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\function_pb2.py", line 16, in
from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\attr_value_pb2.py", line 16, in
from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\tensor_pb2.py", line 16, in
from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\resource_handle_pb2.py", line 16, in
from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\tensor_shape_pb2.py", line 36, in
_descriptor.FieldDescriptor(
File "D:\conda\lib\site-packages\google\protobuf\descriptor.py", line 561, in new
_message.Message._CheckCalledFromGeneratedFile()
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:

  1. Downgrade the protobuf package to 3.20.x or lower.
  2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

How to fine-tune custom datsaet with pretrained model?

I tried to fine-tune custom dataset with pretrained model.
But dataset_stats.npz file like assets/stats/celeba_stats.npz is needed.
No explanation of generating dataset_stats.npz ....
So How can I try to fine-tune custom datset with pretrained model?

Image-to-Image Translation Code

Hi~
Excellient work!!
I want to ask when will you public the code of Image-to-Image Translation. I'm looking forward to it.

Thanks!!

best,

reproduce environment

Hi. Do you have an env.yaml file for us to reproduce your environment? I tried to install the required pkgs using the requirements.txt provided in this repo, but many bugs occurred. A yaml file to create a new conda env would be much straigh-forward and easy for those who want to run your code :) Or any other alternatives would be great as long as it makes reproducing environment easier!

How long does it take to train on LSUN bedroom?

Thanks for your great work!
If I want to train on LSUN bedroom and get a LSUN bedroom checkpoint like in RectifiedFlow pre-trained-checkpoints, how long will it take?

I am training on 8A100, and from step100 to step1400, it takes 22mins. And in configs/default_lsun_configs.py, training.n_iters = 2400001. So it will take about 690 hours (29 days) in total. Is 29 days on 8xA100 a similar training cost time to yours?

out of memory

Hi, I run the inference example python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 8 and got an OOM error.
I set the batch size = 1, num_samples =1.
My GPU has 24576MiB.

Is there any way to bypass the OOM?

I0220 07:32:34.000750 139647609180864 resolver.py:106] Using /tmp/tfhub_modules to cache modules.
I0220 07:32:36.494689 139647609180864 run_lib.py:273] begin checkpoint: 8
Traceback (most recent call last):
  File "main.py", line 74, in <module>
    app.run(main)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "main.py", line 66, in main
    run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder)
  File "/root/RectifiedFlow/ImageGeneration/run_lib.py", line 286, in evaluate
    state = restore_checkpoint(ckpt_path, state, device=config.device)
  File "/root/RectifiedFlow/ImageGeneration/utils.py", line 14, in restore_checkpoint
    loaded_state = torch.load(ckpt_dir, map_location=device)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1049, in _load
    result = unpickler.load()
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1019, in persistent_load
    load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1001, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 973, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
    result = fn(storage, location)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 157, in _cuda_deserialize
    return obj.cuda(device)
  File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/_utils.py", line 78, in _cuda
    return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking)
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 918.46 MiB already allocated; 13.56 MiB free; 968.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

how to accelerate the training process?

python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow

It seems that the training process need 60w iterations.

It seems that the memory usage of each gpu is not very high during the training process. (4.3G for 24G RTX3090)
Is there any way to increase the memory usage and therefore accelerate the training process?

about the formula

t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps

You mentioned that xt = (1-t)x + ty, and in the code, t starts with eps (1e-3) in training, however, you used t=eps in inference and the input of model should be x_0 * (1-eps)+ y*eps accordingly. I feel quiet strange about this.

在采样的时候用的是xt = (1-t)x + ty, t的范围是[1e-3,1],训练的时候t=1e-3对应的是x_0 * (1-eps)+ y*eps,但是推理的时候t的起点也是1e-3,这时候是没有y用于计算x_t的,是出于什么原因这样设计的呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.