Code Monkey home page Code Monkey logo

confidnet's Introduction

Addressing Failure Prediction by Learning Model Confidence

Charles Corbière, Nicolas Thome, Avner Bar-Hen, Matthieu Cord, Patrick Pérez
Neural Information Processing Systems (NeurIPS), 2019

If you find this code useful for your research, please cite our paper:

@incollection{NIPS2019_8556,
   title = {Addressing Failure Prediction by Learning Model Confidence},
   author = {Corbi\`{e}re, Charles and THOME, Nicolas and Bar-Hen, Avner and Cord, Matthieu and P\'{e}rez, Patrick},
   booktitle = {Advances in Neural Information Processing Systems 32},
   editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
   pages = {2902--2913},
   year = {2019},
   publisher = {Curran Associates, Inc.},
   url = {http://papers.nips.cc/paper/8556-addressing-failure-prediction-by-learning-model-confidence.pdf}
}

Abstract

Assessing reliably the confidence of a deep neural net and predicting its failures is of primary importance for the practical deployment of these models. In this paper, we propose a new target criterion for model confidence, corresponding to the True Class Probability (TCP).We show how using the TCP is more suited than relying on the classic Maximum Class Probability (MCP). We provide in addition theoretical guarantees for TCP in the context of failure prediction. Since the true class is by essence unknown at test time, we propose to learn TCP criterion on the training set, introducing a specific learning scheme adapted to this context. Extensive experiments are conducted for validating the relevance of the proposed approach. We study various network architectures, small and large scale datasets for image classification and semantic segmentation. We show that our approach consistently outperforms several strong methods, from MCP to Bayesian uncertainty, as well as recent approaches specifically designed for failure prediction.

Installation

  1. Clone the repo:
$ git clone https://github.com/valeoai/ConfidNet
  1. Install this repository and the dependencies using pip:
$ pip install -e ConfidNet

With this, you can edit the ConfidNet code on the fly and import function and classes of ConfidNet in other project as well.

  1. Optional. To uninstall this package, run:
$ pip uninstall ConfidNet

You can take a look at the Dockerfile if you are uncertain about steps to install this project.

Datasets

MNIST, SVHN, CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in confidnet/data/DATASETNAME-data.

CamVid dataset need to be download beforehand (available here) and the structure must follow:

<data_dir>/train/                       % Train images folder
<data_dir>/trainannot/                  % Train labels folder
<data_dir>/val/                         % Validation images folder
<data_dir>/valannot/                    % Validation labels folder
<data_dir>/test/                        % Test images folder
<data_dir>/testannot/                   % Test labels folder
<data_dir>/train.txt                    % List training samples
<data_dir>/val.txt                      % List validation samples
<data_dir>/test.txt                     % List test samples
...

Running the code

Training

First, to train a baseline model, create a config.yaml file adapted to your dataset. You can find examples in confidnet/confs/. Don't forget to set the output_folder entry to a path of your own. (N.B: if the subfolder doesn't exist yet, the script will create one). Then, simply execute the following command:

$ cd ConfidNet/confidnet
$ python3 train.py -c confs/your_config_file.yaml 

It will create an output folder located as indicated in your config.yaml. This folder includes model weights, train/val split used, a copy of your config file and tensorboard logs.

By default, if the output folder is already existing, training will load last weights epoch and will continue. If you want to force restart training, simply add -f as argument

$ cd ConfidNet/confidnet
$ python3 train.py -c confs/your_config_file.yaml -f

When training ConfidNet, don't forget to add the folder path of your baseline model in your config.yaml:

...
model:
    name: vgg16_selfconfid_classic
    resume: /path/to/weights_folder/model_epoch_040.ckpt
    uncertainty:

Same remark if you want to fine-tune ConfidNet, fill the uncertainty entry.

Testing

To test your model, use the following command:

$ cd ConfidNet/confidnet
$ python3 test.py -c path/to/your/experiment/folder/your_config_file.yaml -e NUM_EPOCHS -m METHOD
  • -c: indicate here the config yaml copy saved in the output folder
  • -e: choose model weights to evaluate by their epoch
  • -m: choose the method to compute uncertainty. Available methods are normal (MCP), mc_dropout, trust_score, confidnet.

Results will be printed at the end of the script.

Pre-trained models

Model weights for MNIST and CIFAR-10 datasets used in the paper are available along with this release. Each zip file contains weights for pre-trained baseline model and weights for ConfidNet. If you want to use baseline weights:

  • unzip files respecting folder structure
  • either for baseline or confidnet, each folder contains at least weights + config file
  • fill your config file with the weights folder path
  • train your model as indicated earlier

Acknowledgements

confidnet's People

Contributors

chcorbi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.