Code Monkey home page Code Monkey logo

t-gate's Introduction

T-GATE: Temporally Gating Attention to Accelerate Diffusion Model for Free! 🥳

GitHub arxiv GitHub release

TGATE-V1: Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
Wentian Zhang*  Haozhe Liu1*  Jinheng Xie2*  Francesco Faccio1,3  Mike Zheng Shou2  Jürgen Schmidhuber1,3 

1 AI Initiative, King Abdullah University of Science And Technology  

2 Show Lab, National University of Singapore   3 The Swiss AI Lab, IDSIA

TGATE-V2: Faster Diffusion Through Temporal Attention Decomposition
Haozhe Liu1,4*  Wentian Zhang*  Jinheng Xie2*  Francesco Faccio1,3  Mengmeng Xu4  Tao Xiang4  Mike Zheng Shou2  Juan-Manuel Pérez-Rúa4  Jürgen Schmidhuber1,3 

1 AI Initiative, King Abdullah University of Science And Technology  

2 Show Lab, National University of Singapore   3 The Swiss AI Lab, IDSIA   4 Meta

Code and Technical Report will be released soon!

Quick Introduction

We explore the role of the attention mechanism during inference in text-conditional diffusion models. Empirical observations suggest that cross-attention outputs converge to a fixed point after several inference steps. The convergence time naturally divides the entire inference process into two phases: an initial phase for planning text-oriented visual semantics, which are then translated into images in a subsequent fidelity-improving phase. Cross-attention is essential in the initial phase but almost irrelevant thereafter. Self-attention, however, initially plays a minor role but becomes increasingly important in the second phase. These findings yield a simple and training-free method called TGATE which efficiently generates images by caching and reusing attention outputs at scheduled time steps. Experiments show TGATE’s broad applicability to various existing text-conditional diffusion models which it speeds up by 10-50%.

The images generated by the diffusion model with or without TGATE. Our method can accelerate the diffusion model without generation performance drops. It is training-free and can be widely complementary to the existing studies.

🚀 Major Features

  • Training-Free.
  • Easily Integrate into Existing Frameworks.
  • Only a few lines of code are required.
  • Friendly support CNN-based U-Net, Transformer, and Consistency Model
  • 10%-50% speed up for different diffusion models.

📄 Updates

  • 2024/05/22: We have successfully extended TGATE to self-attention modules for greater acceleration! Stay tuned for a major update, which will be released in the coming weeks.

  • 2024/04/17: TGATE v0.1.1 is officially added to diffusers.

  • 2024/04/14: We release TGATE v0.1.1 to support the playground-v2.5-1024 model.

  • 2024/04/10: We release our package to PyPI. Check here for the usage.

  • 2024/04/04: Technical Report is available on arxiv.

  • 2024/04/04: TGATE for DeepCache (SD-XL) is released.

  • 2024/03/30: TGATE for SD-1.5/2.1/XL is released.

  • 2024/03/29: TGATE for LCM (SD-XL), PixArt-Alpha is released.

  • 2024/03/28: TGATE is open source.

📖 Key Observation

The images generated by the diffusion model at different denoising steps. The first row feeds the text embedding to the cross-attention modules for all steps. The second row only uses the text embedding from the first step to the 10th step, and the third row inputs the text embedding from the 11th to the 25th step.

We summarize our observations as follows:

  • Cross-attention converges early during the inference process, which can be characterized by a semantics-planning and a fidelity-improving stages. The impact of cross-attention is not uniform in these two stages.

  • The semantics-planning embeds text through cross-attention to obtain visual semantics.

  • The fidelity-improving stage improves the generation quality without the requirement of cross-attention. In fact, a null text embedding in this stage can improve performance.

🖊️ Method

  • Step 1: TGATE caches the attention outcomes from the semantics-planning stage.
if gate_step == cur_step:
    hidden_uncond, hidden_pred_text = hidden_states.chunk(2)
    cache = (hidden_uncond + hidden_pred_text ) / 2
  • Step 2: TGATE reuses them throughout the fidelity-improving stage.
if cross_attn and (gate_step<cur_step):
    hidden_states = cache

📄 Results

Model MACs Param Latency Zero-shot 10K-FID on MS-COCO
SD-1.5 16.938T 859.520M 7.032s 23.927
SD-1.5 w/ TGATE 9.875T 815.557M 4.313s 20.789
SD-2.1 38.041T 865.785M 16.121s 22.609
SD-2.1 w/ TGATE 22.208T 815.433 M 9.878s 19.940
SD-XL 149.438T 2.570B 53.187s 24.628
SD-XL w/ TGATE 84.438T 2.024B 27.932s 22.738
Pixart-Alpha 107.031T 611.350M 61.502s 38.669
Pixart-Alpha w/ TGATE 65.318T 462.585M 37.867s 35.825
DeepCache (SD-XL) 57.888T - 19.931s 23.755
DeepCache w/ TGATE 43.868T - 14.666s 23.999
LCM (SD-XL) 11.955T 2.570B 3.805s 25.044
LCM w/ TGATE 11.171T 2.024B 3.533s 25.028
LCM (Pixart-Alpha) 8.563T 611.350M 4.733s 36.086
LCM w/ TGATE 7.623T 462.585M 4.543s 37.048

The latency is tested on a 1080ti commercial card.

The MACs and Params are calculated by calflops.

The FID is calculated by PytorchFID.

🛠️ Requirements

  • pytorch>=2.0.0
  • diffusers>=0.27.2
  • transformers==4.37.2
  • DeepCache==0.1.1
  • accelerate

🌟 Usage

Examples

To use TGATE for accelerating the denoising process, you can simply use main.py. For example,

  • SD-2.1 w/ TGATE: generate an image with the caption: "High quality photo of an astronaut riding a horse in space"
