Code Monkey home page Code Monkey logo

dall-e's Introduction

Overview

[Blog] [Paper] [Model Card] [Usage]

This is the official PyTorch package for the discrete VAE used for DALL·E. The transformer used to generate the images from the text is not part of this code release.

Installation

Before running the example notebook, you will need to install the package using

pip install DALL-E

dall-e's People

Contributors

adityaramesh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dall-e's Issues

Why the dVAE decoder outputs 6 channels instead of 3 channels

The example code (DALL-E/notebooks/usage.ipynb) seems to use the first 3 channels of the dVAE decoder output as Image data and then why the decoder output 6 channels? It seems waste of computer resources to me because it needs more parameters compared with when it outputs 3 channels. Are there any reasons for this?

Implementation Doubts

Although this codebase is for the vae part. would appreciate if you could help in understanding few components of the transformer part also.
In the paper and blog released, you mentioned that you use Child et al paper. Can you elaborate on what you use as the block size? 8/16/32
If we use a block size of 16 for example then how do you implement the convolutional kernel, it has gaps of only 1 block but if you have sparse block of size 16 then it doesn't make sense.

Also, when you are training the gpt style model. Even though the loss and perplexity reduce, how do you identify when the perplexity/loss value of the 1.2B parameter model is sufficient? like is a loss of 4/5 good or should it be <1.

Discrete Bottleneck Method

Thank you for letting us take an advance peek into the model! I'm sure folks will have fun doing some interesting art and research with this.

From the blog post, there's reference made to the Gumbel-Softmax trick, the Concrete distribution, and previous VQ-VAEs, but it is left ambiguous which of these are used in the discrete VAE and how. Looking through the model in this repo, I do not see anything that fully clarifies this: there's an encoder that predicts distributions over the vocab z_logits, which are argmax-ed to get codes z, and then run through a decoder. Not sure if this clarifies, though, how the logits predicted by the encoder are translated into codes in training.

How does this work in training? Are the codes hard-sampled (argmax) with noise added to the logits, soft-sampled (softmax) with noise added to the logits, or something else?

Training the discrete VAE

Hi,

I have looked at the usage.ipynb. But I find the training pipeline from the notebook is quite different from what is claimed in the paper. I am wondering how to transform the encoder output to the decoder input. From the paper released I guess the only thing to modify is to replace the argmax in the notebook with a gumbel softmax. Is my understanding correct?

Thanks!

How to sample or generate a new image?

Hi, it's a great work! But I am a little confused about how to generate a new image? Shall I give the sentence tokens and then use them to predict the image tokens? And where to inject the noise? It will be very appreciate that you can answer these questions, thank you!

License

Will there be a license? Currently there is none.
MIT license?

Decoder weights for reconstructing 16x16 patches instead of 8x8

Hi,

thanks a lot for your great work!

I wanted to ask whether you have been experimenting with other patch dimensions than 8x8 such as 16x16 (i. e. splitting a 224x224 image into 14x14 visual tokens instead of 28x28). And if so, whether you could share the resulting weights.

Thank you

Jupyter Notebook with Preprocess error

Hello,
A long while back I cloned the Dall-E repo and watched the penguin appear -> yay! :)
So now I did it again, I have a new GPU so why not. Figured I'd run the default CPU first and then flip to CUDA.
This time I got unanticipated errors. (I originally had a path problem caused by two versions of python installed, uninstalled the old one, pip install numpy and torch and good to go) I've attached multiple screenshots to help provide greater context. I feel like this is still somehow related to my original path issue.. or possibly i'm missing a site-package that I can't think of....

Thank you for your time and assistance.

# Do not call functions when jit is used
That's thrown multiple times at the last run and finishes with the below as a finale
`
C:\Python310\lib\site-packages\torch\nn\modules\upsampling.py in forward(self, input)
152 def forward(self, input: Tensor) -> Tensor:
153 return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
--> 154 recompute_scale_factor=self.recompute_scale_factor)
155
156 def extra_repr(self) -> str:

C:\Python310\lib\site-packages\torch\nn\modules\module.py in getattr(self, name)
1183 if name in modules:
1184 return modules[name]
-> 1185 raise AttributeError("'{}' object has no attribute '{}'".format(
1186 type(self).name, name))
1187

AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'
`

Screenshot 2022-03-16 080610
Screenshot 2022-03-16 080652
Screenshot 2022-03-16 080716
Screenshot 2022-03-16 080743
Screenshot 2022-03-16 080807
Screenshot 2022-03-16 081012

