Code Monkey home page Code Monkey logo

ddpo-pytorch's Introduction

ddpo-pytorch

This is an implementation of Denoising Diffusion Policy Optimization (DDPO) in PyTorch with support for low-rank adaptation (LoRA). Unlike our original research code (which you can find here), this implementation runs on GPUs, and if LoRA is enabled, requires less than 10GB of GPU memory to finetune Stable Diffusion!

DDPO

Installation

Requires Python 3.10 or newer.

git clone [email protected]:kvablack/ddpo-pytorch.git
cd ddpo-pytorch
pip install -e .

Usage

accelerate launch scripts/train.py

This will immediately start finetuning Stable Diffusion v1.5 for compressibility on all available GPUs using the config from config/base.py. It should work as long as each GPU has at least 10GB of memory. If you don't want to log into wandb, you can run wandb disabled before the above command.

Please note that the default hyperparameters in config/base.py are not meant to achieve good performance, they are just to get the code up and running as fast as possible. I would not expect to get good results without using a much larger number of samples per epoch and gradient accumulation steps.

Important Hyperparameters

A detailed explanation of all the hyperparameters can be found in config/base.py. Here are a few of the important ones.

prompt_fn and reward_fn

At a high level, the problem of finetuning a diffusion model is defined by 2 things: a set of prompts to generate images, and a reward function to evaluate those images. The prompts are defined by a prompt_fn which takes no arguments and generates a random prompt each time it is called. The reward function is defined by a reward_fn which takes in a batch of images and returns a batch of rewards for those images. All of the prompt and reward functions currently implemented can be found in ddpo_pytorch/prompts.py and ddpo_pytorch/rewards.py, respectively.

Batch Sizes and Accumulation Steps

Each DDPO epoch consists of generating a batch of images, computing their rewards, and then doing some training steps on those images. One important hyperparameter is the number of images generated per epoch; you want enough images to get a good estimate of the average reward and the policy gradient. Another important hyperparameter is the number of training steps per epoch.

However, these are not defined explicitly but are instead defined implicitly by several other hyperparameters. First note that all batch sizes are per GPU. Therefore, the total number of images generated per epoch is sample.batch_size * num_gpus * sample.num_batches_per_epoch. The effective total training batch size (if you include multi-GPU training and gradient accumulation) is train.batch_size * num_gpus * train.gradient_accumulation_steps. The number of training steps per epoch is the first number divided by the second number, or (sample.batch_size * sample.num_batches_per_epoch) / (train.batch_size * train.gradient_accumulation_steps).

(This assumes that train.num_inner_epochs == 1. If this is set to a higher number, then training will loop over the same batch of images multiple times before generating a new batch of images, and the number of training steps per epoch will be multiplied accordingly.)

At the beginning of each training run, the script will print out the calculated value for the number of images generated per epoch, the effective total training batch size, and the number of training steps per epoch. Make sure to double-check these numbers!

Reproducing Results

The image at the top of this README was generated using LoRA! However, I did use a fairly powerful DGX machine with 8xA100 GPUs, on which each experiment took about 4 hours for 100 epochs. In order to run the same experiments with a single small GPU, you would set sample.batch_size = train.batch_size = 1 and multiply sample.num_batches_per_epoch and train.gradient_accumulation_steps accordingly.

You can find the exact configs I used for the 4 experiments in config/dgx.py. For example, to run the aesthetic quality experiment:

accelerate launch scripts/train.py --config config/dgx.py:aesthetic

If you want to run the LLaVA prompt-image alignment experiments, you need to dedicate a few GPUs to running LLaVA inference using this repo.

Reward Curves

Training using 🤗 trl

🤗 trl provides a DDPOTrainer class which lets you fine-tune Stable Diffusion on different reward functions using DDPO. The integration supports LoRA, too. You can check out the supplementary blog post for additional guidance. The DDPO integration was contributed by @metric-space to trl.

ddpo-pytorch's People

Contributors

desaixie avatar kvablack avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

ddpo-pytorch's Issues

On reproducing LLaVA alignment experiments.

Hi! I've a couple of questions on the LLaVa alignment:

  • Which LLaVA model was used for the experiment? Was it the 7B one or 13B one?
  • What precision was used w/ LoRA based results? Was that bf16 of full fp32?

Support for other schedulers

This code currently only supports DDIM. In the recently released SD-XL, the default scheduler is EulerDiscrete. From the paper and the code, it seems that the prev_sample is no longer sampled from a Gaussian distribution but a ODE solution instead (correct me if I am wrong here). How to calculate the log_prob of prev_sample given the noise_pred in this case?

Gif visualization

Hi,
thanks for the reimplementation! Your Gif visualization using iceberg is super nice! Could you maybe also share the code of it?
Thanks a lot!

