Code Monkey home page Code Monkey logo

calibrated-uncertainty's Introduction

Calibration of Bayesian Neural Networks

Analysis of the paper by Kuleshov et al. (2018) โ€” Accurate Uncertainties for Deep Learning Using Calibrated Regression.

Harvard University
Class: AM 207 โ€” Stochastic Methods for Data Analysis, Inference and Optimization
Deliverables: Project Report and the source code in this repository

Table of Contents

Summary of Research

The Issue of Miscalibration

Proper quantification of uncertainty is crucial for applying statistical models to real-world situations. The Bayesian approach to modeling provides us with a principled way of obtaining such uncertainty estimates. Yet, due to various reasons, such estimates are often inaccurate. For example, a 95% posterior predictive interval does not contain the true outcome with a 95% probability. Such a model is miscalibrated.

Sources of Miscalibration

In our project, we first demonstrate that the problem of miscalibration exists and show why it exists for Bayesian neural networks (BNNs) in regression tasks. We focus on the following sources of miscalibration:

  • The prior is wrong, e.g. too strong and overly certain
  • The likelihood function is wrong. There is bias, i.e. the neural network is too simple and is unable to model the data.
  • The noise specification in the likelihood is wrong
  • The inference is approximate or is performed incorrectly

Our aim is to establish a causal link between each aspect of the model-building process and a bad miscalibrated outcome.

Contribution of the Reviewed Paper

Proposition: [Kuleshov et al., 2018] propose a simple calibration algorithm for regression. The method is heavily inspired by Platt scaling [Platt, 1999], which consists of training an additional sigmoid function to map potentially non-probabilistic outputs of a classifier to empirical probabilities.

Unique contribution: The paper contributes to the subject literature by:

  • extending the recalibration methods used so far for classification tasks (Platt scaling) to regression;
  • proposing a procedure that is universally applicable to any regression model, be it Bayesian or frequentist, and does not require modification of the model. Instead, the algorithm is applied to the output of any existing model in a postprocessing step.

Claim: The authors claim that the method outperforms other techniques by consistently producing well-calibrated forecasts, given enough i.i.d. data. Based on their experiments, the procedure also improves predictive performance in several tasks, such as time-series forecasting and reinforcement learning.

Evaluation of the Claims

We evaluate the claims through a series of experiments on synthetic datasets and different sources of miscalibration. Our methodology is as follows:

  1. Data Generation: We generate the data from a known true function with Gaussian or non-Gaussian noise.

  2. Model Building: We then build multiple feedforward BNN models using:

    • different network architectures
    • several priors on the weights, depending on model complexity
    • different variance of the Gaussian noise in the likelihood function
  3. Inference: We obtain the posterior of the model by:

    • sampling from it with the No-U-Turn algorithm
    • approximating the posterior using Variational Inference with reparametrization and isotropic Gaussians

    We check for convergence using trace plots, the effective sample size, and Gelman-Rubin tests. In the case of variational inference, we track the ELBO during optimization.

  4. Recalibration: Finally, we apply the proposed recalibration algorithm to the obtained model. We then visually compare the posterior predictives before and after calibration to the true distribution of the data. This allows us to identify scenarios where the algorithm works well and the cases of failure.

See the full version of the project report for the summary of findings and conclusions.

Reproducing the Results

The final report depends on the following Python data science stack:

  • NumPy
  • SciPy
  • pandas
  • Dask
  • scikit-learn
  • matplotlib
  • Jupyter Notebook

It also requires the probabilistic library NumPyro (based on JAX), which provides fast implementations of sampling and variational inference algorithms.

Use the provided conda environment specification to satisfy all dependencies:

$ conda env create -f report/environment.yml
$ conda activate am207

Intermediate experiments preceding the final report also make use of PyMC3 and autograd. These are considerably slower to run and are optional to reproduce the project results.

Repository Structure

Directory Description
calibration Initial implementation of the calibration algorithm and metrics
experiments Sources of miscalibration in Bayesian neural networks
report The final report together with all the code in the corresponding ./code subfolder
slides Intermediate meetings in the course of the project

We use Jupyter Notebooks *.ipynb throughout the repository. The ./report/code folder is the only one that contains the original source code in *.py format. All of the remaining *.py files can be ignored: Jupytext produced those from the corresponding notebooks to enable clear commit diffs for the project team.

calibrated-uncertainty's People

Contributors

0-one avatar dvukolov avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.