Code Monkey home page Code Monkey logo

Comments (36)

lucidrains avatar lucidrains commented on May 18, 2024 42

6dp8bj

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024 5

https://huggingface.co/rom1504/dalle2-diffusion-prior/resolve/main/1651432174.5708027_saved_model.pth here's a first checkpoint by @krish240574
thanks to him for building the training code and running that first training!

it's time for the evaluation to start.

from dalle2-pytorch.

crowsonkb avatar crowsonkb commented on May 18, 2024 2

btw I think CoCa should be trained dropping out the image features so you can generate captions with superconditioning (it should work well there for the same reason it works well when generating images, the typical image/caption pair in most training sets doesn't match all that well).

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024 2

@rom1504 update: Romain reported that a research group out there have replicated the prior using the code in this repository for their CLIP generations - in other words, the code in this repository works, and we have confirmation that the prior is effective, per paper

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024 1

ohhh you're right, well that makes thing much easier
I'll remove that from the plan

from dalle2-pytorch.

nousr avatar nousr commented on May 18, 2024 1

I was able to over-fit a small subset of LAION-2B locally (using the new CLIP-less DiffusionPrior class). I'll be working to make that data loader next so we can do the first "real" training run.

referencing:

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024 1

Cool!
The corresponding text embeddings are now in https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024 1

You can use 2 instances of embedding reader and a zip to get both embeddings batch

from dalle2-pytorch.

krish240574 avatar krish240574 commented on May 18, 2024 1

https://huggingface.co/rom1504/dalle2-diffusion-prior/resolve/main/1651432174.5708027_saved_model.pth here's a first checkpoint by @krish240574 thanks to him for building the training code and running that first training!

it's time for the evaluation to start.

Here is a run happening, 100 million data points -
https://wandb.ai/laion/diffusion-prior/runs/3o0ic6ou?workspace=user-krish240574

Hyperparameters -
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01
batch_size=10 ** 4
Train-val-test = 0.7 - 0.2 - 0.1
Refer to https://github.com/lucidrains/DALLE2-pytorch/blob/main/train_diffusion_prior.py for training details.

from dalle2-pytorch.

nousr avatar nousr commented on May 18, 2024 1

@lucidrains awesome, thanks for doing that!

Since we're starting to pull out meaningful results from our mini-prior, what would you recommend in terms of network hp params next?

I think I saw somewhere there was a discussion about moving to something like...?

# change to 12 layers, 128 dim, 16 heads
prior_network = DiffusionPriorNetwork(
    dim = 768,
    depth = 12,
    dim_head = 128,
    heads = 16
).cuda()
# i'd also like to try 1000 steps and compare results (open to thoughts on this)
diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024 1

@nousr no problem! :)

yea we can aim for maybe the same size as GPT2 small? https://huggingface.co/docs/transformers/model_doc/gpt2

so it would translate to

prior_network = DiffusionPriorNetwork(
    dim = 768,
    depth = 12,
    dim_head = 64,
    heads = 12
).cuda()

from dalle2-pytorch.

crowsonkb avatar crowsonkb commented on May 18, 2024 1

Classifier-free guidance works fine with predicting x_0, I trained mine that way. :) As for training with the text encodings, I am feeding in the hidden states at the end of the frozen CLIP text encoder to my prior (along with the corresponding padding mask for attention) instead of trying to learn language from scratch. This works pretty well, a lot better than trying to feed in only the CLIP text embedding!

from dalle2-pytorch.

crowsonkb avatar crowsonkb commented on May 18, 2024 1

i was also puzzled why they didn't take the output token of the image embedding

I tried this and it was worse/didn't learn as well. I suspect it might be easier to predict a clean x_0 without the residual bias from the noisy input (because it would have to generate exact anti-noise in the ffns and add it to that token's residual stream), whereas if you have a separate output token you just have to copy information into it via attention.

from dalle2-pytorch.

crowsonkb avatar crowsonkb commented on May 18, 2024 1

I think so too because it can't just throw away as much information about the contents of the individual tokens bc it needs to do the next token prediction task too.

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024 1

