Code Monkey home page Code Monkey logo

vpgen's Introduction

VPGen: Step-by-Step Text-to-Image Generation with Interpretable Visual Programming

The code for VPGen, a new framework for text-to-image generation, as described in the paper:

Visual Programming for Text-to-Image Generation and Evaluation

Jaemin Cho, Abhay Zala, Mohit Bansal

[Project Page] [Paper] [Code for VPEval] [Open In Colab Colab Demo]


VPGen: Step-by-Step T2I Generation

VPGen is a novel visual programming framework for interpretable step-by-step text-to-image (T2I) generation. As illustrated in the figure, we decompose the text-to-image generation task into three steps: (1) object/count generation, (2) layout generation, and (3) image generation. VPGen employs an LM to handle the first two steps: (1) object/count generation and (2) layout generation. Then VPGen uses a layout-to-image module to generate images from the predicted layouts. For the layout generation LM, we finetune Vicuna 13B on text-layout pair annotations on three public datasets: Flickr30K entities, MS COCO, and PaintSkills. For layout-to-image generation, we use GLIGEN.


Code Structure

# Training & Inference Vicuna
utils/
task_utils.py
llama.py
lora_finetune.py
text2layout_inference.py

# Image inference with GLIGEN
inference_images.py
viz_utils.py

Setup

Setup Environment

conda create -n vpgen python=3.9
conda activate vpgen

pip install torch torchvision
pip install -r requirements.txt

Setup Vicuna 13B

Download pre-processed Vicuna 13B + LoRA Checkpoints from HF Hub

Currently we provide Vicuna13B + LoRA checkpoint finetuned on Flickr30K + COCO + PaintSkills. More checkpoints will be updated in the future.

print("Installing HF hub")
# !pip install -q --upgrade huggingface_hub

print("Downloading Vicuna13B weights")

from huggingface_hub import snapshot_download
snapshot_download(repo_id="j-min/vicuna-13b-v0-merged",
                  repo_type="model",
                  local_dir="vicuna_13b_checkpoint",
                  force_download=True,
)

print("Downloading LoRA weights")

from huggingface_hub import hf_hub_download

for filename in ['adapter_config.json', 'adapter_model.bin']:
  hf_hub_download(repo_id="j-min/VPGen",
                  filename=filename,
                  subfolder="vicuna13B_GPU4_flickr30k_coco_paintskills_epoch2_mbatch32_lora16_cutoff256",
                  local_dir="lora_checkpoint/",
  )

(Optional; Guideline to obtain Merged Vicuna 13B weights)

1) Download LLama 13B checkpoint

Weights for the LLaMA models can be obtained from by filling out this form.

2) Convert the weights into Huggingface Transformers compatible version, following https://huggingface.co/docs/transformers/main/model_doc/llama.

git clone https://github.com/huggingface/transformers
cd transformers
pip install -e .

python src/transformers/models/llama/convert_llama_weights_to_hf.py \
    --input_dir /path/to/downloaded/llama/weights \
	--model_size 13B \
	--output_dir /output/path

3) Download Vicuan 13B v0 delta weigths and merge with LLama weights to obtain Vicuna weights.