Finetuning on google colab

I tried this code (compressibility finetuning) on colab but I faced GPU memory overflow. I even reduced batch size and sample size but it did not solve the problem. I should mention that I used free gpu which is T4 and has a 15G of GPU RAM (this is claimed that this code could be run on 10G of RAM).
Any suggestions or help?

Batch size unrecogonized

I am using the DDPO logic to fine-tuned my own model.
However, I found that the example reward function (LLaVA BERTScore) use a fixed batch size.

After seeing the source code in this repo and the TRL DDPOTrainer class, I found that this batch size may related to sample_batch_size.

I recommend to modify the batch size with the one in the config or leave some comments on it. By doing so, people who wants to design their reward function can have a more sensible guide.

Below is the example reward in this repo I mentioned above.

def llava_bertscore():
    """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
    https://github.com/kvablack/LLaVA-server for server-side code.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 16 
    url = "http://127.0.0.1:8085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del metadata
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC

        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
...

And this is the code which use compute_reward() in the DDPOTraner class in TRL Repo

def step(self, epoch: int, global_step: int):
        """
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.

        """
        samples, prompt_image_data = self._generate_samples(
            iterations=self.config.sample_num_batches_per_epoch,
            batch_size=self.config.sample_batch_size,
        )

        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
        rewards, rewards_metadata = self.compute_rewards(
            prompt_image_data, is_async=self.config.async_reward_computation
        )
...

Prompt Alignment with LLaVA-server: Client-side prompt and image doesn't match server side reward

I am running the prompt alignment experiment with LLaVA-server, although I am using BLIP2 instead of LLaVA.
I wanted to see the VLM's caption of the image along side the prompt, image, and reward, so I added this additional logging to wandb. For passing the caption strings from the server back to the main training process, I converted the fixed-length strings into ascii integers with ord(), so it can be converted to a torch.tensor before calling accelerator.gather at this line, and then back to strings with chr(). As shown in the image below, the prompts and VLM captions that I received from the server do not match.
image

Then I used a trick trying to match the input prompt from the client side and the server's response. For each prompt generated with prompt_fn, I generate a random 5-digit id number. This id is passed to the server, prepended to the VLM's outputs. Then I use the prompts' ids to retrieve the corresponding captions. As shown below, the prompts and the captions now match after using my "id" trick. I also appended the computed rewards to the captions on the server side, before sending the response to client. However, the reward appended at the end of the captions do not match the rewards from the client side (code). It seems that the server's responses don't preserver the order of the queries it receives.
image

Could you verify if the current code does have this problem where the order of server's responses doesn't match that of the client's queries? I am getting clear training progress, which shouldn't be the case the the rewards' order is messed up.
image

prompt-dependent value function optimization

I saw you mentioned prompt-dependent value function at #7 (comment). By chance, I happen to be using ddpo for related optimizations. Consider the ideal situation, where there is only one prompt and its corresponding reward function. I still found that in the early stages of training, the reward mean is very fluctuate, even if I increase the training batch size or reduce the learning rate, although the overall reward mean is rising in the end. Are there any optimization techniques to make the optimization of a single prompt prompt stable? Any suggestions or insights would be greatly appreciated.

OOM when using "stabilityai/stable-diffusion-2-1" with batch size of 2

Hi,

Thanks for sharing the code!

I am using your code and fine-tuning this model stabilityai/stable-diffusion-2-1, I choose aesthetic, I have set Lora=True also. But the training is very memory intensive and in 80GB A100 it cannot even fit batch size of 2 per GPUs. I always have OOM error. Below are my settings:

config = compressibility()
config.project_name = "ddpo-aesthetic"
config.pretrained.model = "stabilityai/stable-diffusion-2-1"

config.num_epochs = 20000
config.reward_fn = "aesthetic_score"

# the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
config.sample.batch_size = 2
config.sample.num_batches_per_epoch = 1

# this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
config.train.batch_size = 2
config.train.gradient_accumulation_steps = 1

config.prompt_fn = "simple_animals"
config.per_prompt_stat_tracking = {
  "buffer_size": 32,
  "min_count": 16,
}

Any suggestions regarding this? I appreciate your help!

Hello, when I trained an aesthetic model using the default configuration on 8 A800 cards, I found that the training process got stuck after completing one epoch, but it worked fine when using a single A800 card. May I ask what could be the cause of this situation?

Hello, when I trained an aesthetic model using the default configuration on 8 A800 cards, I found that the training process got stuck after completing one epoch, but it worked fine when using a single A800 card. May I ask what could be the cause of this situation?

reproducing the aesthetic experiment

I am trying to reproduce the aesthetic experiment on a single GPU. I made the following changes to the config:

config.sample.batch_size = 1
config.sample.num_batches_per_epoch = 256
config.train.batch_size = 1
config.train.gradient_accumulation_steps = 128

My results are summarized in the following figure:
image

Few questions I have regarding the results:

  1. The paper generates "stylized line drawings". However, neither the reference nor my results for ddpo-pytorch show this behavior.
  2. Why is it that compared with the paper reward curve, my reward (5.5) already starts off higher than the end of the paper's reward curve (5.1)?
  3. Are there reference reward curves corresponding to teaser.jpg that we could compare with for all four experiments?

OOM despite using A100-80GB GPU and 128GB CPU memory (+16 CPUs per task)

Hi,

I am experiencing out-of-memory issues with this codebase even at the start of epoch 0.0, despite using an A100 with 80GB VRAM and 128GB RAM :/

I am using the following changes in the config file:
config.sample.batch_size = 1
config.sample.num_batches_per_epoch = 256
config.train.batch_size = 1
config.train.gradient_accumulation_steps = 128

SDXL Support?

Any chance for Stable Diffusion XL models support in the near future?

Question about the optimized objective.

Many thanks for conducting this excellent work. While I read this repo, I raised two questions on the optimized objective.

  1. Does the loss term in Eq. 4 align with the code implementation below? advantages corresponds to r(x_0, c) and ratio corresponds to the importance sampling term. But where is the gradient term \nabla_\theta \log p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{c}, t, \mathbf{x}_t\right)? Is it because taking gradient on ratio will implicitly generate the gradient term?
# ppo logic
advantages = torch.clamp(
    sample["advantages"], -config.train.adv_clip_max, config.train.adv_clip_max
)
ratio = torch.exp(log_prob - sample["log_probs"][:, j])
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
    ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range
)
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
  1. clip_range is set to 1e-4 by default, which makes the clipped_loss close to -advantages. The following operation torch.maximum(unclipped_loss, clipped_loss) will bound loss to be approximately greater than -advantages. What is the aim of setting such a small value of clip_range?

About the large dataset and Unet Training

Thanks for your great work!

I have repeated the lora training,and got a same results in supplied prompt. But I have some issues in large dataset and Unet Trainning.

Problem1: Overfitting

When I training the LoRA with 400 prompts, I find the reward overfitting to maximum easily in 4K steps. Even I update the prompts to 20K, the same problem will happen.
So, is it work for large dataset ? How to train the ddpo in a large dataset ?

400 prompts

image

20K prompts

image

Problem2: Unet Trainning

I set "config.use_lora = False" in config/base.py to train Unet, but the reward change to zero within tens steps, and the sd model generate black images.

image

Code logics, thanks

I am a bit confused with the logics in the train script.

A "new" unet is defined as pipeline.unet, unet.parameters() is then put in optimizer, and finally loss is computed from unet. Thus, can I understand that this new unet will be updated.

However, we know that pipeline.unet should be updated, and I can observe that the unet in pipeline is indeed updated, not the new unet.

Can anybody tell me why this new unet should be defined? Can we just use something like this:

optimizer(pipeline.unet.parameters(), ...)
noise_pred = pipeline.unet(...)

Thank you very much.

About the training with prompt_image_alignment configuration which uses llava_bertscore reward function

The result of training with llava_bertscore reward is:
image
But the result of training with aesthetic reward is:
image
This shows that the reinforcement learning works.
The configuration file, i.e. dgx.py and base.py, were not modified.
The hardware used in the training is 8 A800. The llava-server ran in GPU0 and GPU1. The training program ran in all 8 GPUs.
GPU0 and GPU1 was used 59G memory, and others use 30G memory.
The version of llava is liuhaotian/LLaVA-Lightning-MPT-7B-preview.

On reproducibility and LoRA

For the reproducibility experiments, right now the script has use_lora=True in the dgx.py. I just want to double check if that is indeed the case because the README.md seems a bit obscure.

Questions about the reward curve and bert.

Many thanks for conducting this excellent work!

I raised 2 questions while trying to reproduce the experiments.

  1. Why does the reward curve for prompt-image alignment shown in this repo fluctuate greatly, but the reward curve shown in Figure 5 in the paper is very smooth?
  2. While I was experimenting, the bert scores are extremely high. For example, for the noisy image below, LLaVA generated a description: "In the image, there is a colorful, abstract, and blurry background with a mix of colors and patterns. The background is filled with a variety of colors, creating a visually interesting and dynamic scene. " The bert score of this description with the prompt "a pig riding a bike" is about 0.82. I think this may be due to the fact that the bert model I used is different from yours. So I would like to know more details about the bert used in your experiments.

1711874754272

Very thanks if you can give me some help :-)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.