Code Monkey home page Code Monkey logo

enhancing-transformers's Introduction

Table of Contents
  1. About The Project
  2. Getting Started
  3. Roadmap
  4. Contributing
  5. License
  6. Contact
  7. Acknowledgments

News

09/09

  1. The release weight of ViT-VQGAN small which is trained on ImageNet at here

16/08

  1. First release weight of ViT-VQGAN base which is trained on ImageNet at here
  2. Add an colab notebook at here

About The Project

This is an unofficial implementation of both ViT-VQGAN and RQ-VAE in Pytorch. ViT-VQGAN is a simple ViT-based Vector Quantized AutoEncoder while RQ-VAE introduces a new residual quantization scheme. Further details can be viewed in the papers

Getting Started

For the ease of installation, you should use anaconda to setup this repo.

Installation

A suitable conda environment named enhancing can be created and activated with:

conda env create -f environment.yaml
conda activate enhancing

Training

Training is easy with one line: python3 main.py -c config_name -lr learning_rate -e epoch_nums

Roadmap

  • Add ViT-VQGAN
    • Add ViT-based encoder and decoder
    • Add factorized codes
    • Add l2-normalized codes
    • Replace PatchGAN discriminator with StyleGAN one
  • Add RQ-VAE
    • Add Residual Quantizer
    • Add RQ-Transformer
  • Add dataloader for some common dataset
    • ImageNet
    • LSUN
    • COCO
      • Add COCO Segmentation
      • Add COCO Caption
    • CC3M
  • Add pretrained models
    • ViT-VQGAN small
    • ViT-VQGAN base
    • ViT-VQGAN large

Contributing

Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated.

If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again!

  1. Fork the Project
  2. Create your Feature Branch (git checkout -b feature/AmazingFeature)
  3. Commit your Changes (git commit -m 'Add some AmazingFeature')
  4. Push to the Branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

License

Source code and pretrained weights are distributed under the MIT License. See LICENSE for more information.

Contact

Thuan H. Nguyen - @leejohnthuan - [email protected]

Acknowledgements

This project would not be possible without the generous sponsorship from Stability AI and helpful discussion of folks in LAION discord

This repo is heavily inspired by following repos and papers:

enhancing-transformers's People

Contributors

thuangb avatar thuanz123 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

enhancing-transformers's Issues

Learning rate and scheduler for stage 1 training.

Hi, @thuanz123

Thank you for reproducing and open-sourcing ViT-VQGAN using PyTorch. I would like to inquire, is the learning rate fixed at 4.5e-6 during stage 1 training in your code? I noticed that the paper mentions a base_lr of 1e-4, along with linear warm-up and cosine decay.

why not used pretrained discriminator

Hi,

this is my first project for VQGAN. I wonder what is the disadvantage of using a pre-trained stylegan discriminator instead of training it from scratch? In this way, you don't need to optimize the discriminator.

stage2 transformer

          hi @manuelknott, the code for stage2 transformer is currently buggy so after I fixed everything, I will try to train and released a pretrained model. But this will be a long time later since I'm still learning about autoregressive modeling with transformers.

Originally posted by @thuanz123 in #8 (comment)

Hi, I'm sorry to bother you. Has this pre-trained model been released?

OOM for imagenet_gpt_vitvq_base and a 100M params GPT on A100 40G

model:
    target: enhancing.modules.stage2.transformer.CondTransformer
    params:
        cond_key: class
        cond: 
            target: enhancing.modules.cond.dummycond.ClassCond
            params:
                image_size: 256
                class_name: assets/class/imagenet.txt
        stage1:
            target: enhancing.modules.stage1.vitvqgan.ViTVQ
            params:
                image_key: image
                path: weight/imagenet_vitvq_base.ckpt
                image_size: 256
                patch_size: 8
                encoder:
                    dim: 768
                    depth: 12
                    heads: 12
                    mlp_dim: 3072
                decoder:
                    dim: 768
                    depth: 12
                    heads: 12
                    mlp_dim: 3072
                quantizer:
                    embed_dim: 32
                    n_embed: 8192
                loss:
                    target: enhancing.losses.vqperceptual.DummyLoss
        transformer:
            target: enhancing.modules.stage2.layers.GPT
            params:
                vocab_cond_size: 1000
                vocab_img_size: 8192
                embed_dim: 768
                cond_num_tokens: 1
                img_num_tokens: 1024 
                n_heads: 12
                n_layers: 12
                
