Code Monkey home page Code Monkey logo

arch_uncert's Introduction

Variational Depth Search in ResNets

arxiv Python 3.7+ Pytorch 1.3 License: MIT

​ ​ One-shot neural architecture search allows joint learning of weights and network architecture, reducing computational cost. We limit our search space to the depth of residual networks and formulate an analytically tractable variational objective that allows for obtaining an unbiased approximate posterior over depths in one-shot. We propose a heuristic to prune our networks based on this distribution. We compare our proposed method against manual search over network depths on the MNIST, Fashion-MNIST, SVHN datasets. We find that pruned networks do not incur a loss in predictive performance, obtaining accuracies competitive with unpruned networks. Marginalising over depth allows us to obtain better-calibrated test-time uncertainty estimates than regular networks, in a single forward pass. ​

Requirements

Python packages:

  • test-tube 0.7.5
  • Pytorch 1.3.1, torchvision 0.4.2
  • Numpy 1.17.4
  • Matplotlib 3.1.2
  • scikit-learn 0.22
  • scypy 1.3.3 ​

Running Experiments from the Paper

​ Integers passed as an argument to python scripts represent which cuda device to use. If you only have one GPU, use 0. If you dont have a GPU, pass any integer and your CPU will be used automatically. ​ First change into experiments directory:

cd experiments

Spirals

​ In order to reproduce the plots from our paper exactly, you will need to run each script multiple times. Different runs of each experiment are automatically saved separately. ​ Train Learnt Depth Networks with every maximum depth:

python scan_max_depth_spirals.py 0

​ Train deterministic depth networks of every depth:

python scan_deterministic_depth_spirals.py 0

​ Train Learnt Depth Networks with different dataset sizes:

python scan_data_amount.py 0

​ Train Learnt Depth Networks with different dataset complexity:

python scan_spiral_complexity.py 0

​ Train Learnt Depth Networks with different widths:

python scan_width_spirals.py 0

Images

​ Each script runs experiments on all three datasets (MNIST, Fashion-MNIST and SVHN) and repeats each experiment 4 times. ​ Train Learnt Depth Networks:

python scan_max_depth_images.py 0

​ Train Deterministic Depth Networks:

python scan_deterministic_depth_images.py 0

Generate Plots from Paper

​ All plotting code is contained within the notebooks in the ./notebooks/ folder. Once the experiments have been run, running the notebooks will generate the plots. ​

Citation

If you find this repo useful, please cite: ​

@misc{antoran2020variational,
    title={Variational Depth Search in ResNets},
    author={Javier Antorán and James Urquhart Allingham and José Miguel Hernández-Lobato},
    year={2020},
    eprint={2002.02797},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

arch_uncert's People

Contributors

jamesallingham avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

arch_uncert's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.