Code Monkey home page Code Monkey logo

kyegomez / cm3leon Goto Github PK

View Code? Open in Web Editor NEW
354.0 21.0 18.0 772 KB

An open source implementation of "Scaling Autoregressive Multi-Modal Models: Pretraining and Instruction Tuning", an all-new multi modal AI that uses just a decoder to generate both text and images

Home Page: https://discord.gg/qUtxnK2NMf

License: MIT License

Python 100.00%
attention attention-is-all-you-need dalle imagegeneration multimodal multimodal-learning multimodality

cm3leon's Introduction

Multi-Modality

CM3Leon: Autoregressive Multi-Modal Model for Text and Image Generation (wip)

GitHub issues GitHub forks GitHub stars GitHub license Share on Twitter Share on Facebook Share on LinkedIn Discord Share on Reddit Share on Hacker News Share on Pinterest Share on WhatsApp Open In Colab

CM3Leon is a transformer-based autoregressive model designed for multi-modal tasks, specifically text and image generation. The model is trained in two stages, using a large diverse multimodal dataset and augmented retrieval pretraining. It also implements contrastive decoding to enhance the quality of the generated samples.

CM3LEON, PAPER LINK

  • Please Help with this open source implementation in the Agora discord, Discord
  • This implementation is still not finished.

Install

pip3 install cm3


Usage & Example

To start with CM3Leon in a PyTorch environment:

import torch
from cm3.model import CM3

# usage
img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

model = CM3()

output = model(img, caption)
print(output.shape)  # (1, 1024, 20000)

This repository hosts the open-source implementation of CM3Leon, a state-of-the-art autoregressive multi-modal model for text and image generation. The model is introduced in the paper "Scaling Autoregressive Multi-Modal Models: Pretraining and Instruction Tuning".


Overview

Key Features of CM3Leon:

  • Retrieval augmented pretraining on a large diverse multimodal dataset.
  • Two-stage training: pretraining and supervised fine-tuning.
  • Contrastive decoding for enhanced sample quality.

CM3Leon sets a new benchmark in text-to-image generation, outperforming comparable models while requiring 5x less computational resources.

Getting Started

The following sections provide a detailed analysis of the model architecture, the necessary resources, and the steps needed to replicate the CM3Leon model.

Requirements

Replicating CM3Leon involves several critical components and requires proficiency in the following areas:

  • Large-scale distributed training of transformer models using a significant number of GPUs/TPUs.
  • Efficient data loading and preprocessing to handle extensive multimodal datasets.
  • Memory optimization techniques to accommodate large models within the GPU memory.
  • Custom tokenizer implementation for both text and image modalities.
  • Setting up a retrieval infrastructure for dense retrieval during pretraining.
  • Developing a fine-tuning framework to handle mixed text-image tasks.
  • Inference optimizations such as compiler-accelerated decoders, lower precision computing, and batching.

System Architecture

The CM3Leon implementation comprises:

  • A distributed training framework, preferably TensorFlow or PyTorch.
  • High-performance compute infrastructure (HPC cluster with GPUs/TPUs).
  • A retrieval index and dense retriever module for augmentation.
  • Data pipelines for efficient preprocessing and loading.
  • Custom code for tokenizers and the CM3 model architecture.
  • Fine-tuning framework and relevant task datasets.
  • Serving infrastructure for low-latency inference.

Implementing these components involves challenges such as efficient utilization of large compute clusters, minimizing data loading and preprocessing bottlenecks, optimizing memory usage during training and inference, and ensuring low latency serving.

Model Architecture

The architecture of CM3Leon includes:

  • Text and Image Tokenizers: Custom text tokenizer trained on CommonCrawl data and Image tokenizer that encodes 256x256 images into 1024 tokens.
  • Special Tokens: Usage of <break> token to indicate modality transitions.
  • Retrieval Augmentation: Using a bi-encoder based on CLIP to retrieve relevant text and images from the memory bank.
  • Autoregressive Decoder-only Transformer: Standard transformer architecture similar to GPT models.
  • Two-Stage Training: Pretraining with retrieval augmentation and supervised finetuning on text-image tasks via instruction tuning.
  • Contrastive Decoding: Modified contrastive decoding for better sample quality.

The model size ranges from 350M to 7B parameters.

Data