python main.py \
--prompt 'A coral reef bustling with diverse marine life.' \
--model 'sd_2.1' \
--gate_step 10 \ 
--saved_path './sd_2_1/' \
--inference_step 25 \
  • SD-XL w/ TGATE: generate an image with the caption: "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
python main.py \
--prompt 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k' \
--model 'sd_xl' \
--gate_step 10 \ 
--saved_path './sd_xl/' \
--inference_step 25 \
  • Pixart-Alpha w/ TGATE: generate an image with the caption: "An alpaca made of colorful building blocks, cyberpunk."
python main.py \
--prompt 'An alpaca made of colorful building blocks, cyberpunk.' \
--model 'pixart' \
--gate_step 8 \ 
--saved_path './pixart_alpha/' \
--inference_step 25 \
  • LCM-SDXL w/ TGATE: generate an image with the caption: "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
python main.py \
--prompt 'Self-portrait oil painting, a beautiful cyborg with golden hair, 8k' \
--model 'lcm_sdxl' \
--gate_step 1 \ 
--saved_path './lcm_sdxl/' \
--inference_step 4 \
  • SDXL-DeepCache w/ TGATE: generate an image with the caption: "A haunted Victorian mansion under a full moon."
python main.py \
--prompt 'A haunted Victorian mansion under a full moon.' \
--model 'sd_xl' \
--gate_step 10 \ 
--saved_path './sd_xl_deepcache/' \
--inference_step 25 \
--deepcache \
  1. For LCMs, gate_step is set as 1 or 2, and inference step is set as 4.

  2. To use DeepCache, deepcache is set as True.

Third-party Usage

📖 Related works:

We encourage the users to read DeepCache and Adaptive Guidance

Methods U-Net Transformer Consistency Model
DeepCache -
Adaptive Guidance
TGATE (Ours)

Compared with DeepCache:

  • TGATE can cache one time and re-use the cached feature until ending sampling.
  • TGATE is more friendly for Transformer-based Architecture and mobile devices since it drops the high-resolution cross-attention.
  • TGATE is complementary to DeepCache.

Compared with Adaptive Guidance:

  • TGATE can reduce the parameters in the second stage.
  • TGATE can further improve the inference efficiency.
  • TGATE is complementary to non-cfg framework, e.g. latent consistency model.
  • TGATE is open source.

Acknowledgment

Citation

If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.

@article{tgate,
  title={Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models},
  author={Zhang, Wentian and Liu, Haozhe and Xie, Jinheng and Faccio, Francesco and Shou, Mike Zheng and Schmidhuber, J{\"u}rgen},
  journal={arXiv preprint arXiv:2404.02747},
  year={2024}
}

t-gate's People

Contributors

eltociear avatar haozheliu-st avatar jetthu avatar sierkinhane avatar wentianzhang-ml avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

t-gate's Issues

TGATE v0.1.1 encounter ValueError when performing multiple forward inferences

Hi! Thank you for the amazing work.

I encounter the ValueError when performing multiple forward inferences:
image

Here's the testing code I used:

pipe = TgateSDLoader(
            pipe,
            gate_step=gate_step,
            num_inference_steps=inference_step
       ).to("cuda")
start_time = time.time()
for _ in range(infer_times):
    tagate_image = pipe.tgate(
          prompt,
          gate_step=gate_step,
          num_inference_steps=inference_step
      ).images    
    latency = (time.time() - start_time) / infer_times
    logging.info("T-GATE: {:.2f} seconds".format(latency))

Hope you can resolve this issue.

How to reproduce FID from paper?

Hi! I'am trying to reproduce results of T_GATE (FID metric) that described in your technical report using SDXL model, DPM scheduler with 25 inference steps and gate step is 10. I'am using MS_COCO 256x256 benchmark from https://github.com/Nota-NetsPresso/BK-SDM.git repository and got very big FID instead of 22.738 that presented in your paper on arxiv. Other metrics that I measure like Inception score and CLIP score is normal. Can you please provide more information about hyperparameters (guidance scale for example), image resolution? What captions used for generation (full validation set from MSCOCO-2014 or MSCOCO-2017, or maybe some subset from them) and what real images was used to measure FID between real and generated samples?

cross-attention Difference code

Hi,
thank you for your indepth analysis,
could you open source how to compute the cross-attention Difference code given in Figue 2 ?

confusion about speedup

Hello, thank you for your excellent work. Your work considers the redundancy of crossattn and uses the cache approach to solve the above problem, and finally achieves the speedup of the generation. As far as I know, the computational cost of self-attn and ffn is larger than that of cross-attn. However, it is pointed out in the paper that t-gate can achieve nearly 40% speedup with only cache cross-attn. How such a high speedup is achieved, if I have some misunderstanding of the technology. I would appreciate it if you could help me solve this confusion. Thanks! 🌹

About Playground-v2.5-1024 model.

Hi!
Thanks for your amazing work.

Playground-v2.5-1024 is a stronger T2I model based on the SD-XL architecture.
(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic)
I try to use the follow code to speed up the model, but the result seems terrible.

import torch
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
        "playgroundai/playground-v2.5-1024px-aesthetic",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
)

from tgate import TgateSDXLLoader
gate_step = 10
inference_step = 25
pipe = TgateSDXLLoader(
       pipe,
       gate_step=gate_step,
       num_inference_steps=inference_step,
)
pipe = pipe.to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k."
        
image = pipe.tgate(
        prompt,
        gate_step=gate_step,
        num_inference_steps=inference_step
).images[0]
image.save(f"{prompt}.png") 

Astronaut in a jungle, cold color palette, muted colors, detailed, 8k

Is there any way to solve the problem?
I am looking for your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.