btw I think CoCa should be trained dropping out the image features so you can generate captions with superconditioning (it should work well there for the same reason it works well when generating images, the typical image/caption pair in most training sets doesn't match all that well).

yes! agreed on the superconditioning!

last thought i've had that is worth sharing is perhaps there can be one more level of indirection. If one were to train CLIP with a multi-stage efficient ViT, we can do text -> text embed + text encoding -> image embed + image encoding (~49-64 tokens) -> image. in the ideal world, we have a 540B parameter PaLM + CoCa (with FILIP fine interactions between the text tokens and image tokens) + cascading DDPM generator of about 3B parameters conditioned on the image tokens from the hierarchical ViT. One can dream 😆

from dalle2-pytorch.

nousr avatar nousr commented on May 18, 2024 1

@lucidrains have you started working on config-file support for the prior?

clarification: (I was thinking about tackling that to simplify the training script, but didn't wanna duplicate work!)

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024 1

@nousr hey yea i did, only partly. but the general scaffold is there that it shouldn't take too much code to convert it to be config-based (by looking at the decoder training script and what was done in dalle2_pytorch/train_configs.py as example)

this week i'm back to attending a bunch of meetings, so will be generally unproductive. feel free to jump in with a PR!

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024 1

I believe we got everything done here.
One last thing that would be helpful in my opinion would be a prior.md with these information:

  • What is the prior and what can it do
  • How to prepare a dataset for it
  • How to run a training
  • how to use a pretrained model
  • What are good metrics
  • how to plug into the decoder

I think that would close this topic.
I believe most of the information is already present in readme.md but having it in its dedicated file would help as the readme is getting very large now.

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

@rom1504 ah looks like a great plan :) so i think the PCA portion was only for the autoregressive prior method, and not the diffusion prior (they needed to quantize so they can use the straightforward cross entropy loss). but correct me if i'm wrong i'm happy to build that into the framework

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024

I'll keep this issue updated as we make progress

from dalle2-pytorch.

taki0112 avatar taki0112 commented on May 18, 2024

@rom1504
The text embedding you mentioned is from which network? Is it a clip?

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024

Yes, image and text embeddings link above are ViT-L/14 clip.
Check the laion5B blogpost for details

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

@krish240574 so just a word of caution, that learning rate of 1e-3 is quite high for transformers. a more conservative value would be Karpathy's favorite, 3e-4 (however, I don't know how transformers behave within a DDPM framework, so I could be wrong)

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024

@krish240574 is running that new run with the newly added metrics https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=user-rom1504

seems cosine similarity is going up.

@lucidrains do you have any opinion on what would be the best way to know if the prior is doing its job? (except for plugging it into a generator and training of course)

from dalle2-pytorch.

xiankgx avatar xiankgx commented on May 18, 2024

Seems like the prior is doing a great job here. Please share the training details and checkpoint and dataset if possible. Thanks.

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024

Dataset is
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")

IE laion2B-en clip ViT-l/14 text/image embeddings

from dalle2-pytorch.

krish240574 avatar krish240574 commented on May 18, 2024

Here is a sample run, 600 million data points - https://wandb.ai/laion/diffusion-prior/runs/ar65uq6n?workspace=user-krish240574, hyperparameters as in the script(default values) train_diffision_prior.py

from dalle2-pytorch.

nousr avatar nousr commented on May 18, 2024

Just wanted to make a note that Katherine recommended we use EMA while training the prior as well @lucidrains

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

Just wanted to make a note that Katherine recommended we use EMA while training the prior as well @lucidrains

ohh got it, i can take care of that tomorrow morning 👍

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

Just wanted to make a note that Katherine recommended we use EMA while training the prior as well @lucidrains
@nousr

740d644 ok all done! @crowsonkb thank you for the advice yet again 🙏

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

in light of #71 should definitely be retrained at the latest version! i've also turned off classifier free guidance, since i'm uncertain if it works well with the predict x0 objective. final thought is that we should be training with the text encodings + corresponding text mask if possible. i think the paper did have this (although if Katherine was able to get good results without, let us aim for that first)

from dalle2-pytorch.

crowsonkb avatar crowsonkb commented on May 18, 2024

Actually do you think I need learned queries if I am feeding in a sequence of text encoder hidden states? I don't have them and don't know how much they help. There are always at least two text encoder hidden states from the SOT and EOT tokens (for the null condition), more for actual prompts.

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

Classifier-free guidance works fine with predicting x_0, I trained mine that way. :) As for training with the text encodings, I am feeding in the hidden states at the end of the frozen CLIP text encoder to my prior (along with the corresponding padding mask for attention) instead of trying to learn language from scratch. This works pretty well, a lot better than trying to feed in only the CLIP text embedding!

thanks for sharing! that's great to know! was wondering if this would work well given the l2norm constraint - i'll revert the commit next week so others can use it

and yes, i have things setup exactly like you did (the text encodings being the output of the final layer of the CLIP text transformer), so it should be ready to go, provided Laion can save and dataload the text encodings efficiently in addition to the text embeddings

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

Actually do you think I need learned queries if I am feeding in a sequence of text encoder hidden states? I don't have them and don't know how much they help. There are always at least two text encoder hidden states from the SOT and EOT tokens (for the null condition), more for actual prompts.

so the literature is scant on this, but https://arxiv.org/abs/2006.11527 does suggest adding learned queries (up to 16) can be beneficial. i was also puzzled why they didn't take the output token of the image embedding, and one possibility is that they projected the "noised | predicted" image embedding into multiple tokens for attention (which i'll build into the repository as a setting eventually). the more realistic answer is no matter which token you choose, as long as the position information is there, and as long as the transformer is big enough, it won't matter that much 😆

tldr: it doesn't hurt to add a few memory (learned query) tokens. with a big enough transformer, probably matters little

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 18, 2024

this is valuable information! thanks!

at the end of the day, I think the team still struggled with the generator not treating the text as a bag of words, and the bottleneck is the clip text encoder. I suspect the new Coca, which is an LM and Clip trained end to end, would help

from dalle2-pytorch.

rom1504 avatar rom1504 commented on May 18, 2024

pretty much done now

from dalle2-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.