This conversion command needs around 60 GB of CPU RAM. See the "Low CPU Memory Conversion" section below if you do not have enough memory. Replace /path/to/* with the real paths.

Check https://github.com/lm-sys/FastChat#model-weights for more details.

# for v0 weights
pip install fschat==0.1.10

python -m fastchat.model.apply_delta \
    --base-model-path /path/to/llama-13b \
    --target-model-path vicuna_13b_checkpoint \
    --delta-path lmsys/vicuna-13b-delta-v0

Setup GLIGEN

Check https://github.com/gligen/diffusers/tree/gligen/examples/gligen for more details.

git clone https://github.com/gligen/diffusers gligen_diffusers
cd gligen_diffusers
pip install -e .

Finetuning Vicuna with LoRA

Flickr30k+COCO+Paintskills training

n_gpus=4
model='vicuna13B'
base_model_path='vicuna_13b_checkpoint'

micro_batch_size=24
batch_size=96
lora_r=16
epochs=2
cutoff_len=512

# https://huggingface.co/j-min/VPGen/blob/main/flickr30k_coco_paintskills_text2box_train.json
data='flickr30k_coco_paintskills'

run_name=$model"_GPU$n_gpus"_epoch"$epochs"_mbatch"$micro_batch_size"_lora"$lora_r"_cutoff"$cutoff_len"
data_path='TRAIN_FILE'

torchrun --nproc_per_node=4 \
    lora_finetune.py \
	--base_model $base_model_path \
	--data_path $data_path \
	--output_dir './output/'$run_name \
	--prompt_template_name text2box \
	--num_epochs $epochs \
	--batch_size $batch_size \
	--cutoff_len $cutoff_len \
	--group_by_length \
	--lora_target_modules '[q_proj,k_proj,v_proj,o_proj]' \
	--lora_r $lora_r \
	--micro_batch_size=$micro_batch_size

Layout Inference with Vicuna

It takes 10-15 minutes to load Vicuna weights. In our experiments, Vicuna 13B inference takes around 35GB CPU + 30GB GPU memory.

gpu_id=0

base_model_path='vicuna_13b_checkpoint'

# LoRA checkpoint path
lora_model_path='lora_checkpoint/vicuna13B_GPU4_flickr30k_coco_paintskills_epoch2_mbatch32_lora16_cutoff256'

# where to load prompts
prompts_path='DATA_PATH'

# Where to save the generated layouts
layout_dump_path='LAYOUT_DUMP_PATH'

echo $gpu_id

echo $base_model_path
echo $lora_model_path
echo $prompts_path
echo $layout_dump_path

python text2layout_inference.py \
	--llm_device "cuda:$gpu_id" \
	--base_model $base_model_path \
	--lora_model $lora_model_path \
	--data_path $prompts_path \
	--layout_dump_path $layout_dump_path

Image Generation with GLIGEN

GLIGEN inference requires around 6GB of GPU RAM.

gpu_id=0

model='gligen'

# layout generated by Vicuna
layout_path='LAYOUT_DUMP_PATH'

# Where to save the images
image_dump_dir='IMAGE_DUMP_PATH'

# Where to save the bounding box images
layout_image_dump_dir='LAYOUT_IMAGE_DUMP_PATH'

echo $gpu_id

echo $layout_path
echo $image_dump_dir
echo $layout_image_dump_dir

CUDA_VISIBLE_DEVICES=$gpu_id \
python inference_images.py \
    --model $model \
    --layout_path $layout_path \
    --image_dump_dir $image_dump_dir \
    --layout_image_dump_dir $layout_image_dump_dir \

Citation

If you find our project useful in your research, please cite the following paper:

@article{Cho2023VPT2I,
  author    = {Jaemin Cho and Abhay Zala and Mohit Bansal},
  title     = {Visual Programming for Text-to-Image Generation and Evaluation},
  year      = {2023},
}

vpgen's People

Contributors

j-min avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

vpgen's Issues

Lack of dataset description/examples

When I try to reproduce result, I found that the README lacks a description of the content structure of the dataset and does not provide flickr30k_coco_paintskills dataset, therefore cannot be replicated.
Can you provide a demo?

Checkpoint for LoRA

Hi there,
Thanks for this great work, it is making an excellent contribution to the community.
I want to try your inference stage. Could you provide the checkpoint for the LoRA?
Thank you for your time!
Sincerely,
ZiAng

Can you release the training data & training prompts?

I notices you miss some important file for training the vicuna lora, e.g., the data you use (what is the data format), and prompt template (text2box). Could you please release them and provide a download link? So I can train the lora on my machine. And I don't think that those files are too big, and it will be helpful for following works if you provide the detailed training data/files.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.