Code Monkey home page Code Monkey logo

ml-aim's Introduction

AIM: Autoregressive Image Models

Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin

To appear at ICML 2024

[Paper] [BibTex]

This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.

We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:

  1. the model capacity can be trivially scaled to billions of parameters, and
  2. AIM effectively leverages large collections of uncurated image data.

Installation

Please install PyTorch using the official installation instructions. Afterward, install the package as:

pip install git+https://[email protected]/apple/ml-aim.git

We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:

pip install mlx

Usage

Below we provide an example of usage in PyTorch:

from PIL import Image

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
logits, features = model(inp)
and in both MLX
from PIL import Image
import mlx.core as mx

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
and JAX
from PIL import Image
import jax.numpy as jnp

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])

Pre-trained checkpoints

The pre-trained models can be accessed via PyTorch Hub as:

import torch

aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b   = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b   = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b   = torch.hub.load("apple/ml-aim", "aim_7B")

or via HuggingFace Hub as:

from aim.torch.models import AIMForImageClassification

aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b   = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b   = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b   = AIMForImageClassification.from_pretrained("apple/aim-7B")

Pre-trained backbones

The following table contains pre-trained backbones used in our paper.

model #params attn (best layer) backbone, SHA256
AIM-0.6B 0.6B 79.4% link, 0d6f6b8f
AIM-1B 1B 82.3% link, d254ecd3
AIM-3B 3B 83.3% link, 8475ce4e
AIM-7B 7B 84.0% link, 184ed94c

Pre-trained attention heads

The table below contains the classification results on ImageNet-1k validation set.

model top-1 IN-1k attention head, SHA256
last layer best layer last layer best layer
AIM-0.6B 78.5% 79.4% link, 5ce5a341 link, ebd45c05
AIM-1B 80.6% 82.3% link, db3be2ad link, f1ed7852
AIM-3B 82.2% 83.3% link, 5c057b30 link, ad380e16
AIM-7B 82.4% 84.0% link, 1e5c99ba link, 73ecd732

Reproducing the IN-1k classification results

The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:

torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
  --model=aim-7B \
  --batch-size=64 \
  --data-path=/path/to/imagenet \
  --probe-layers=best \
  --backbone-ckpt-path=/path/to/backbone_ckpt.pth \
  --head-ckpt-path=/path/to/head_ckpt.pth

By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass --probe-layers=last.

Citation

If you find our work useful, please consider citing us as:

@article{el2024scalable,
  title={Scalable Pre-training of Large Autoregressive Image Models},
  author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
  journal={International Conference on Machine Learning},
  year={2024}
}

ml-aim's People

Contributors

michalk8 avatar aelnouby avatar eltociear avatar

Stargazers

Amr Kayid avatar Pierre-Antoine Berthier avatar David-Hown avatar  avatar  avatar AIFlow_ML avatar Abdul Zalil avatar Yuchao Jin avatar GUGU avatar Xiaoyuan Wang avatar Hongwei Yi avatar Xiaolong avatar Gang Wu avatar  avatar Hyeokjoon Kwon avatar Arneesh Aima avatar Hao Tang avatar  avatar Massimiliano Fiori avatar Yayue Deng (邓雅月) avatar  avatar Jiazhen Wang avatar Pratham Sahni avatar Gowthami Somepalli avatar Ethan Wickstrom avatar Moses Yeung avatar Zhiheng_Ma avatar Aloys Portafaix avatar Md Musfiqur Rahman avatar  avatar David Hsieh avatar Rishikesh (ऋषिकेश) avatar Hector Carrion avatar Dong Zhang avatar BeiXiao avatar Hangyeol Jung avatar Vansh avatar Zhengyang Geng avatar sean avatar  avatar SohaibYaser avatar Mohammad Reza Taesiri avatar Tao Shen avatar Algomaster avatar Xuan-Li CHEN 陈轩黎 avatar Lahav Lipson avatar  avatar Colin Merk avatar Kartikey avatar  avatar Sunyue-SJTU avatar  avatar  avatar  avatar Gurpreet Kaur avatar sadamov avatar  avatar Henry Wallace avatar  avatar allen.hu avatar  avatar David Mizrahi avatar  avatar Mithun Parab avatar Coolshan avatar Xinhong Ma avatar Daniel Ji avatar Jiaqi-Chen avatar Gabrielle Hoyer avatar  avatar Jeff Dlouhy avatar Shamima avatar Will_Qiu avatar Faique Ali avatar James Le avatar Vaibhav Bhargava avatar Ruihan Yang avatar  avatar Runjie Yan avatar Dave Proffer avatar Thomas Lindgren avatar Hyeon Jo avatar Uday Upreti avatar  avatar dg avatar  avatar Sundar Sripada V S avatar 1amageek avatar  avatar Akinori Nakajima avatar Roman Neronov avatar SUN YOUNG HWANG avatar Atif S avatar Victor Kiprop avatar Liang Chen avatar Alex hepburn avatar  avatar Brad Ledford avatar Sandeep Kumar Kushwaha avatar Li Pan avatar

Watchers

Federico Bucchi avatar James Cloos avatar  avatar Mike Drob avatar Tim Kolecke avatar Michael Tu avatar Matt Young avatar Gurkaran Singh avatar vulcangz avatar Guido Soranzio avatar  avatar Jenny Burcio avatar Kostas Georgiou avatar  avatar  avatar IronMan avatar  avatar  avatar  avatar  avatar

ml-aim's Issues

Releasing Training Codes

Thank you so much for sharing your great work!

However, it would be super useful if the training codes and procedures could be shared as well.
This will be helpful in encouraging more follow-up work and research to build on top of your great work.

Looking forward to hearing from you!
Thanks in advance!

Using as an image probability model

I was wondering if you could upload the weights of the final layers used to minimise the NLL, so that this model could be used as an image probability model.

Mismatches between ViT-H/14 in AIM and ViT-H/14 in MAE

AIM-600M:

def aim_600M(img_size: Union[int, Tuple[int, int]] = 224, **kwargs: Any) -> AIM:
    preprocessor, trunk, head = _aim(
        img_size=img_size,
        patch_size=14,
        embed_dim=1536,
        num_blocks=24,
        num_heads=12,
        **kwargs,
    )
    return AIM(preprocessor, trunk, head)

ml-aim/aim/torch/models.py

Lines 176 to 185 in 0b1dea9

def aim_600M(img_size: Union[int, Tuple[int, int]] = 224, **kwargs: Any) -> AIM:
preprocessor, trunk, head = _aim(
img_size=img_size,
patch_size=14,
embed_dim=1536,
num_blocks=24,
num_heads=12,
**kwargs,
)
return AIM(preprocessor, trunk, head)

MAE ViT-H/14:

def vit_huge_patch14(**kwargs):
    model = VisionTransformer(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_vit.py#L70-L74

The models have very different embedding dimensions, depth, and num_heads, and are incompatible with each other. However, in Tab. 6 of the paper, these two works share the same architecture in "Arch." column. Are the two architectures different, as it shows in the code? If so, it should probably be clarified in terms of the number of parameters in the paper.

what "offline tokenizers" have you tried?

Hi there. Thanks for the great work.
In the paper, you mentioned that "We also consider a cross-entropy loss with patches converted to discrete tokens using an offline tokenizer".
Recently, I've been working on a relevant project in which I want to tokenize images into discretized tokens in an unsupervised way (e.g. VQ-GAN).
I'm wondering what offline tokenizers you've tried.
Any help will be appreciated. Thanks!😀

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.