Code Monkey home page Code Monkey logo

laplace's Introduction

Laplace

Main

The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. The library documentation is available at https://aleximmer.github.io/Laplace.

There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:

@inproceedings{laplace2021,
  title={Laplace Redux--Effortless {B}ayesian Deep Learning},
  author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
          and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
  booktitle={{N}eur{IPS}},
  year={2021}
}

The code to reproduce the experiments in the paper is also publicly available; it provides examples of how to use our library for predictive uncertainty quantification, model selection, and continual learning.

Setup

For full compatibility, install this package in a fresh virtual env. We assume Python >= 3.9 since lower versions are (soon to be) deprecated. PyTorch version 2.0 and up is also required for full compatibility. To install laplace with pip, run the following:

pip install laplace-torch

For development purposes, clone the repository and then install:

# or after cloning the repository for development
pip install -e .
# run tests
pip install -e '.[tests]'
pytest tests/

Example usage

Post-hoc prior precision tuning of diagonal LA

In the following example, a pre-trained model is loaded, then the Laplace approximation is fit to the training data (using a diagonal Hessian approximation over all parameters), and the prior precision is optimized with cross-validation "gridsearch". After that, the resulting LA is used for prediction with the "probit" predictive for classification.

from laplace import Laplace

# Pre-trained model
model = load_map_model()

# User-specified LA flavor
la = Laplace(model, "classification",
             subset_of_weights="all",
             hessian_structure="diag")
la.fit(train_loader)
la.optimize_prior_precision(method="gridsearch", val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx="probit")

Differentiating the log marginal likelihood w.r.t. hyperparameters

The marginal likelihood can be used for model selection [10] and is differentiable for continuous hyperparameters like the prior precision or observation noise. Here, we fit the library default, KFAC last-layer LA and differentiate the log marginal likelihood.

from laplace import Laplace

# Un- or pre-trained model
model = load_model()

# Default to recommended last-layer KFAC LA:
la = Laplace(model, likelihood="regression")
la.fit(train_loader)

# ML w.r.t. prior precision and observation noise
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()

Laplace on foundation models like LLMs

This library also supports Huggingface models and parameter-efficient fine-tuning. See examples/huggingface_examples.py and examples/huggingface_examples.md for the exposition.

First, we need to wrap the pretrained model so that the forward method takes a dict-like input. Note that when you iterate over a Huggingface dataloader, this is what you get by default. Having a dict-like input is nice since different models have different number of inputs (e.g. GPT-like LLMs only take input_ids, while BERT-like ones take both input_ids and attention_mask, etc.). Inside this forward method you can do your usual preprocessing like moving the tensor inputs into the correct device.

class MyGPT2(nn.Module):
    def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
        super().__init__()
        config = GPT2Config.from_pretrained("gpt2")
        config.pad_token_id = tokenizer.pad_token_id
        config.num_labels = 2
        self.hf_model = GPT2ForSequenceClassification.from_pretrained(
            "gpt2", config=config
        )

    def forward(self, data: MutableMapping) -> torch.Tensor:
        device = next(self.parameters()).device
        input_ids = data["input_ids"].to(device)
        attn_mask = data["attention_mask"].to(device)
        output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask)
        return output_dict.logits

Then you can "select" which parameters of the LLM you want to apply the Laplace approximation on, by switching off the gradients of the "unneeded" parameters. For example, we can replicate a last-layer Laplace: (in actual practice, use Laplace(..., subset_of_weights='last_layer', ...) instead, though!)

model = MyGPT2(tokenizer)
model.eval()

# Enable grad only for the last layer
for p in model.hf_model.parameters():
    p.requires_grad = False
for p in model.hf_model.score.parameters():
    p.requires_grad = True

la = Laplace(
    model,
    likelihood="classification",
    # Will only hit the last-layer since it's the only one that is grad-enabled
    subset_of_weights="all",
    hessian_structure="diag",
)
la.fit(dataloader)
la.optimize_prior_precision()

test_data = next(iter(dataloader))
pred = la(test_data)

This is useful because we can apply the LA only on the parameter-efficient finetuning weights. E.g., we can fix the LLM itself, and apply the Laplace approximation only on the LoRA weights. Huggingface will automatically switch off the non-LoRA weights' gradients.

def get_lora_model():
    model = MyGPT2(tokenizer)  # Note we don't disable grad
    config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["c_attn"],  # LoRA on the attention weights
        lora_dropout=0.1,
        bias="none",
    )
    lora_model = get_peft_model(model, config)
    return lora_model

lora_model = get_lora_model()

# Train it as usual here...

lora_model.eval()

lora_la = Laplace(
    lora_model,
    likelihood="classification",
    subset_of_weights="all",
    hessian_structure="diag",
    backend=AsdlGGN,
)

test_data = next(iter(dataloader))
lora_pred = lora_la(test_data)

Applying the LA over only a subset of the model parameters via the Subnetwork Laplace

This example shows how to fit the Laplace approximation over only a subnetwork within a neural network (while keeping all other parameters fixed at their MAP estimates), as proposed in [11]. It also exemplifies different ways to specify the subnetwork to perform inference over.

from laplace import Laplace

# Pre-trained model
model = load_model()

# Examples of different ways to specify the subnetwork
# via indices of the vectorized model parameters
#
# Example 1: select the 128 parameters with the largest magnitude
from laplace.utils import LargestMagnitudeSubnetMask
subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
subnetwork_indices = subnetwork_mask.select()

