Code Monkey home page Code Monkey logo

vit-prisma's Introduction

Vision Transformer (ViT) Prisma Library

Logo Image 1 Logo Image 2 Logo Image 3

For a full introduction, including Open Problems in vision mechanistic interpretability, see the original Less Wrong post here.

ViT Prisma is an open-source mechanistic interpretability library for vision and multimodal models. Currently, the library supports ViTs and CLIP. This library was created by Sonia Joseph. ViT Prisma is largely based on TransformerLens by Neel Nanda.

Contributors: Praneet Suresh, Yash Vadi, Rob Graham [and more coming soon]

We welcome new contributors. Check out our contributing guidelines here and our open Issues.

Installing Repo

Installing with pip:

pip install vit_prisma

To install as an editable repo from source:

git clone https://github.com/soniajoseph/ViT-Prisma
cd ViT-Prisma
pip install -e .

How do I use this repo?

Check out our guide.

Check out our tutorial notebooks for using the repo.

  1. Main ViT Demo - Overview of main mechanistic interpretability technique on a ViT, including direct logit attribution, attention head visualization, and activation patching. The activation patching switches the net's prediction from tabby cat to Border collie with a minimum ablation.
  2. Emoji Logit Lens - Deeper dive into layer- and patch-level predictions with interactive plots.
  3. Interactive Attention Head Tour - Deeper dive into the various types of attention heads a ViT contains with interactive JavaScript.

Features

For a full demo of Prisma's features, including the visualizations below with interactivity, check out the demo notebooks above.

Attention head visualization

Logo Image 1 Logo Image 2 Logo Image 3

Activation patching

Direct logit attribution

Emoji logit lens

Supported Models

Training Code

Prisma contains training code to train your own custom ViTs. Training small ViTs can be very useful when isolating specific behaviors in the model.

For training your own models, check out our guide.

Custom Models & Checkpoints

ImageNet-1k classification checkpoints (patch size 32)

This model was trained by Praneet Suresh. All models include training checkpoints, in case you want to analyze training dynamics.

This larger patch size ViT has inspectable attention heads; else the patch size 16 attention heads are too large to easily render in JavaScript.

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 3 0.22 | 0.42 N/A Attention+MLP

ImageNet-1k classification checkpoints (patch size 16)

The detailed training logs and metrics can be found here. These models were trained by Yash Vadi.

Table of Results

Accuracy [ <Acc> | <Top5 Acc> ]

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 1 0.16 | 0.33 0.11 | 0.25 AttentionOnly, Attention+MLP
base 2 0.23 | 0.44 0.16 | 0.34 AttentionOnly, Attention+MLP
small 3 0.28 | 0.51 0.17 | 0.35 AttentionOnly, Attention+MLP
medium 4 0.33 | 0.56 0.17 | 0.36 AttentionOnly, Attention+MLP

dSprites Shape Classification training checkpoints

Original dataset is here.

Full results and training setup are here. These models were trained by Yash Vadi.

Table of Results

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 1 0.535 0.459 AttentionOnly, Attention+MLP
base 2 0.996 0.685 AttentionOnly, Attention+MLP
small 3 1.000 0.774 AttentionOnly, Attention+MLP
medium 4 1.000 0.991 AttentionOnly, Attention+MLP

Guidelines for training + uploading models

Upload your trained models to Huggingface. Follow the Huggingface guidelines and also create a model card. Document as much of the training process as possible including links to loss and accuracy curves on weights and biases, dataset (and order of training data), hyperparameters, optimizer, learning rate schedule, hardware, and other details that may be relevant.

Include frequent checkpoints throughout training, which will help other researchers understand training dynamics.

Citation

Please cite this repository when used in papers or research projects.

@misc{joseph2023vit,
  author = {Sonia Joseph},
  title = {ViT Prisma: A Mechanistic Interpretability Library for Vision Transformers},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/soniajoseph/vit-prisma}}
}

vit-prisma's People

Contributors

alik-git avatar jain18ayush avatar praneetneuro avatar soniajoseph avatar stevinson avatar themachinefan avatar yashvadi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

vit-prisma's Issues

Make sure we have documentation for functions in Emoji Logit Lens

Docstrings are not complete for functions. We can prioritize our public facing tutorials, and the functions they use, in terms of what we should document first.