Here is a markdown table with the datasets used in the paper along with additional metadata and source links:

Dataset Domain Size Source
Shutterstock Images and captions 3 billion text tokens, licensed image data Proprietary dataset, described in paper
MS-COCO Image captioning 591K image-caption pairs Microsoft COCO Captions
Flickr30k Image captioning 144K image-caption pairs Flickr30k Entities
Image Paragraph Dense image captioning 14K images with paragraph captions Image Paragraph dataset
Localized Narratives Image paragraph captioning 164K images with localized narratives Localized Narratives
VQA2 Visual question answering 1.3M images with question-answer pairs VQA2 dataset
VizWiz Visual question answering for blind users 92K images with question-answer pairs VizWiz dataset
OKVQA Knowledge-based VQA 26K images with question-answer pairs OK-VQA dataset
ScienceQA Scientific visual QA 6K images with multi-choice QA pairs ScienceQA

The model was trained and evaluated on several datasets including MS-COCO [...] (Chen et al., 2015), Flickr30k [...] (Young et al., 2014), etc.

For successful implementation, CM3Leon requires:

  • A large (100M+ examples) diverse multimodal dataset like Shutterstock for pretraining.
  • A mixture of text and image tasks with accompanying datasets for finetuning.
  • Efficient and scalable data loading that does not bottleneck model training.
  • Preprocessing steps like resizing images to 256x256 pixels and text tokenization.

Training

CM3Leon's training process involves:

  • Pretraining with retrieval augmentation and CM3 objective.
  • Supervised finetuning on text-image tasks.
  • Efficient distributed training infrastructure for large-scale model training.
  • Hyperparameter tuning for learning rates, batch sizes, optimizers, etc.

Inference

For efficient inference, consider:

  • Using compiler-accelerated decoders like FasterTransformer.
  • Other optimizations like lower precision (FP16/INT8) and batching.
  • Efficient implementation of contrastive decoding.

HyperParameters

350M 24 1024 4096 8M 6e-04 1500 256 1.4T
760M 24 1536 4096 8M 5e-04 1500 256 1.9T
7B 32 4096 4096 8M 1.2e-04 1500 512 2.4T

SuperVised FineTuning parameters

Model # GPUS Seq Length Batch Size LR Warm-up Steps # Tokens
CM3Leon-760m 64 4096 2M 5e-05 150 30B
CM3Leon-7b 128 4096 2M 5e-05 150 30B

Innovations in the paper:

  • Conditional text + image generation with objective function + contrastive top k decoding

  • Multi-Modality models need to be dynamic they can't just generate the types of data they were trained on they need to be able to adapt to user needs therefore multi-modality models should be conditional, if prompted the model will generate text and or images, this is the future.

Contributing

This repository welcomes contributions. Feel free to submit pull requests, create issues, or suggest any enhancements.

Support

If you encounter any issues or need further clarification, please create an issue in the GitHub issue tracker.

License

CM3Leon is open-sourced under the MIT license.

Roadmap

  • Implement Objective function where multi-modal inputs are transformed into an infilling instance by masking specific spans and relocating them to the end.

  • Implement a next token prediction loss, -log p(x input)

  • Implement TopP sampling

  • Implement Free Guidance CFG => directing an unconditional sample towards a conditional sample. Replace text with mask token from cm3 objective for uncoditional sampling so that during inference 2 concurrent tokens tsreams are generated a conditional stream, which is contigent on the input text and an unconditional token stream which is conditioned on a mask token Where

Logits, cond = T(ty | ty), logit.uncond = T(ty | <mask>)
logits.cf = logits.uncond + a.c * (logits.cond - logits.uncond)

T = transformer
ty = output tokens
tx = conditional input text <mask>
<mask> = no input text + replacement with a mask token
a.c = scaling factor
  • Implement Contrastive Decoding TopK =>
V(t.y < .i) = {t.yi is in V: P.exp(t.yi | t.y<.i) >= a * kmax(p.exp(w|t.y<i))}

Citation

cm3leon's People

Contributors

kyegomez avatar rsxdalv avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cm3leon's Issues

Errors on library import and readme example running

Hi I've been having an error on cm3 simple library import or even on readme example. Would help me out on this one? tks!!

image

image

The weird thing is that if I run it once, and after, get the linecode in error :

