Code Monkey home page Code Monkey logo

lgssl's Introduction

Learning Visual Representations via Language-Guided Sampling

Learning Visual Representations via Language-Guided Sampling
Mohamed El Banani, Karan Desai, and Justin Johnson

If you have any questions, please feel free to email me at [email protected].

Environment Setup

We recommend using Anaconda or Miniconda. To setup the environment, follow the instructions below.

conda create -n lgssl python=3.8 --yes
conda activate lgssl
conda install pytorch=1.12.1 torchvision cudatoolkit=11.3 -c pytorch --yes

python -m pip install -r requirements.txt
python setup.py develop

Training Datasets

Expand

We train our models on RedCaps and ConceptualCaptions (CC3M and CC12M). We note that all 3 datasets can decay, so you might end up with a different number of instances. Please refer to the original papers for dataset download instructions. In our case, the datasets had the following sizes:

Dataset Size
RedCaps-2020 3273223
RedCaps 12010494
CC3M 2913035
CC12M 10958691

We assume all training datasets are in data/datasets which is set as the default data_root in the base dataset class. We expect the dataset to be in the format below where each dataset is subdivided into several directories and each directory contains a set of instances where each instance has an image file and a json caption file.

data/datasets/<dataset_name>
    |- directory_0
        |- <instance_0>.jpg     <- image for instance 0
        |- <instance_0>.json    <- caption for instance 0
        |- <instance_1>.jpg
        |- <instance_1>.json
        ...
        |- <instance_n>.jpg
        |- <instance_n>.json
    |- directory_1
    |- directory_2
    ...
    |- directory_m

For RedCaps, the directory names are encoded as <subreddit>_<year>_<id>, e.g., crochet_2017_000001, where each directory only has 10000 classes. We use this naming convention for some of the experiments: experiments with redcaps-2020 and sampling scope.

Generating dataset dictionaries

We create dataset specific dictionaries that contain the information for each dataset (eg, image paths, captions) which allow for easy sampling in subsequent steps. To generate a dataset dictionary, run the following code where <dataset_name> is the name of the dataset repo in data/datasets.

cd preprocess
python make_imagecaption_dict.py <dataset_name> 

Sampling nearest neighbor pairs

Once we have the dataset dictionaries, we can easily sample nearest neighbor pairs. We provide the code for sampling using language or visual embeddings. We also provide the sampling based on dataset subsets for the experiments reported in supplementary. Check the commands below for language sampling based on SBERT, visual sampling based on an ImageNet pretrained model, and language sampling within each subreddit.

python sample_language_nn.py <dataset_name> all-mpnet-base-v2                       # Language - MPNet (SBERT)
python sample_language_nn_subsets.py <dataset_name> all-mpnet-base-v2 subreddit     # Language Subset - MPNet (SBERT) on subreddits

python sample_visual_nn.py <dataset_name> vit_b_32 IMAGENET1K_V1                    # Visual - ImageNet-supervised ViT-B/32

Evaluation Datasets

Expand

We use TensorFlow Datasets for our evaluations. This package provides us with all the evaluations except for FGVC Aircraft. Our code will automatically download and extract all the datasets in data/evaluation_datasets on the first run of the evaluation code. This means that the first evaluation run will be much slower than usual.

Note 1: We encountered a bug with SUN 397 where one image could not be decoded correctly. This is a known bug which has not been fixed yet in the stable version. To fix it, simply make the two changes outlined by this commit.

Note 2: TensorFlow Datasets will require you to independently downloaded RESISC45. Please follow the instructions provided here

Training models

We use hydra configs for our training experiments. The configs can all be found here. To run an experiment, you can either to define a new experiment config which can be used to override the default configs. Alternatively, you can just overwrite some configs in the command. We provide a few sample training commands configs for clarity:

python train.py +experiment=ours                        % LG SimCLR
python train.py +experiment=vis_baseline                % SimCLR 
python train.py +experiment=vis_baseline model=simsiam  % SimSiam

Evaluation

We use two primary evaluations: linear probe using L-BFGS and few-shot evaluation. The configs for those evaluations can be found here.

Linear Probe: we train a single layer using logistic regression and sweep over regualizer weight values. We provide an implementation of logistic regression using PyTorch's L-BFGS, however, you can easily use scikit-learn's implementation by setting the use_sklearn flag in the evaluation configs. For datasets without a standard validation split, we randomly split the training set while maintaining the class distribution.

Few-Shot Evaluation: we also evaluate our frozen features on 5-shot, 5-way classification. The evaluation can be found here. We sample the training samples from the train/valid splits and the query samples for the test set.

The following commands can be used to evaluate checkpoints or baselines. For example, you can evaluate our model or the pretrained SimCLR checkpoint on all the datasets by running the following commands:

python evaluate.py model.name=lgssl_checkpoints model.checkpoint=lgsimclr dataset.name=all
python evaluate.py model.name=simclr dataset.name=all

Pre-trained Checkpoints

You can find all our pretrained checkpoints here. You should download them to data/checkpoints. Alternatively, you could just use hubconf to get the relevant checkpoint as shown in the code snippet below:

import torch
model = torch.hub.load("mbanani/lgssl", "lgsimclr")

For a list of released models, check hubconf.py

Citation

If you find this code useful, please consider citing:

@inproceedings{elbanani2022languageguided,
  title={{Learning Visual Representations via Language-Guided Sampling}},
  author={El Banani, Mohamed and Desai, Karan and Johnson, Justin},
  booktitle={CVPR},
  year={2023},
}

Acknowledgments

We thank Richard Higgins, Ashkan Kazemi, and Santiago Castro for many helpful discussions. We also thank David Fouhey, Ziyang Chen, Chenhao Zheng, and Fahad Kamran, and Dandan Shan for their feedback on early drafts. This project was funded under the Ford-UM Alliance partnership. We thank Alireza Rahimpour, Devesh Upadhyay, and Ali Hassani from Ford Research for their support and discussion.

lgssl's People

Contributors

mbanani avatar kdexd avatar ahmadmustafaanis avatar relh avatar bryant1410 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.