Code Monkey home page Code Monkey logo

bayesian-metric-learning's Introduction

Bayesian Metric Learning with Laplace Approximation

πŸ’» Official implementation for "Bayesian Metric Learning for Uncertainty Quantifiation in Image Retrieval" by Frederik Warburg*, Marco Miani*, Silas Brack and SΓΈren Hauberg.

πŸ”₯ tl;dr: We use the Laplace Approximation for the Contrastive Loss to optimize a latent space with Metric Learning.

πŸ“° paper: https://arxiv.org/abs/2302.01332

πŸ“° Abstract: We propose the first Bayesian encoder for metric learning. Rather than relying on neural amortization as done in prior works, we learn a distribution over the network weights with the Laplace Approximation. We actualize this by first proving that the contrastive loss is a valid log-posterior. We then propose three methods that ensure a positive definite Hessian. Lastly, we present a novel decomposition of the Generalized Gauss-Newton approximation. Empirically, we show that our Laplacian Metric Learner (LAM) estimates well-calibrated uncertainties, reliably detects out-of- distribution examples, and yields state-of-the-art predictive performance.

Getting Started

git clone https://github.com/FrederikWarburg/bayesian-metric-learning;
cd bayesian-metric-learning;
git clone https://github.com/IlMioFrizzantinoAmabile/stochman;
cd stochman;
python setup.py develop;
cd ../src;

Your file structure should look like:

bayesian-metric-learning
β”œβ”€β”€ configs             # config files, organised by experiments
β”œβ”€β”€ img                 # figures
β”œβ”€β”€ scripts             # scripts for running code
β”œβ”€β”€ src                 # source code
β”‚   β”œβ”€β”€ datasets        # Code for data
β”‚   β”œβ”€β”€ evaluate        # Code for evaluation
β”‚   β”œβ”€β”€ lightning       # pytorch lightning models (+ baseliens)
|   β”œβ”€β”€ losses          # specialize loss functions
|   β”œβ”€β”€ miners          # miners
|   β”œβ”€β”€ models          # network architectures
β”‚   └── utils           # helpers
β”œβ”€β”€ requirements.txt    # file containing python packages that are required to run code
└── stochman

Train your LAM

cd src;
CUDA_VISIBLE_DEVICES=0 python run.py --config ../configs/fashionmnist/laplace_online_arccos_fix.yaml;

or a baseline model (e.g. the deterministic model)

cd src;
CUDA_VISIBLE_DEVICES=0 python run.py --config ../configs/fashionmnist/deterministic.yaml;

Remember to change the data_dir in the .yaml config file.

Citation

If you find this code useful, please consider citing us:

@article{Warburg2023LAM,
  title={Bayesian Metric Learning for Uncertainty Quantification in Image Retrieval},
  author={Frederik Warburg and Marco Miani and Silas Brack and SΓΈren Hauberg},
  journal={CoRR},
  year={2023}
}
@article{LAE2022,
  title={Laplacian Autoencoders for Learning Stochastic Representations},
  author={Marco Miani and Frederik Warburg and Pablo Moreno-MuΓ±oz and Nicki Skafte Detlefsen and SΓΈren Hauberg},
  journal=Neurips,
  year={2022}
}
@article{software:stochman,
  title={StochMan},
  author={Nicki S. Detlefsen and Alison Pouplin and Cilie W. Feldager and Cong Geng and Dimitris Kalatzis and Helene Hauschultz and Miguel GonzΓ‘lez Duque and Frederik Warburg and Marco Miani and SΓΈren Hauberg},
  journal={GitHub. Note: https://github.com/MachineLearningLifeScience/stochman/},
  year={2021}
}

bayesian-metric-learning's People

Contributors

frederikwarburg avatar ilmiofrizzantinoamabile avatar silasbrack avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.