Code Monkey home page Code Monkey logo

concretedropout's Introduction

ConcreteDropout

Build Status

PyTorch implementation of Concrete Dropout

This repository provides an implementation of the theory described in the Concrete Dropout paper. The code provides a simple PyTorch interface which ensures that the module can be integrated into existing code with ease.

  • Python 3.6+
  • MIT License

Overview:

Obtaining reliable uncertainty estimates is a challenge which requires a grid-search over various dropout probabilities - for larger models this can be computationally prohibitive. The Concrete Dropout paper suggests a novel dropout variant which improves performance and yields better uncertainty estimates.

Concrete Dropout uses the approach of optimising the dropout probability through gradient descent in order to minimise an objective wrt. that parameter. Dropout can be viewed as as an approximating distribution to the posterior, q(w). Using this interpretation it is possible to add a regularisation term to the loss function which is dependant on the KL Divergence, KL[q(w)||p(w)]; this ensures that the posterior does not deviate too far from the prior. As is often the case, the KL Divergence is computationally intractable and as such an approximation is developed - details of this can be seen in equations [2-4] in the paper.

In typical dropout the probability is modelled as a Bernoulli random variable - unfortunately this does not play well with the re-parameterisation trick which is required to calculate the derivative of the objective. To allow the derivative to be calculated, a continous relaxation of the discrete Bernoulli distribution is used - specifically the Concrete distribution relaxation. This has a simple parameterisation which reduces to a simple sigmoid distribution as seen in equation [5].

Through use of the Concrete relaxation it is now possible to compute the derivatives of the objective with help from the re-parameterisation trick and optimise the dropout probability through gradient descent.

Example:

An example of ConcreteDropout has been implemented in mnist_example.py - this example can be run with:

python3 mnist_example.py

MNIST Results

References:

@misc{gal2017concrete,
    title={Concrete Dropout},
    author={Yarin Gal and Jiri Hron and Alex Kendall},
    year={2017},
    eprint={1705.07832},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}
Code by Yarin Gal, author of the paper.
PyTorch implementation of Concrete Dropout
Made by Daniel Kelshaw

concretedropout's People

Contributors

arose13 avatar danielkelshaw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

ml-edu arose13

concretedropout's Issues

Rethink Class Decorator Approach

In the initial development of this code I was experimenting with the use of class decorators to add functionality, namely the concrete_regulariser. Coming back to this code several months later I realise that, although this works, it is not the nicest way of implementing a solution.

I would like to rework my implementation of the regularisation by making a base-class which accounts for the implementation, rather than requiring the user to wrap their model implementation with a pre-defined decorator.

is softmax required?

The model passes the logits through softmax in the output layer, but then the F.cross_entropy() function is used which combines log_softmax and nll. I think maybe the softmax in the output layer is not necessary?

Provide MNIST Example

An example of using ConcreteDropout in training a network for MNIST classification should be included - this provides a demonstration of the capabilities of the network.

Update README

At the moment the README.md provides no particular indication as to what this project is about - it would be a good idea to update this to describe the purpose of the code that has been developed.

Implement ConcreteDropout Class

Implementation of the ConcreteDropout class should inherit from nn.Module to allow the functionality to be used inline with other PyTorch modules. Upon implementation the ConcreteDropout module can be used in a regular neural network to be tested.

Configure Repository

In order to provide a suitable development environment the repository should be set up with an appropriate directory structure as well as functionality for automated continuous integration.

Weight regulariser and dropout regulariser

Hi ! I read Yarin Gal's paper and I did not understand how the weight regulariser and dropout regulariser are initialized. The author provided a formula, but it is not very clear (e.g what means prior length scale ? and which value to assign for this variable ?). Could you explain how you find the values used to inizialize the weight regulariser and the dropout regulariser ?

Demonstrate Usage in README.md

It would be good to demonstrate how the modules developed can be user in a very simple way. This could help users implement the code found in this repository.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.