Where's the image generation code

This page from Open-AI official website lead me here, it says that "DALL·E is a 12-billion parameter version of GPT-3 trained to generate images from text descriptions", but in this repo called DALL-E, it says "The transformer used to generate the images from the text is not part of this code release."

Where should I go if I want to do real image generation? Is Open-AI going to release your pre-trained model?

PyPi package

Could you please release this as a PyPi package so that

  1. No one else claims it and provides something that isn't your original implementation
  2. We downstream can include it in our packages (since installing from git repo URLs is unsupported within the context of a package's dependencies)

Help

Hello i am brand new to the github community and coding, i have zero idea how to install this but im an artist and it would be an excellent resource for non copyrighted images, i know its alot to ask but can someone please tell me how to install this code i made my account for this specifically for this

Pretraining Model

I want to know if you have any plans to release the pre-training model and code, just like clip model

Decoder-only architecture

Not so much an issue as a question - I read the paper but didn't see any discussion or justification for the decoder-only architecture.

Maybe this is a dumb question, but it seems that there's two discrete modalities (text and dVAE tokens) with entirely separate distributions of tokens so it would intuitively make sense to use an encoder-decoder architecture instead. There's already a text encoder in CLIP so it's even more convenient to do it this way.

I'm sure there are solid reasons for this architecture, just curious what considerations went into this architecture choice.

KL Loss

I am having trouble getting the dVAE to train properly if I include the KL loss term with a uniform prior over the number of visual tokens. Does anyone here has had similar experiences or problems? The paper mentions an increasing schedule for the kl weight factor but I cant get it to work properly and results are always better if I set the KL loss to zero altogether.

Maybe someone can help?

Training code

Hi,
thanks for your great work.
I'm wondering is there any plan to release training code of dVAE? I want to fine tune the model on a different dataset.

Error on executing usage.ipynb notebook on a cuda:0 device

I changed this line as sugggested to use the GPU:

# This can be changed to a GPU, e.g. 'cuda:0'.
dev = torch.device('cuda:0')

And I tried to execute the notebook. I got the following error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_11/3257249919.py in <module>
      1 import torch.nn.functional as F
      2 
----> 3 z_logits = enc(x)
      4 z = torch.argmax(z_logits, axis=1)
      5 z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/dall_e/encoder.py in forward(self, x)
     91                         raise ValueError('input must have dtype torch.float32')
     92 
---> 93                 return self.blocks(x)

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/dall_e/utils.py in forward(self, x)
     41                         w, b = self.w, self.b
     42 
---> 43                 return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
     44 
     45 def map_pixels(x: torch.Tensor) -> torch.Tensor:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper___slow_conv2d_forward)

transformer missing

It says that the transformer to generate the images from text is missing. Can this project be used to generate images from text as is, or I need to have the transformer? Please help. I'm lost. Is this a ready to use project?

About dVAE

Hi, thank you for your great work, i want to konw what does the decrete in dVAE means exactly?

How do i run this from the files?

i have everything i need installed but i still don't know what to do after that. can someone give me a link to a video or what commands to put to run it?

Licensing

As usual, modified licenses create confusion.

We don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to
do with as you please. We only ask that you use the model responsibly and clearly indicate that it
was used.

The first sentence seem to be reasonable and although probably not needed it's nice that it's being clarified.

As for using the model "responsibly" and "clearly indicate that it was used" it's not at all clear to me.

Is it meant as a condition? If so, with what legal basis (since no derived work is being made or distributed)?
If it is a condition, that's quite unfortunate, since it makes it both non-free and incompatible with every copyleft license, but it's legitimate, of course. Just a bit unfortunate.
Is it just a kind request? If so, it really really really does not belong in a license.

How do we specify the text prompt?

On wikipedia and blogs and media we see that we can specify a text prompt and get a set of images matching the prompt (for example "a baby daikon radish in a tutu walking a dog" the first example from https://www.openai.com/blog/dall-e/). How do we specify an arbitrary text prompt using this code and get the set of images?

an TypeError

pytorch:1.7.1 torchversion: 0.8.2
when run the code, it seems wrong:

Traceback (most recent call last):
File "E:/github/DALL-E-master/test.py", line 46, in
display(T.ToPILImage(mode='RGB')(x[0]))
File "C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 185, in call
return F.to_pil_image(pic, self.mode)
File "C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py", line 202, in to_pil_image
'not {}'.format(type(npimg)))
TypeError: Input pic must be a torch.Tensor or NumPy ndarray, not <class 'numpy.ndarray'>
Original image:

