Code Monkey home page Code Monkey logo

ladagan's Introduction

LadaGAN

This repo is the official implementation of "Efficient generative adversarial networks using linear additive-attention Transformers".

By Emilio Morales-Juarez and Gibran Fuentes-Pineda.

Abstract

Although the capacity of deep generative models for image generation, such as Diffusion Models (DMs) and Generative Adversarial Networks (GANs), has dramatically improved in recent years, much of their success can be attributed to computationally expensive architectures. This has limited their adoption and use to research laboratories and companies with large resources, while significantly raising the carbon footprint for training, fine-tuning, and inference. In this work, we present LadaGAN, an efficient generative adversarial network that is built upon a novel Transformer block named Ladaformer. The main component of this block is a linear additive-attention mechanism that computes a single attention vector per head instead of the quadratic dot-product attention. We employ Ladaformer in both the generator and discriminator, which reduces the computational complexity and overcomes the training instabilities often associated with Transformer GANs. LadaGAN consistently outperforms existing convolutional and Transformer GANs on benchmark datasets at different resolutions while being significantly more efficient. Moreover, LadaGAN shows competitive performance compared to state-of-the-art multi-step generative models (e.g. DMs) using orders of magnitude less computational resources.

Dependencies

  • Python 3.9
  • Tensorflow <= 2.13.1

A conda environment can be created and activated with:

conda create --name tf13 python=3.9.16
conda activate tf13
pip install tensorflow[and-cuda]==2.13.1 
pip install numpy matplotlib pillow scipy tqdm huggingface-hub

Training LadaGAN

Use --file_pattern=<file_pattern> and --eval_dir=<eval_dir> to specify the dataset path and FID evaluation path.

python train.py --file_pattern=./data_path/*png --eval_dir=./eval_path/*png

FLOPs

Using a single 12GB GPU (RTX 3080 Ti) for training on CIFAR-10 and CelebA datasets takes less than 40 hours:

Model (CIFAR 10 32x32) ADM-IP (80 steps) StyleGAN2 VITGAN LadaGAN
GPUs Tesla V100 x 2 - - RTX 3080 Ti x 1
#Images 69M - - 68M
#Params 57M - - 19M
FLOPs 9.0B - - 0.7B
FID 2.93 5.79 4.57 3.29
Model (CelebA 64x64) ADM-IP (80 steps) StyleGAN2 VITGAN LadaGAN
GPUs Tesla V100 x 16 - - RTX 3080 Ti x 1
#Images 138M - - 72M
#Params 295M 24M 38M 19M
FLOPs 103.5B 7.8B 2.6B 0.7B
FID 2.67 - 3.74 1.81
Model (FFHQ 128x128) ADM-IP (80 steps) StyleGAN2 VITGAN LadaGAN
#Images 61M - - 24M
#Params 543M - - 24M
FLOPs 391.0B 11.5B 11.8B 4.3B
FID 6.89 - - 4.48

Hparams setting

Adjust hyperparameters in the config.py file.

Implementation notes:

  • This model depends on other files that may be licensed under different open source licenses.
  • LadaGAN uses Differentiable Augmentation. Under BSD 2-Clause "Simplified" License.
  • FID evaluation.
  • Efficient patch generation with XLA.

Demo

Open In Colab

Attention maps

Single head maps training progress:

BibTeX

@article{morales2024efficient,
  title={Efficient generative adversarial networks using linear additive-attention Transformers},
  author={Morales-Juarez, Emilio and Fuentes-Pineda, Gibran},
  journal={arXiv preprint arXiv:2401.09596},
  year={2024}
}

License

MIT

ladagan's People

Contributors

milmor avatar

Stargazers

Richard Burleigh avatar Bwang avatar John D. Pope avatar Gibran Fuentes-Pineda avatar Seb avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

johndpope

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.