Code Monkey home page Code Monkey logo

pixart-sigma's Introduction

👉 PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation


This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on our project page.

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation
Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li
Huawei Noah’s Ark Lab, DLUT, HKU, HKUST


Welcome everyone to contribute🔥🔥!!

Learning from the previous PixArt-α project, we will try to keep this repo as simple as possible so that everyone in the PixArt community can use it.


Breaking News 🔥🔥!!

  • (🔥 New) Apr. 6, 2024. 💥 PixArt-Σ checkpoint 256px & 512px are released!
  • (🔥 New) Mar. 29, 2024. 💥 PixArt-Σ training & inference code & toy data are released!!!

🔧 Dependencies and Installation

conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt

🔥 How to Train

1. PixArt Training

First of all.

We start a new repo to build a more user friendly and more compatible codebase. The main model structure is the same as PixArt-α, you can still develop your function base on the original repo. lso, This repo will support PixArt-alpha in the future.

Now you can train your model without prior feature extraction. We reform the data structure in PixArt-α code base, so that everyone can start to train & inference & visualize at the very beginning without any pain.

1.1 Downloading the toy dataset

Download the toy dataset first. The dataset structure for training is:

cd ./pixart-sigma-toy-dataset

Dataset Structure
├──InternImgs/  (images are saved here)
│  ├──000000000000.png
│  ├──000000000001.png
│  ├──......
├──InternData/
│  ├──data_info.json    (meta data)
Optional(👇)
│  ├──img_sdxl_vae_features_1024resolution_ms_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npy
│  │  ├──000000000001.npy
│  │  ├──......
│  ├──caption_features_new
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......
│  ├──sharegpt4v_caption_features_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......

1.2 Download pretrained chechpoint

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

1.3 You are ready to train!

Selecting your desired config file from config files dir.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 \
          train_scripts/train.py \
          configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py \
          --load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth
          --work-dir output/your_first_pixart-exp \
          --debug

💻 How to Test

1. Quick start with Gradio

To get started, first install the required dependencies. Make sure you've downloaded the checkpoint files from models(coming soon) to the output/pretrained_models folder, and then run on your local machine:

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223

2. Integration in diffusers

(Coming soon)

💪To-Do List

We will try our best to release

  • Training code
  • Inference code
  • Model zoo
  • Diffusers
  • training & inference code of One Step Sampling with DMD

before 10th, April.

pixart-sigma's People

Contributors

lawrence-cj avatar xieenze avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.