Hyperparameters of the bottleneck

Thanks for releasing the paper as well as the codes!

Could you give some hints on the hyperparameters of the bottleneck that might affect the performance?

  1. The downsampling ratio. In the original VQ-VAE paper, they only use 4 times downsampling (comparing to 8 in DALL-E) and it seems their generated images lack a global structure (I assumed also because they didn't use a powerful prior model). Is that using a higher downsampling rate, the global structure is better preserved? Or easier for the prior model to learn?

  2. The codebook size is set to 2^13, have you tried using smaller codebook size? Presumably, as the codebook size shrinks, the VAE can hardly reconstruct the image. What does the reconstructed image with a small codebook look like? Is the texture still preserved but the global structure distorted or something else?

I also have an additional question to the inference stage of the model. Is the image tokens sampled from the prior transformer auto-regressively or using other searching technique? Also, how to control the number of the generated image tokens to be exactly 32 * 32?

What's purpose of it?

If it doesn't include an image generator. What are the tasks where I can use this project?

Make it clear that this package does not allow full DALL-E functionality

Many of us are arriving here with the expectation that we'll be able to replicate text-to-image functionality exhibited by DALL-E:
https://openai.com/blog/dall-e/

We can't. Despite naming this repo "DALL-E", you have only released a minor part of the DALL-E functionality. This isn't immediately clear.

Please make clear the limitations of this package, so people don't have to dig around and waste their time before discovering that text-to-image functionality isn't here and will never be released.

about codebook

so in the whole architecture, no explicit codebook vecters are used right? only categorical logits as the input to the decoder when you train your dvae?

questions on notebook

I just downloaded the repo to my local file system and used jupyter notebook and then opened and played the notebook. I also downloaded the encoder and decoder to the same folder for ease of loading. It says that 'preprocess' is not defined, but it seems to be. Admittedly, a bit rusty. Running Python 3.9 on Mac OSX. Also, I may be way out of line with respect to the purpose, but I was expecting to see code that took natural language input (e.g. "Show me a penguin on snow") and then DALL*E returns the provided image.

Screen Shot 2021-02-27 at 9 51 37 PM

Originally posted by @metaphorz in #5 (comment)

Any plan on releasing the text encoder?

The current release only contains the CNN encoder and decoder trained with d-VAE. The repo looks like a release for the d-VAE paper rather than for the DALL-E. The only task we can achieve in this release is image reconstruction provided by d-VAE, instead of the cool applications shown in the blog post.

Do the authors plan to release the text encoder? Or does anyone have thoughts on how to get around with this?

docker image

do we have a docker image for this that I can use?

Training dynamics of the prior transformer

I noticed that a footnote of the paper says that the dVAE model is underfitted and thus the codes are sampled with argmax. So does that mean the distribution over the code vocabulary is flatter rather than sharper, i.e., less confident? Also, I am wondering what is the final nll of the prior transformer on the valid set?

How to generate vocabulary for the BPE-encoding of text?

Hi, @adityaramesh, This is an amazing work!
In the paper, it mentioned that the lowercased text caption is BPE-encoded using at most 256 tokens with vocabulary size 16384. Here are my two questions:

  1. How to get the vocabulary? Is it obtained by collecting all subword uints in the captions of training data?
  2. How to decide the vocabulary size 16384? I can not get any other information from the paper about this value.

Thanks

Error with multiple devices (but dev is set to 'cuda:0 ?

Getting error on last block; not sure if this is related... nlp-with-transformers/notebooks#31
`---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Untitled-1.ipynb Cell 5' in <cell line: 2>()
1 import torch.nn.functional as F
----> 2 z_logits = enc(x)
3 z = torch.argmax(z_logits, axis=1)
4 z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()

File c:\Users\aaken.venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\aaken.venv\lib\site-packages\dall_e\encoder.py:93, in Encoder.forward(self, x)
90 if x.dtype != torch.float32:
91 raise ValueError('input must have dtype torch.float32')
---> 93 return self.blocks(x)

File c:\Users\aaken.venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\aaken.venv\lib\site-packages\torch\nn\modules\container.py:141, in Sequential.forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input

File c:\Users\aaken.venv\lib\site-packages\torch\nn\modules\module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\aaken.venv\lib\site-packages\dall_e\utils.py:43, in Conv2d.forward(self, x)
39 x = x.float()
41 w, b = self.w, self.b
---> 43 return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper___slow_conv2d_forward)`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.