dataset:
    target: enhancing.dataloader.DataModuleFromConfig
    params:
        batch_size: 1
        num_workers: 4
        train:
            target: enhancing.dataloader.imagenet.ImageNetTrain
            params:
                root: /fsx/ilsvrc2012
                resolution: 256
                
        validation:
            target: enhancing.dataloader.imagenet.ImageNetValidation
            params:
                root: /fsx/ilsvrc2012
                resolution: 256

    parser.add_argument('-c', '--config', type=str, required=True)
    parser.add_argument('-s', '--seed', type=int, default=0)
    parser.add_argument('-nn', '--num_nodes', type=int, default=1)
    parser.add_argument('-ng', '--num_gpus', type=int, default=1)
    parser.add_argument('-u', '--update_every', type=int, default=1)
    parser.add_argument('-e', '--epochs', type=int, default=100)
    parser.add_argument('-lr', '--base_lr', type=float, default=4.5e-4)
    parser.add_argument('-a', '--use_amp', default=False, action='store_true')
    parser.add_argument('-b', '--batch_frequency', type=int, default=750)
    parser.add_argument('-m', '--max_images', type=int, default=4)
    args = parser.parse_args()
Sanity Checking DataLoader 0:   0%|                                                                                                                                                      | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/fsx/BlinkDL/CODE/enhancing-transformers/main.py", line 61, in <module>
    trainer.fit(model, data)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
    self._run_sanity_check()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
    val_loop.run()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 134, in advance
    self._on_evaluation_batch_end(output, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 267, in _on_evaluation_batch_end
    self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 141, in on_validation_batch_end
    self.log_img(pl_module, batch, batch_idx, split="val")
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 108, in log_img
    images = pl_module.log_images(batch, split=split, pl_module=pl_module)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 204, in log_images
    log["first samples"] = self.sample(cond_codes, return_pixels=True)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 87, in sample
    logits, codes = self.transformer.sample(conds=conds, top_k=top_k, top_p=top_p,
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 214, in sample
    logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 278, in sample_step
    x, present = block.sample(x, layer_past= past[i])
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 122, in sample
    attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 70, in forward
    att = F.softmax(att, dim=-1)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 1834, in softmax
    ret = input.softmax(dim)
RuntimeError: CUDA out of memory. Tried to allocate 11.21 GiB (GPU 0; 39.59 GiB total capacity; 24.12 GiB already allocated; 6.96 GiB free; 30.91 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

about training loss

Hi, I'm trying to reproduce the result, but got a higher loss for stage2 image token regression task.

Could you please provide your training loss curves for reference?

Thanks a lot

stage2 transform

When you run out of the task, how many steps you run will have good images. At present, I have run 20,000 steps or there is no good image. I wonder if there is a problem with your production code?

Pretrained Stage 2 Transformer for ViT-VQGAN

Hi,

Thank you for the great work and especially for publishing the pre-trained models.

I was wondering: do you also plan to publish the weights of the autoregressive transformer (stage 2 training) of ViT-VQGAN?

An inplace operation in the forward process

Hi,

I found an inplace operation in the forward process of ViTDecoder:

token += self.de_pos_embedding

which may cause the error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

This error can occur in such use cases:

decoder_a = ViTDecoder(**decoder_config_a)
decoder_b = ViTDecoder(**decoder_config_b)
rec_a = decoder_a(quant)
rec_b = decoder_b(quant)        # Above error will occur when backwarding, as the quant is modified by the above inplace operation here

Maybe it would be better to change it to:
token = token + self.de_pos_embedding

P.S. Perhaps this operation also occurs elsewhere, I did not double-check other places.

raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['where', 'cl']' returned non-zero exit status 1.

Hi, i've got this error when trying to run the simple training, it seems to be linked with visual studio or my environment variables
I'm with anaconda on Windows (had to install the requirements manually)
pytorch==1.12.0
python==3.10.4

Thanks a lot

C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py:346: UserWarning: Error checking compiler version for cl: [WinError 2] The system cannot find the file specified
warnings.warn(f'Error checking compiler version for {compiler}: {error}')
INFO: Could not find files for the given pattern(s).
Traceback (most recent call last):
File "C:\Users\BVA\Documents\enhancing-transformers\main.py", line 41, in
model = initialize_from_config(config.model)
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\utils\general.py", line 40, in initialize_from_config
return get_obj_from_str(config["target"])(**config.get("params", dict()))
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\modules\stage1\vitvqgan.py", line 36, in init
self.loss = initialize_from_config(loss)
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\utils\general.py", line 40, in initialize_from_config
return get_obj_from_str(config["target"])(**config.get("params", dict()))
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\utils\general.py", line 36, in get_obj_from_str
return getattr(importlib.import_module(module, package=None), cls)
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\importlib_init_.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "", line 1050, in _gcd_import
File "", line 1027, in _find_and_load
File "", line 1006, in _find_and_load_unlocked
File "", line 688, in _load_unlocked
File "", line 883, in exec_module
File "", line 241, in call_with_frames_removed
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\losses\vqperceptual.py", line 14, in
from .layers import *
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\losses\layers.py", line 19, in
from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\losses\op_init
.py", line 1, in
from .fused_act import FusedLeakyReLU, fused_leaky_relu
File "C:\Users\BVA\Documents\enhancing-transformers\enhancing\losses\op\fused_act.py", line 11, in
fused = load(
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py", line 1202, in load
return _jit_compile(
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py", line 1425, in _jit_compile
_write_ninja_file_and_build_library(
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py", line 1524, in _write_ninja_file_and_build_library
_write_ninja_file_to_build_library(
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py", line 1963, in _write_ninja_file_to_build_library
_write_ninja_file(
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\site-packages\torch\utils\cpp_extension.py", line 2090, in _write_ninja_file
cl_paths = subprocess.check_output(['where',
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\subprocess.py", line 420, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "C:\ProgramData\Anaconda3\envs\enhancing\lib\subprocess.py", line 524, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['where', 'cl']' returned non-zero exit status 1.

stage1 pretraining

Thank you for your great work! I want to train a stage1 version myself. If I raise the dataset of stage1 to the billion level, will it improve the reconstruction quality?

ImageNet version.

@thuanz123 Hi, I want to know what is your imagenet version? imagenet-1k, imagenet-21k, imagenet-full or openimage?
Thank you.

Incomplete implementation of RQ-VAE

It seems that the ViTVQ only consider the origion VQVAE(Not considering codes has 3 dimensions, e.g. HxWxN, N means residual times in RQVAE)

def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor:

RQ-VAE is only mentioned here(and flag use_residual is always False)

Did I miss something? Looking forward to your reply!

Model license

Hi.

I am aware of the fact that the repository itself is under the MIT license.

Since the models are external to the repository, if you trained the models from scratch, would you consider specifying a license for the trained models (or specify that the license applies to the models too)?

Results

Hello. Thanks for your work.
What are training results and metrics that you get on any dataset with thus code?

Smaller images

I want to try to train it on the dataset of small 32 by 32. This should make training relatively fast and is ok for my task. I have changed the dataset code. Can you please suggest how
to estimate correct hparams and what I will need to change. Looks like it is not enough to change image_size and patch_size in enhancing.modules.stage1.vitvqgan.ViTVQ config.

Reconstruction results

Hi, First of all thanks for you work.

Working with vit small, I see that results are far away from VQGAN, did you stop training when reached convergence? Do you think there is more room to improve the model performance/

Results with vit-small
cat
input image
212861459-e4113b34-622d-4602-afe4-f20e2d79425c

Training time and number of GPUS

Hi, thanks for sharing the implementation. I wonder how many gpus (and what kind of gpu) you have used and how long to train the stage 1 and stage 2. Since I don't have much gpus, I want to see if I can afford the training or fine tuning.

Reconstruction Visualization

Hi,
Thank you for this awesome project. I try the weight you released and I find reconstructed images are strange.
This is the input image:
image
And this is the output:
image

I only follow the colab and your source code:


import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image
import PIL
import io
import importlib
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from utils.general import initialize_from_config

config = OmegaConf.load('configs/imagenet_vitvq_base.yaml')
model = initialize_from_config(config.model)
model.init_from_ckpt('configs/vq-f8/imagenet_vitvq_base.ckpt')

def preprocess(img):
    s = min(img.size)

    if s < 256:
        raise ValueError(f'min dim for image {s} < 256')

    r = 256 / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [256])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img

original=Image.open(r"C:\Users\All\Desktop\1e8fb4e75b444e99a9b7cd47e4157b4a.jpg")
image=preprocess(original)
image = image.unsqueeze(0)

quant, _  = model.encode(image)
dec = model.decode(quant)

dec = dec.clamp(0, 1)
dec = dec[0]
dec = dec.detach().numpy()
dec = dec.transpose((1, 2, 0))
plt.imshow(dec)
plt.show()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.