Document all the vit_prisma functions imported into the Emoji Logit Lens Demo: https://colab.research.google.com/drive/1yAHrEoIgkaVqdWC4GY-GQ46ZCnorkIVo#scrollTo=8nN9gDe0OcCW

Ensure that we have type and shape for all input and outputs for the functions. Follow best documentation practices.

Patch-level labels with SAM

  • Given patchxdimension activations and segmentation map, return more granular labels
  • High-level design decisions need to be made about whether to compute our own segmentation maps, use pre-computed datasets, and returning hierarchical labels (e.g. eye -> face -> person).

Adapt TinyCLIP to HookedViT

This will allow for multimodal models!

  • See code for adapting timm vits as reference
  • Ensure that this version of TinyCLIP is not BertBlock (BertBlock has LayerNorms after the attention and MLP layers; we want LayerNorms before the attention and MLP layers). I believe both versions are floating about the internet.
  • Ensure folding the LayerNorm, weight centering etc are fine.
  • If you truly can't find a non-BertBlock TinyCLIP, then consult this TransformerLens PR for how to fold the LayerNorm: TransformerLensOrg/TransformerLens#509
  • Compare the logit output of the HookedViT to the original TinyCLIP

Create an ImageNet-1k subset for tutorial notebooks and upload to Huggingface

UPDATE: hold off on this until we have patch-level labels via SAM

Create ~10k train/val/test image subset of ImageNet-1k for tutorial notebooks and upload to Huggingface. (Depending on if this dataset size is computationally reasonable. I can also run a script on a CC cluster if you have compute limitations.)

Structure patch-level labels so that they can accommodate any patch size (e.g. include just segmentation maps that can be converted to patch-level labels).

Implement various types of loss and subtypes

Types of loss

Classification

  • CLS token prediction (already implemented)
  • GAAP (Global Average Pooling) (already implemented)

Reconstruction

  • Next-patch prediction (do this for the entire image... feed in image autoregressively)
  • Masked reconstruction (mask out n% of the patches and predict those patches)

Sub-types (Need better wording for this)

  • Cross-Entropy loss (already implemented)
  • MSE loss (implemented?)
  • MAE (mean absolute error) loss

Update config to give the options above.

Add hook functions to the ViT

  • Add hook functions so you can get all intermediate activations of an image with one forward inference pass (probs return activations in a dictionary?)
  • Add function so you can get intermediate activations of multiple images

Direct Logit Attribution not replicating

Direct Logit Attribution not replicating

Context
I am replicating the Exploratory Analysis notebook, which goes through Reverse Engineering GPT-2 Circuit and also basic TransformerLens functionality.

I am replicating the functionality in this notebook here with localizing cat/non-cat in the net as a test case.
So far, the replication is acting as a good test of our functionality.

I am not able to replicate the “average logit difference.” In my notebook, this is under the "Direct Logit Attribution" section of both notebooks. Projecting the intermediate activations onto the output head is supposed to equal taking the logits. However, I am getting a mean of 9.21 for the first method and 7.31 for the taking the logits directly.

What could be going wrong:

  1. Layer norm folding, centering the weight matrices, or centering the unembed may have bugs or may be incorrectly implemented. One step is checking each of these operations in the unit test.
  2. There may be a bug in my implementation in that notebook

Request
Double-check my work and find the bug (I have been staring at this for awhile and extra eyes would be immensely helpful! Without getting this notebook to work, we cannot be sure that the basic functionality is set up properly)

Debug visualization/visualize_attn_js.py code + test in notebook

  • [] Load interactive attention head visualization and notebook and thoroughly test for bugs
  • [] Test mouse-out functionality (I believe there's a bug here, it's laggy)
  • [] Test that the CLS token reverts back properly when you mouse over it (the CLS column does not on Google Collab)

Code: https://github.com/soniajoseph/ViT-Prisma/blob/main/src/vit_prisma/visualization/visualize_attention_js.py
Example in Colab: https://colab.research.google.com/drive/1xyNa2ghlALC7SejHNJYmAHc9wBYWUhZJ?usp=sharing

Fix the ImageNet checkpoints so they properly load from Huggingface

Right now, the toy ImageNet checkpoints need ImagenetConfig in the same folder to load. This is not expected behavior.