# Example 2: specify the layers that define the subnetwork
from laplace.utils import ModuleNameSubnetMask
subnetwork_mask = ModuleNameSubnetMask(model, module_names=["layer.1", "layer.3"])
subnetwork_mask.select()
subnetwork_indices = subnetwork_mask.indices

# Example 3: manually define the subnetwork via custom subnetwork indices
import torch
subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])

# Define and fit subnetwork LA using the specified subnetwork indices
la = Laplace(model, "classification",
             subset_of_weights="subnetwork",
             hessian_structure="full",
             subnetwork_indices=subnetwork_indices)
la.fit(train_loader)

Serialization

As with plain torch, we support to ways to serialize data.

One is the familiar state_dict approach. Here you need to save and re-create both model and Laplace. Use this for long-term storage of models and sharing of a fitted Laplace instance.

# Save model and Laplace instance
torch.save(model.state_dict(), "model_state_dict.bin")
torch.save(la.state_dict(), "la_state_dict.bin")

# Load serialized data
model2 = MyModel(...)
model2.load_state_dict(torch.load("model_state_dict.bin"))
la2 = Laplace(model2, "classification",
              subset_of_weights="all",
              hessian_structure="diag")
la2.load_state_dict(torch.load("la_state_dict.bin"))

The second approach is to save the whole Laplace object, including self.model. This is less verbose and more convenient since you have the trained model and the fitted Laplace data stored in one place, but also comes with some drawbacks. Use this for quick save-load cycles during experiments, say.

# Save Laplace, including la.model
torch.save(la, "la.pt")

# Load both
torch.load("la.pt")

Some Laplace variants such as LLLaplace might have trouble being serialized using the default pickle module, which torch.save() and torch.load() use (AttributeError: Can't pickle local object ...). In this case, the dill package will come in handy.

import dill

torch.save(la, "la.pt", pickle_module=dill)

With both methods, you are free to switch devices, for instance when you trained on a GPU but want to run predictions on CPU. In this case, use

torch.load(..., map_location="cpu")

Structure

The laplace package consists of two main components:

  1. The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all', 'subnetwork' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', 'lowrank', 'diag' and 'gp'). This results in ten currently available options: laplace.FullLaplace, laplace.KronLaplace, laplace.DiagLaplace, laplace.FunctionalLaplace the corresponding last-layer variations laplace.FullLLLaplace, laplace.KronLLLaplace, laplace.DiagLLLaplace and laplace.FunctionalLLLaplace (which are all subclasses of laplace.LLLaplace), laplace.SubnetLaplace (which only supports 'full' and 'diag' Hessian approximations) and laplace.LowRankLaplace (which only supports inference over 'all' weights). All of these can be conveniently accessed via the laplace.Laplace function.
  2. The backends in laplace.curvature which provide access to Hessian approximations of the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.utils.feature_extractor) and effectively dealing with Kronecker factors (laplace.utils.matrix).

Finally, the package implements several options to select/specify a subnetwork for SubnetLaplace (as subclasses of laplace.utils.subnetmask.SubnetMask). Automatic subnetwork selection strategies include: uniformly at random (laplace.utils.subnetmask.RandomSubnetMask), by largest parameter magnitudes (LargestMagnitudeSubnetMask), and by largest marginal parameter variances (LargestVarianceDiagLaplaceSubnetMask and LargestVarianceSWAGSubnetMask). In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (ParamNameSubnetMask) or modules (ModuleNameSubnetMask) to perform Laplace inference over.

Extendability

To extend the laplace package, new BaseLaplace subclasses can be designed, for example, Laplace with a block-diagonal Hessian structure. One can also implement custom subnetwork selection strategies as new subclasses of SubnetMask.

Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian approximations to the Laplace approximations. For example, currently the curvature.CurvlinopsInterface based on Curvlinops and the native torch.func (previously known as functorch), curvature.BackPackInterface based on BackPACK and curvature.AsdlInterface based on ASDL are available.

The curvature.CurvlinopsInterface backend is the default and provides all Hessian approximation variants except the low-rank Hessian. For the latter, curvature.AsdlInterface can be used. Note that curvature.AsdlInterface and curvature.BackPackInterface are less complete and less compatible than curvature.CurvlinopsInterface. So, we recommend to stick with curvature.CurvlinopsInterface unless you have a specific need of ASDL or BackPACK.

Documentation

The documentation is available here or can be generated and/or viewed locally:

# assuming the repository was cloned
pip install -e .[docs]
# create docs and write to html
bash update_docs.sh
# .. or serve the docs directly
pdoc --http 0.0.0.0:8080 laplace --template-dir template

Contributing

Pull requests are very welcome. Please follow these guidelines:

  1. Install Laplace via pip install -e ".[dev]" which will install ruff and all requirements necessary to run the tests and build the docs.
  2. Use ruff as autoformatter. Please refer to the following makefile and run it via make ruff. Please note that the order of ruff check --fix and ruff format is important!
  3. Also use ruff as linter. Please manually fix all linting errors/warnings before opening a pull request.
  4. Fully document your changes in the form of Python docstrings, typehinting, and (if applicable) code/markdown examples in the ./examples subdirectory.
  5. Provide as many test cases as possible. Make sure all test cases pass.

Issues, bug reports, and ideas are also very welcome!

References

This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.

laplace's People

Contributors

wiseodd avatar aleximmer avatar edaxberger avatar metodmove avatar runame avatar ludvins avatar metodj avatar elcorto avatar heatdh avatar ruili-pml avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.