Code Monkey home page Code Monkey logo

pwuethri / active_learning_bnn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lukaserlenbach/active_learning_bnn

0.0 0.0 0.0 736 KB

This directory contains a sample from the code which I wrote during my master thesis on Active Learning with Bayesian Neural Networks with Gaussian Processes. Two example scripts are provided which show: 1. the superiority of the GPA Sampler over random point selection, 2. the runtime improvement of the Fast GPA Sampler over the GPA Sampler

License: MIT License

Python 100.00%

active_learning_bnn's Introduction

Active Learning for Bayesian Neural Networks with Gaussian Processes

(last modified 01.02.2021 by Lukas Erlenbach, LinkedIn profile)

This directory contains parts of the code that I have written during my master thesis (which is available here).

The project implements a active learning framework for Bayesian Neural Networks and regression tasks and is based on a paper by Tsymbalov et al. [1]. Furthermore, it contains a generalization with faster runtime which is described in chapter 5.4 of the thesis.

(very short) Introduction

(I also published a blog post with an intuitive introduction on Medium.)

As example data, the housing dataset from sklearn is used which is a well known regression problem. The aim is to train a Bayesian Neural Network with low RMSE while minimizing the number of training data points.

After a Bayesian Neural Network is trained of an initial set of points, Active Learning iterations are performed. In each iteration, a Sampler selects points from a pool (without having access to the lables of these points) which then get added to the training data before the training process is resumed.

For further details please consider the paper [1], my blog post or my thesis.

Two example scripts are provided which showcase:

  1. That the GPA Sampler from [1] is superior to randomly selecting additional points.
  2. That the Fast GPA Sampler computes the same points as the GPA Sampler from [1] while reducing the runtime.

Usage

To run the experiments, it is recommended to set up a virtualenv with python3.8 and intall the requirements via

    virtualenv --python=python3.8 venv
    source venv/bin/activate
    pip install -r requirements.txt

Afterwards the two example scripts can be called via

    python source/compare_gpa_rand.py
    python source/compare_fastgpa_batchgpa.py

The first script takes about 10min to run on my machine, the second less than 2min. Both create a results directory which contains logfile, experiment configurations as well as a plot with the most important metrics.

The first scripts compares the GPA Sampler from [1] to random point selection. The results depend on the chosen random seed but in most of the cases using the GPA Sampler leads to a faster and more stable convergence. (For qualitative results, refer to chapter 6 in the thesis.)

The second script compares the Fast GPA Sampler (from my thesis) to the GPA Sampler (from [1]). Both Sampler in theory compute the same posterior variance, however, sometimes differences in the convergence occur from rounding errors. In general, the Fast version of the sampler does the same job in shorter time as it avoids the repeated inversion of the posterior covariance matrix.

To change the experimental setup, consider changing the parameter values (in particular random seed and train/pool/test sizes) in the python scripts and the .yaml configuration files in configs/.

Example Results from source/compare_gpa_rand.py

Example Results

Example Results from source/compare_fastgpa_batchgpa.py

Example Results

References

active_learning_bnn's People

Contributors

lukaserlenbach avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.