Code Monkey home page Code Monkey logo

mammut-pytorch's Introduction

MaMMUT - Pytorch

Implementation of MaMMUT, a simple vision-encoder text-decoder architecture for multimodal tasks from Google, in Pytorch. Blog post

This work is basically just a simplified CoCa. I copied the code from this repository and made the change in the paper, which was to simply do two passes through the text encoder, one with cross attention for the generative loss, and the other without for the contrastive loss.

This is also a good time to plug an open sourced version of CoCa from the folks at OpenCLIP!

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

Install

$ pip install mammut-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch>=0.40.2

Then

import torch

# import vision transformer

from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor

vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    patch_dropout = 0.5  # https://arxiv.org/abs/2212.00794
)

vit = Extractor(vit, return_embeddings_only = True, detach = False)

# extractor will enable it so the vision transformer returns its embeddings

# import MaMMUT and instantiate it

from mammut_pytorch.mammut_pytorch import MaMMUT

mammut = MaMMUT(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    depth = 6,                     # depth of the transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving MaMMUT your text and images with `return_loss = True`

loss = mammut(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = mammut(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = mammut(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

One of the main findings of the paper is that different tasks perform differently depending on the amount of cross attention. This repository will give you full control over how much cross attention you want to place in the network.

mammut = MaMMUT(
    dim = 512,
    img_encoder = vit,
    image_dim = 1024,
    num_tokens = 20000,
    depth = 6,
    cross_attend_every = 2,   # say you want to cross attend only every 2 layers
    dim_head = 64,
    heads = 8,
    caption_loss_weight = 1.,
    contrastive_loss_weight = 1.
).cuda()

# or you can finely specify which layers to do cross attention

mammut = MaMMUT(
    dim = 512,
    img_encoder = vit,
    image_dim = 1024,
    num_tokens = 20000,
    depth = 6,
    cross_attend_layers = (4, 5, 6),  # only last three layers have cross attention
    dim_head = 64,
    heads = 8,
    caption_loss_weight = 1.,
    contrastive_loss_weight = 1.
).cuda()

Todo

  • offer masked mean pooling of text embeddings and mean pooling for images for contrastive latents

Citations

@article{Kuo2023MaMMUTAS,
    title   = {MaMMUT: A Simple Architecture for Joint Learning for MultiModal Tasks},
    author  = {Weicheng Kuo and A. J. Piergiovanni and Dahun Kim and Xiyang Luo and Benjamin Caine and W. Li and Abhijit S. Ogale and Luowei Zhou and Andrew M. Dai and Zhifeng Chen and Claire Cui and Anelia Angelova},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2303.16839}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}

mammut-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar Yoon, Seungje avatar Meimingwei Li avatar R Janó avatar  avatar Wish Suharitdarmong avatar  avatar Chitra Singh avatar Viraat Das avatar  avatar Naveen Garg avatar  avatar chlei avatar Fatemeh Amerehi avatar Yuchong Yao avatar Lau Van Kiet avatar João Mesquita avatar Aishwarya Sekhar avatar Ahmad Alismail avatar  avatar 5l1v3r1 avatar DataCentric avatar Gustaf Rydholm avatar dinhanhx avatar Ghulam Jilani Raza avatar QiangZhou avatar Jeff Carpenter avatar Wei-Hsin Yeh avatar hoon_bari avatar Mason Nakamura avatar Nathan Baylon avatar Maxime Santerre avatar  avatar John S. Dvorak avatar Myungji Lee avatar  avatar Christopher Erick Moody avatar dbernardoj avatar Tony Davis avatar Ge Zhu (朱舸) avatar Mathus Thongkerd avatar Moritz Reuss avatar  avatar Simon Levine avatar Mohammad Reza Taesiri avatar lyb avatar phalanx avatar Ellery Queen avatar Seonjong Kang avatar Ryan Walden avatar Chanjoon Park avatar Siddharth Mishra-Sharma avatar Maxi avatar Huy Manh avatar yahooo avatar Jeffrey Fetzer avatar Zion.N avatar Richard Chen avatar Sandalots avatar Danil avatar slyviacassell avatar Aryan Shekarlaban avatar Johnny avatar  avatar  avatar Shu avatar Huaiwen Zhang avatar i-MaTh avatar Tomasz Latkowski avatar likeucode avatar LiuZhuang avatar yangmin09 avatar  avatar 爱可可-爱生活 avatar Kentechx avatar huhu avatar Yuan-Man avatar Zhi-guo Huang avatar Lnyan avatar Yong Yuan avatar  avatar HAESUNG JEON (chad.plus) avatar YuanfengJi avatar  avatar ZZK avatar snoop2head avatar Jonah Turner avatar Arka Sadhu avatar Tanya Malygina avatar Vi Ngo Van avatar Igor avatar Hussein Lezzaik avatar Kye Gomez avatar Rishikesh (ऋषिकेश) avatar

Watchers

 avatar  avatar Kostas Georgiou avatar  avatar

mammut-pytorch's Issues

causal mask when cross-attn

Hi, thanks for having this great repo. I just wonder if the causal mask should be used in the second forward pass when cross-attn is used. Or do I miss something here? Thank you

Minimum GPU requirement

What is the minimum GPU requirement to run MaMMUT model properly?
Is it possible to run MaMMUT properly on a machine with 12 GB GPU?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.