image

and run it separately, it runs without error:

image

but if then I try to run the example code or the cm3 import again, it returns another error:

image

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Generate not working

Trying to train any of the Vision models that you have around, but im not being lucky on finding one that is workable both in terms of training and prediction.

File /data/conda/envs/eqbench/lib/python3.10/site-packages/zeta/structs/auto_regressive_wrapper.py:186, in AutoregressiveWrapper.generate(self, start_tokens, seq_len, eos_token, strategy, temperature, filter_logits_fn, filter_thres, min_p_pow, min_p_ratio, gamma, **kwargs)
    182 b, t = start_tokens.shape
    184 out = start_tokens
--> 186 if self.speculative:
    187     for _ in range(seq_len):
    188         x = out[:, -self.max_seq_len]

File /data/conda/envs/eqbench/lib/python3.10/site-packages/torch/nn/modules/module.py:1695, in Module.__getattr__(self, name)
   1693     if name in modules:
   1694         return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'AutoregressiveWrapper' object has no attribute 'speculative'

Im keen on pretraining a model, which one u think will give good performance and that the code is reproducible at your repo? i tried kosmos-x also, having as well some issue to run it.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Error while processing rearrange-reduction pattern

Code:

import torch
from cm3.model import CM3

model = CM3(
    depth=24, dim=1536
)
model_gpu = model.to('cuda:0')

img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
output2 = model_gpu(text, img)

Error:
Failed in forward method: Error while processing rearrange-reduction pattern "b c (h p1) (w p2) -> b (h w) (p1 p2 c)".
Input tensor shape: torch.Size([1, 1024]). Additional info: {'p1': 32, 'p2': 32}.
Expected 4 dimensions, got 2

EinopsError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/einops/einops.py in reduce(tensor, pattern, reduction, **axes_lengths)
411 recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
--> 412 return _apply_recipe(recipe, tensor, reduction_type=reduction)
413 except EinopsError as e:

8 frames
EinopsError: Expected 4 dimensions, got 2

During handling of the above exception, another exception occurred:

EinopsError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/einops/einops.py in reduce(tensor, pattern, reduction, **axes_lengths)
418 message += '\n Input is list. '
419 message += 'Additional info: {}.'.format(axes_lengths)
--> 420 raise EinopsError(message + '\n {}'.format(e))
421
422

EinopsError: Error while processing rearrange-reduction pattern "b c (h p1) (w p2) -> b (h w) (p1 p2 c)".
Input tensor shape: torch.Size([1, 1024]). Additional info: {'p1': 32, 'p2': 32}.
Expected 4 dimensions, got 2

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

What is the output of the model?

In the example.py file, the model output should be of shape print(output[0].shape) # (1, 1024, 20000).
What is this (1,1024,20000) matrix? I thought the CM3Leon model is an image generator, but (1,1024,20000) doesn't seem like an image.
Thank you for your great work and I'm looking forward to hearing from you.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Cannot access to your paper file?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Is CM3Leon not gonna be open source?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

how to generate an image

Could you provide a demo for generating image from a text?
example

prompt = "xxx"
image = model.generate(prompt)
image.save("xxx.png")

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Is this repo been tested?

I've encountered lots of errors like unfound attributes etc, wonder is this repo has been tested?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Checkpoint

I was reading another issues and walking through the code I realized that the example.py doesn't load any checkpoint, so, as far as I understood, the example.py and what you got so far from the code is just the pre trained vit, transformers, and so on, and those items were not retrained in the MMM setting proposed by the paper, right?

It's just because I have a few days to do some testing, and as far as I realized, even if I implement the img and text bi-encoder I would get the same performance as the model has, right?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Finetuning on custom dataset

Hi @kyegomez,
Thank you for sharing this code.
Could you give an example to finetune on the custom dataset for image caption or text2image tasks?

Thanks!

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Generate an image

The Readme describes a method to load the model but how can we actually use the model to generate an image using a given prompt?

Could you provide a code example?

prompt = "xxx"
image = model.generate(prompt)
image.save("xxx.png")

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Citation?

Will the citation section be updated with the required bibtex?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

zeta library does not compile

These two lines in model.py do not compile for me. I have already pip installed cm3.

from zeta.nn.architecture.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.nn.architecture.transformer import (

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.