One possible fix is loading all the checkpoints again, saving them again, ensuring that they can load without ImagenetConfig in he folder, and then reuploading them to Huggingface.

@YashVadi, someone else also welcome to pick this up

Test standard mech interp techniques on Video ViT

@themachinefan adapted the video vision transformer VivitForVideoClassification to Prisma from transformers (thank you!), but it's not yet clear if the standard mech interp techniques apply or break.

Try running direct logit attribution, etc on the video ViT. Put your results in a jupyter notebook as research code.

Create polygenic induction dataset

  • Instead of each image having 1 pair of objects (monogenic), each image has 2 pairs (polygenic) that are directly adjacent (either horizontally or vertically). This gives NNNN patterns (e.g. AAAA, ABAB, ABBA, BBAA, BBBA, BBBB, etc...)
  • This would result in a 12-way classification problem, with classes (for each horizontal and vertical):
  • AAAA
  • ABAB
  • ABBA
  • AABB
  • ABBB
  • AAAB
  • Balance the dataset so same number in every class
  • Ensure that the categories look sufficiently different from each other by visually inspecting them

Put patch-level labels in a Dataloader, ensure it loads in notebook

ImageNet labels are way too coarse-grained. @themachinefan put ImageNet through a SAM pipeline to get a label for each patch.

The results are here: https://huggingface.co/datasets/Prisma-Multimodal/segmented-imagenet1k-subset

We need to make sure these results load cleanly into a Dataloader, which we can cleanly query to get the label per patch

Some things to keep in mind are:

  • The DataLoader should keep in mind that patch size is a hyperparameter that can vary. Patch size usually 8, 16, or 32.
  • The above labels currently are in the form of a boolean mask. Could we make this code more efficient?
  • Ideally, there is a function that takes the the patch number, and returns the labels for that patch returned as a list of strings.

Convert my own pre-trained model into HookedViT object

Hi Sonia,

Thank you for sharing your work!

I would like to visualize my own model (vit_cxr), which was fine-tuned on the google/vit-base-patch16-224-in21k, using ViT Prisma. The model architecture is ViT-Base, and I have uploaded the fine-tuned model to HuggingFace.

I was playing around with the ViT Prisma Main Demo by importing vit_cxr from HuggingFace as a HookedViT using the following code:

model = HookedViT.from_pretrained(model_name="Lancelottery/cxr-race", is_timm=False)

However, it raised the following value error:

'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'Lancelottery/cxr-race', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-12, 'original_architecture': ['ViTForImageClassification'], 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 16, 'image_size': 224, 'n_classes': None, 'n_params': None


ValueError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs)
287 )
--> 288 raise ValueError
289

ValueError:

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last)
2 frames
/usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs)
295
296 except:
--> 297 raise ValueError(
298 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
299 )

ValueError: Loading weights from the architecture is not currently supported: ['ViTForImageClassification'], generated from model name Lancelottery/cxr-race. Feel free to open an issue on GitHub to request this feature.

image

I found that as soon as I set is_timm = False, it will raise a value error when I use an hf_model:

image

I was wondering if the ViT Prisma repository supports loading pretrained models with weights from HuggingFace. If not, is there an alternative way to import my pretrained model, such as using the pytorch_model.bin file?

Thank you in advance for your support and guidance!

Make sure all the functions used in the Interactive Attention Head Tour are documented

Docstrings are not complete for functions. We can prioritize our public facing tutorials, and the functions they use, in terms of what we should document first.

Document all the vit_prisma functions imported into the Interactive Attention Head Tour: https://colab.research.google.com/drive/1P252fCvTHNL_yhqJDeDVOXKCzIgIuAz2#scrollTo=xrAzVMmb-DmG

Ensure that we have type and shape for all input and outputs for the functions. Follow best documentation practices.

Make sure we have documentation for functions in ViT Prisma Main Demo

Docstrings are not complete for functions. We can prioritize our public facing tutorials, and the functions they use, in terms of what we should document first.

Document all the vit_prisma functions imported into ViT Prisma Main Demo: https://colab.research.google.com/drive/1TL_BY1huQ4-OTORKbiIg7XfTyUbmyToQ#scrollTo=eCO9BA7EqUaE

Ensure that we have type and shape for all input and outputs for the functions. Follow best documentation practices.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.