Code Monkey home page Code Monkey logo

visdial's Introduction

VisDial

Code for the paper

Visual Dialog
Abhishek Das, Satwik Kottur, Khushi Gupta, Avi Singh, Deshraj Yadav, José M. F. Moura, Devi Parikh, Dhruv Batra
arxiv.org/abs/1611.08669
CVPR 2017 (Spotlight)

Visual Dialog requires an AI agent to hold a meaningful dialog with humans in natural, conversational language about visual content. Given an image, dialog history, and a follow-up question about the image, the AI agent has to answer the question.

Demo: demo.visualdialog.org

This repository contains code for training, evaluating and visualizing results for all combinations of encoder-decoder architectures described in the paper. Specifically, we have 3 encoders: Late Fusion (LF), Hierarchical Recurrent Encoder (HRE), Memory Network (MN), and 2 kinds of decoding: Generative (G) and Discriminative (D).

models

If you find this code useful, consider citing our work:

@inproceedings{visdial,
  title={{V}isual {D}ialog},
  author={Abhishek Das and Satwik Kottur and Khushi Gupta and Avi Singh
    and Deshraj Yadav and Jos\'e M.F. Moura and Devi Parikh and Dhruv Batra},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2017}
}

Setup

All our code is implemented in Torch (Lua). Installation instructions are as follows:

git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install-deps;
TORCH_LUA_VERSION=LUA51 ./install.sh

Additionally, our code uses the following packages: torch/torch7, torch/nn, torch/nngraph, Element-Research/rnn, torch/image, lua-cjson, loadcaffe, torch-hdf5. After Torch is installed, these can be installed/updated using:

luarocks install torch
luarocks install nn
luarocks install nngraph
luarocks install rnn
luarocks install image
luarocks install lua-cjson
luarocks install loadcaffe
luarocks install luabitop
luarocks install totem

Installation instructions for torch-hdf5 are given here.

Running on GPUs

Although our code should work on CPUs, it is highly recommended to use GPU acceleration with CUDA. You'll also need torch/cutorch and torch/cunn.

luarocks install cutorch
luarocks install cunn

Training your own network

Preprocessing VisDial

The preprocessing script is in Python, and you'll need to install NLTK.

pip install nltk
pip install numpy
pip install h5py
python -c "import nltk; nltk.download('all')"

VisDial v0.9 dataset can be downloaded and preprocessed as follows:

cd data
python prepro.py -download 1
cd ..

This will generate the files data/visdial_data.h5 (contains tokenized captions, questions, answers, image indices) and data/visdial_params.json (contains vocabulary mappings and COCO image ids).

Extracting image features

Since we don't finetune the CNN, training is significantly faster if image features are pre-extracted. We use image features from VGG-16. The model can be downloaded and features extracted using:

sh scripts/download_vgg16.sh
cd data
th prepro_img.lua -imageRoot /path/to/coco/images/ -gpuid 0

This should generate data/data_img.h5 containing features for COCO train and val splits corresponding to VisDial v0.9.

Training

Finally, we can get to training models! All supported encoders are in the encoders/ folder (lf-ques, lf-ques-im, lf-ques-hist, lf-ques-im-hist, hre-ques-hist, hre-ques-im-hist, hrea-ques-im-hist, mn-ques-hist, mn-ques-im-hist), and decoders in the decoders/ folder (gen and disc).

Generative (gen) decoding tries to maximize likelihood of ground-truth response and only has access to single input-output pairs of dialog, while discriminative (disc) decoding makes use of 100 candidate option responses provided for every round of dialog, and maximizes likelihood of correct option.

Encoders and decoders can be arbitrarily plugged together. For example, to train an HRE model with question and history information only (no images), and generative decoding:

th train.lua -encoder hre-ques-hist -decoder gen -gpuid 0

Similarly, to train a Memory Network model with question, image and history information, and discriminative decoding:

th train.lua -encoder mn-ques-im-hist -decoder disc -gpuid 0

The training script saves model snapshots at regular intervals in the checkpoints/ folder.

It takes about 15-20 epochs to train models with generative decoding to convergence, and 4-8 epochs for discriminative decoding.

Evaluation

We evaluate model performance by where it ranks human response given 100 response options for every round of dialog, based on retrieval metrics — mean reciprocal rank, R@1, R@5, R@10, mean rank.

Model evaluation can be run using:

th evaluate.lua -loadPath checkpoints/model.t7 -gpuid 0

Note that evaluation requires image features data/data_img.h5, tokenized dialogs data/visdial_data.h5 and vocabulary mappings data/visdial_params.json.

Running Beam Search & Visualizing Results

We also include code for running beam search on your model snapshots. This gives significantly nicer results than argmax decoding, and can be run as follows:

th generate.lua -loadPath checkpoints/model.t7 -maxThreads 50

This would compute predictions for 50 threads from the val split and save results in vis/results/results.json.

cd vis
# python 3.6
python -m http.server
# python 2.7
# python -m SimpleHTTPServer

Now visit localhost:8000 in your browser to see generated results.

Sample results from HRE-QIH-G available here.

Download Extracted Features & Pretrained Models

All files available for download here.

  • visdial_data.h5: Tokenized captions, questions, answers, image indices
  • visdial_params.json: Vocabulary mappings and COCO image ids
  • data_img.h5: VGG16 image features for COCO train and val

Pretrained models

Model checkpoints available here.

Discriminative decoding

  • hre-qih-d.t7: Hierarchical Recurrent Encoder
  • hrea-qih-d.t7: Hierarchical Recurrent Encoder with Attention
  • mn-qih-d.t7: Memory Network
  • lf-qih-d.t7: Late Fusion

Generative decoding

  • hre-qih-g.t7: Hierarchical Recurrent Encoder
  • hrea-qih-g.t7: Hierarchical Recurrent Encoder with Attention
  • mn-qih-g.t7: Memory Network
  • lf-qih-g.t7: Late Fusion

Contributors

License

BSD

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.