Code Monkey home page Code Monkey logo

iterinpaint's Introduction

IterInpaint

The code for IterInpaint, a new baseline for layout-guided image generation, as described in the paper:

Diagnostic Benchmark and Iterative Inpainting for Layout-Guided Image Generation (CVPR 2024 Workshop)

Jaemin Cho, Linjie Li, Zhengyuan Yang, Zhe Gan, Lijuan Wang, Mohit Bansal

[Project Page] [Paper] [Gradio Demo ] [Colab Demo Colab]

Setup Environment

conda create -n iterinpaint python=3.9
conda activate iterinpaint

pip install torch torchvision
pip install -r requirements.txt

🧨 Diffusers support

We provide Huggingface Diffusers checkpoint for IterInpaint, where you can simply load our model as follows:

from diffusers import StableDiffusionInpaintPipeline

# CLEVR checkpoint
pipe = StableDiffusionInpaintPipeline.from_pretrained('j-min/IterInpaint-CLEVR')

# COCO checkpoint
pipe = StableDiffusionInpaintPipeline.from_pretrained('j-min/iterinpaint_sd15inpaint_coco')

Inference Demos

We provide demos for IterInpaint inference, where you can generate images with your own custom layouts.

Gradio

Gradio Demo with Diffusers

Notebooks

Inference with Diffusers - You can run this notebook on Colab.

Inference with original LDM codebase - You need 12GB+ CPU memory to build model (you would need Colab Pro).

Training IterInpaint on CLEVR

We provide pretrained checkpoints for IterInpaint on CLEVR.

Below, we provide the instructions for training IterInpaint on CLEVR.

1) Download SD checkpoint

mkdir preload_model_checkpoints
cd preload_model_checkpoints

# By default, we use SD v1.5 inpainting checkpoint as starting point (https://huggingface.co/runwayml/stable-diffusion-inpainting).
wget https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt

# You can replace it with other checkpoint, such as SD text2image from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
# wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt

Run Training

We train IterInpaint on 16 V100 GPUs (2 nodes x 8 GPUs at each node) with batch size 1 per GPU, gradient accumulation of 8. We train IterInpaint for 20K steps with the effective batch size of 128 (=16 x 8).

We update U-Net and CLIP text encoder parameters, while freezeing the autoencoder.

You can check and change other hyperparameters in the config file (configs/stable-diffusion/v1-finetune_clevr_iterinpaint_SD15.yaml).

config='configs/stable-diffusion/v1-finetune_clevr_iterinpaint_SD15.yaml'
SD_starting_checkpoint='preload_model_checkpoints/sd-v1-5-inpainting.ckpt'
data_root='datasets/clevr_data'
lr=1e-4
fg_task_ratio='030'
job_name='iterinpaint_CLEVR_FG30'
save_dir='results'
batch_size=1

python main.py
  --base $config
  --train
  --nodes 2
  --gpus 0,1,2,3,4,5,6,7
  --actual_resume $SD_starting_checkpoint
  --name $job_name
  --data_root $data_root
  --val_data_root $data_root
  --no-test true
  --lr $lr
  --batch_size $batch_size
  --logdir $save_dir/$job_name
  --fg_task_ratio $fg_task_ratio
  --seed 42

(optional) Convert LDM-based checkpoint to HF diffusers format

# checkpoint output path from training
ckpt_path=xxxx.ckpt
config_file=xxxx.project.yaml

# output path for HF diffusers checkpoint
dump_path=DUMP_PATH

python convert_iterinpaint_ldm_checkpoint_to_diffusers.py \
  --checkpoint_path $ckpt_path \
  --original_config_file $config_file \
  --image_size 512 \
  --prediction_type 'epsilon' \
  --pipeline_type 'FrozenCLIPEmbedder' \
  --extract_ema \
  --dump_path $dump_path

CLEVR inference

dump_dir='eval_images_dump/clevr'
config='configs/stable-diffusion/v1-inference-iterinpaint.yaml'
run_name='iterinpaint_guidance4.0'

torchrun \
  --nnodes=$n_nodes \
  --nproc_per_node=$n_gpus \
  scripts/clevr_inference.py \
  --eval_data 'clevr' \
  --plms \
  --scale 4.0 \
  --config $config \
  --ckpt $ckpt_path \
  --clevr_dump_dir $dump_dir \
  --save_bbox_viz \
  --name $run_name

LayoutBench inference

skill_split='number_few' # Change to other skill splits
dump_dir='eval_images_dump/layoutbench'
config='configs/stable-diffusion/v1-inference-iterinpaint.yaml'
run_name='iterinpaint_guidance4.0'

torchrun \
  --nnodes=$n_nodes \
  --nproc_per_node=_gpus \
  scripts/clevr_inference.py \
  --eval_data 'layoutbench' \
  --plms \
  --scale 4.0 \
  --config $config \
  --ckpt $ckpt_path \
  --layoutbench_dump_dir $dump_dir \
  --skill_split $skill_split \
  --save_bbox_viz \
  --name $run_name

Citation

If you find our project useful in your research, please cite the following paper:

@inproceedings{Cho2024LayoutBench,
  author    = {Jaemin Cho and Linjie Li and Zhengyuan Yang and Zhe Gan and Lijuan Wang and Mohit Bansal},
  title     = {Diagnostic Benchmark and Iterative Inpainting for Layout-Guided Image Generation},
  booktitle = {The First Workshop on the Evaluation of Generative Foundation Models},
  year      = {2024},
}

iterinpaint's People

Contributors

j-min avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

iterinpaint's Issues

COCO Experiments

Hi, thanks for sharing the code. I was trying your demo on hugging face space, but it doesn't work well for the more generalized prompts, for example, the prompts you used in the COCO Examples on the project page. I'm wondering if that experiment uses a different checkpoint. Is it possible to have access to it somewhere?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.