Code Monkey home page Code Monkey logo

techthiyanes / distinction-maximization-loss Goto Github PK

View Code? Open in Web Editor NEW

This project forked from dlmacedo/distinction-maximization-loss

0.0 0.0 0.0 957 KB

Improve classification accuracy, uncertainty estimation, and out-of-distribution detection (open set recognition) by changing a few lines of code in your project! Perform efficient inferences (i.e., do not increase inference time) without repetitive model training, hyperparameter tuning, or collecting additional data. If you like it, please star.

Shell 8.01% Python 91.99%

distinction-maximization-loss's Introduction

Distinction Maximization Loss (DisMax)

Efficiently Improving Classification Accuracy, Uncertainty Estimation, and Out-of-Distribution Detection Simply Replacing the Loss and Calibrating

We keep single network inference efficiency. No hyperparameter tuning. We need to train only once. SOTA.

Building robust deterministic deep neural networks is still a challenge. On the one hand, some approaches improve out-of-distribution detection at the cost of reducing classification accuracy in some situations. On the other hand, some methods simultaneously increase classification accuracy, out-of-distribution detection, and uncertainty estimation but reduce inference efficiency in addition to requiring training the same model many times to tune hyperparameters. In this paper, we propose training deterministic deep neural networks using our DisMax loss, which works as a drop-in replacement for the commonly used SoftMax loss (i.e., the combination of the linear output layer, the SoftMax activation, and the cross-entropy loss). Starting from IsoMax+ loss, we created novel logits that are based on the distance to all prototypes rather than just the one associated with the correct class. We also propose a novel way to augment images to construct what we call fractional probability regularization. Moreover, we propose a new score to perform out-of-distribution detection and a fast way to calibrate the network after training. Our experiments show that DisMax usually outperforms all current approaches simultaneously in classification accuracy, uncertainty estimation, inference efficiency, and out-of-distribution detection, avoiding hyperparameter tuning and repetitive model training.

Read the full paper: Distinction Maximization Loss: Efficiently Improving Classification Accuracy, Uncertainty Estimation, and Out-of-Distribution Detection Simply Replacing the Loss and Calibrating.

Visit also the repository of our previous work: Entropic Out-of-Distribution Detection.


Use DisMax in your project!!!

Replace the SoftMax loss with the DisMax loss changing few lines of code!

Replace the model classifier last layer with the DisMax loss first part:

class Model(nn.Module):
    def __init__(self):
    (...)
    #self.classifier = nn.Linear(num_features, num_classes)
    self.classifier = losses.DisMaxLossFirstPart(num_features, num_classes)

Replace the criterion by the DisMax loss second part:

model = Model()
#criterion = nn.CrossEntropyLoss()
criterion = losses.DisMaxLossSecondPart(model.classifier)

Preprocess before forwarding in the training loop:

# In the training loop, add the line of code below for preprocessing before forwarding.
inputs, targets = criterion.preprocess(inputs, targets) 
(...)
# The code below is preexistent. Just keep the following lines unchanged!
outputs = model(inputs)
loss = criterion(outputs, targets)

Detect during inference:

# Return the score values during inference.
scores = model.classifier.scores(outputs) 

Run the example:

python example.py

Code

Software requirements

Much code reused from deep_Mahalanobis_detector, odin-pytorch, and entropic-out-of-distribution-detection.

Please, install all package requirments runing the command bellow:

pip install -r requirements.txt

Preparing the data

Please, move to the data directory and run all the prepare data bash scripts:

# Download and prepare out-of-distrbution data for CIFAR10 and CIFAR100 datasets.
./prepare-cifar.sh

Reproducing the experiments

Train and evaluate the classification, uncertainty estimation, and out-of-distribution detection performances:

./run_cifar100_densenetbc100.sh*
./run_cifar100_resnet34.sh*
./run_cifar100_wideresnet2810.sh*
./run_cifar10_densenetbc100.sh*
./run_cifar10_resnet34.sh*
./run_cifar10_wideresnet2810.sh*

Analizing the results

Print the experiment results:

./analize.sh

Citation

Please, cite our papers if you use our loss in your works:

@article{macedo2022distinction,
      title={Distinction Maximization Loss: Efficiently Improving Classification Accuracy, Uncertainty Estimation, and Out-of-Distribution Detection Simply Replacing the Loss and Calibrating}, 
      author={David Macêdo and Cleber Zanchettin and Teresa Ludermir},
      year={2022},
      eprint={2205.05874},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

distinction-maximization-loss's People

Contributors

dlmacedo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.