Code Monkey home page Code Monkey logo

mmnas's Introduction

MMNas: Deep Multimodal Neural Architecture Search

This repository corresponds to the PyTorch implementation of the MMnas for visual question answering (VQA), visual grounding (VGD), and image-text matching (ITM) tasks.

example-image

Prerequisites

Software and Hardware Requirements

You may need a machine with at least 4 GPU (>= 8GB), 50GB memory for VQA and VGD and 150GB for ITM and 50GB free disk space. We strongly recommend to use a SSD drive to guarantee high-speed I/O.

You should first install some necessary packages.

  1. Install Python >= 3.6

  2. Install Cuda >= 9.0 and cuDNN

  3. Install PyTorch >= 0.4.1 with CUDA (Pytorch 1.x is also supported).

  4. Install SpaCy and initialize the GloVe as follows:

    $ pip install -r requirements.txt
    $ wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz
    $ pip install en_vectors_web_lg-2.1.0.tar.gz

Dataset Preparations

Please follow the instructions in dataset_setup.md to download the datasets and features.

Search

To search an optimal architecture for a specific task, run

$ python3 search_[vqa|vgd|vqa].py

At the end of each searching epoch, we will output the optimal architecture (choosing operators with largest architecture weight for every block) accroding to current architecture weights. When the optimal architecture doesn't change for several continuous epochs, you can kill the searching process manually.

Training

The following script will start training network with the optimal architecture that we've searched by MMNas:

$ python3 train_[vqa|vgd|itm].py --RUN='train' --ARCH_PATH='./arch/train_vqa.json'

To add:

  1. --VERSION=str, e.g.--VERSION='mmnas_vqa' to assign a name for your this model.

  2. --GPU=str, e.g.--GPU='0, 1, 2, 3' to train the model on specified GPU device.

  3. --NW=int, e.g.--NW=8 to accelerate I/O speed.

  1. --RESUME to start training with saved checkpoint parameters.

  2. --ARCH_PATH can use the different searched architectures.

If you want to evaluate an architecture that you got from seaching stage, for example, it's the output architecture at the 50-th searching epoch for vqa model, you can run

$ python3 train_vqa.py --RUN='train' --ARCH_PATH='[PATH_TO_YOUR_SEARCHING_LOG]' --ARCH_EPOCH=50

Validation and Testing

Offline Evaluation

It's convenient to modify follows args: --RUN={'val', 'test'} --CKPT_PATH=[Your Model Path] to Run val or test Split.

Example:

$ python3 train_vqa.py --RUN='test' --CKPT_PATH=[Your Model Path] --ARCH_PATH=[Searched Architecture Path]

Online Evaluation (ONLY FOR VQA)

Test Result files will stored in ./logs/ckpts/result_test/result_train_[Your Version].json

You can upload the obtained result file to Eval AI to evaluate the scores on test-dev and test-std splits.

Pretrained Models

We provide the pretrained models in pretrained_models.md to reproduce the experimental results in our paper.

Citation

If this repository is helpful for your research, we'd really appreciate it if you could cite the following paper:

@article{yu2020mmnas,
  title={Deep Multimodal Neural Architecture Search},
  author={Yu, Zhou and Cui, Yuhao and Yu, Jun and Wang, Meng and Tao, Dacheng and Tian, Qi},
  journal={Proceedings of the 28th ACM International Conference on Multimedia},
  pages = {3743--3752},
  year={2020}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.