Code Monkey home page Code Monkey logo

curvlinops's Introduction

Logo scipy linear operators of deep learning matrices in PyTorch

Python 3.8+ tests Coveralls

This library implements scipy.sparse.linalg.LinearOperators for deep learning matrices, such as

  • the Hessian
  • the Fisher/generalized Gauss-Newton (GGN)
  • the Monte-Carlo approximated Fisher
  • the Fisher/GGN's KFAC approximation (Kronecker-Factored Approximate Curvature)
  • the uncentered gradient covariance (aka empirical Fisher)
  • the output-parameter Jacobian of a neural net and its transpose

Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but on data sets (looping over batches during a matvec operation).

You can plug these linear operators into scipy, while carrying out the heavy lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for such a routine is scipy.sparse.linalg.eigsh that lets you compute a subset of eigen-pairs.

The library also provides linear operator transformations, like taking the inverse (inverse matrix-vector product via conjugate gradients) or slicing out sub-matrices.

Finally, it offers functionality to probe properties of the represented matrices, like their spectral density, trace, or diagonal.

Installation

pip install curvlinops-for-pytorch

Examples

Future ideas

Other features that could be supported in the future include:

Logo mage credits

curvlinops's People

Contributors

f-dangel avatar kourbou avatar ltatzel avatar runame avatar wiseodd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

curvlinops's Issues

Allow arbitrarily-ordered parameters in KFAC

At the moment, the params supplied to KFAC must be in the same order as the NN's parameters, i.e.

model = Sequential(Linear(...), Linear(...))

# supported
params_allowed = [model.0.weight, model.0.bias, model.1.weight, model.1.bias]

# not supported
params_forbidden = [model.1.bias, model.0.bias, model.0.weight, model.1.weight]

While parameters will often be supplied in the correct order, dealing with arbitrary orders should be supported.

Support `BCEWithLogitsLoss`

Requested by @wiseodd for laplace-torch.

This consists of three parts:

  • Support sqrt_hessian for KFAC (type 2)
  • Support sample_grad_output for FisherMC
  • Support draw_label for KFAC (type 1)

Linear operator for KFAC

Implement a linear operator that multiplies with the Kronecker-factorized curvature approximation of the Fisher.

  • The first version will only support networks with parameters in Linear and Conv2d layers
  • There should be an option to treat .weight and .bias jointly or separately.
  • There could be an option for KFAC-expand and KFAC-reduce

[BUG] Scaling of (KFAC) empirical/MC Fisher broken in some cases when using mean reduction for loss

The reduction_factor is assumed to be just the dataset size in the implementation of the empirical Fisher, which is incorrect in the case of MSELoss, BCEWithLogitsLoss, and in some cases also for CrossEntropyLoss (when the model output has more than two dimensions). KFAC with fisher_type="empirical" also requires a change in the scaling, similar to here.

The same applies to FisherMCLinearOperator, but not for KFAC with fisher_type="mc".

[BUG | KFAC] Changing device invalidates parameter mapping

Whenever KFAC is instantiated on GPU, and we call .to_device(device("cpu)), this will invalidate the internal mapping between parameter .data_ptr()s to module names which is needed to identify a parameter's position in the list format. Currently, this bug can silently pass because we do not check in matmat whether all parameter positions are processed.

Feature request: Diagonal estimation algorithms

Similar to Hutchinson's method for trace estimation, one can approximate the diagonal of a matrix from projections onto random vectors, see for instance Equation 16, or Equation 9.

There are no implementations for Hessian diagonal estimation in scipy, so it would be nice to offer such methods through a LinearOperator interface through this library.

Feature request: Trace estimation algorithms

The trace is often used to summarize curvature matrices in second-order methods or for generalization metrics.

I could not find libraries that provide trace estimation methods for scipy.sparse.LinearOperators. The closest library is Nico's matfree which has Hutchinson trace estimation for JAX. pyhessian has Hutchinson trace estimation in PyTorch, but does not use a LinearOperator interface and only considers the Hessian.

So it would be useful to offer trace estimation through a scipy-based linear operator interface through this library.

Possible algorithms are:

  • (Basic) Hutchinson trace estimation (see Section 4)
  • (Advanced) Hutch++ (paper, matlab implementation)
  • (Advanced) NA-Hutch++ (paper): I decided against implementing NA-Hutch++, since it does not offer memory savings over Hutch++. According to the paper, non-adaptive methods have practical benefits when used with batch-multiplies of the linear operator. The linear operators offered by this library however do only support efficient matvecs (matmats are for loops) and hence do not allow to leverage this benefit. Another point against implementing and maintaining this method is that according to the meyer2020hutch paper, NA-Hutch++ "tends to perform slightly worse in our experiments."

`A.torch_matvec()` changes the shape of tensors

When A is a KFACLinearOperator or KFACInverseLinearOperator for model parameters given by a list of tensors, the expected behaviour when applying A to a list of tensors of the same shape would be that it returns a list of the same shape. This is not the case:

params = [torch.zeros(5,4), torch.zeros(5, 3, 3, 3)]
A = KFACLinearOperator(
    ...
    params,
    ...
)
v = [torch.randn_like(p) for p in params]
y = A.torch_matvec(v)
print([y_elem.shape for y in y_elem])
# Expected output: same shapes as `v` and `params`:
>>> [(5, 4), (5, 3, 3, 3)]
# Actual output: flattened to a 2D tensor:
>>> [(5, 4), (5, 27)]

Add option for heuristic and exact damping to `KFACInverseLinearOperator`

There are two different ways to set the damping for the KFACInverseLinearOperator that we should implement:

  1. 'Heuristic' damping, introduced in section 6.3 in the original K-FAC paper.
  2. 'Exact' damping, which can be efficiently implemented for Kronecker factored matrices, e.g. see equation (21) in Grosse et al., 2023.

There is also an 'adaptive' damping scheme, as described in the original K-FAC paper, but I do not plan to implement this for now.

Add `state_dict` functionality to `KFACLinearOperator` and `KFACInverseLinearOperator`

Since KFACLinearOperator and KFACInverseLinearOperator both have state that is potentially expensive to compute, it is often convenient to store the (inverted) Kronecker factors to disk to save computation. To implement this, it makes sense to add a state_dict method to them, together with load_state_dict and a classmethod from_state_dict.

Error when installing `curvlinops` as dependency

From laplace-torch Github actions:

Collecting curvlinops-for-pytorch@ git+https://github.com/f-dangel/curvlinops (from laplace-torch==0.1a2)
  Cloning https://github.com/f-dangel/curvlinops to /tmp/pip-install-c9oemcf1/curvlinops-for-pytorch_8092676930d443ea8493f3627b08db31
  Running command git clone --filter=blob:none --quiet https://github.com/f-dangel/curvlinops /tmp/pip-install-c9oemcf1/curvlinops-for-pytorch_8092676930d443ea8493f3627b08db31
  Resolved https://github.com/f-dangel/curvlinops to commit 84c2ce[75](https://github.com/aleximmer/Laplace/actions/runs/9799308281/job/27059312114?pr=202#step:4:76)5c2c5325330e115a0d46f13ef15cfd2e
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'error'
  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [17 lines of output]
      Traceback (most recent call last):
        File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
          main()
        File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-oh_[92](https://github.com/aleximmer/Laplace/actions/runs/9799308281/job/27059312114?pr=202#step:4:93)17v/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 327, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=[])
        File "/tmp/pip-build-env-oh_9217v/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 297, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-oh_9217v/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 497, in run_setup
          super().run_setup(setup_script=setup_script)
        File "/tmp/pip-build-env-oh_9217v/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 313, in run_setup
          exec(code, locals())
        File "<string>", line 9, in <module>
      ModuleNotFoundError: No module named 'packaging'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

Verify that the order of the mini-batches is deterministic where it matters

Some linear operators like FisherMCLinearOperator rely on deterministic mini-batches, i.e. that shuffle=False for the used data loader. Otherwise, there will be the following error which does not point out the cause of the issue:

RuntimeError: Check for deterministic matvec failed.

To avoid this we could check that shuffle=False upon construction.

Make less/no assumption about the data

Currently, curvlinops assumes that X and y are both torch.Tensor. This works for the old deep learning paradigm. But with the rise of LLMs and other complicated models, one should not make that assumption.

For instance, the input of a Huggingface model is a UserDict:

data = UserDict({
    'input_ids': torch.LongTensor(...),
    'attention_mask': torch.LongTensor(...),
    'labels': torch.LongTensor(...)
})

In this case, one can extract X and y via (example from laplace-torch)

if isinstance(data, UserDict)  or isinstance(data, dict): # To support Huggingface dataset
    X, y = data, data['labels'].to(self._device)

However, curvlinops is strongly assuming X to be a tensor, see for example here and here and here.

I think the best way to circumvent this is to not assume anything about X (preferably also y, e.g. for multi-output models).

Feature request: add method to convert to matrix

To convert a linear operator to an explicit matrix, we need the following one-liner:

matrix = operator @ np.eye(operator.shape[-1])

SciPy also has the aslinearoperator function, that converts a matrix to an operator.

What do you think about adding a method that does the converse, i.e.

matrix = operator.asmatrix()

It would basically involve the one-liner above. Pros, cons?

Improve efficiency of MC Fisher

Currently, we loop over the data points and MC samples, which is the 'best' way in the sense of FLOPs (batch_size * mc_samples backpropagations), but suffers from poor parallelization. We could use functorch's vmap, but currently the library relies mostly on autograd.grad and I would like to keep it this way as there are certain benefits (e.g. recycling the computation graph when multiplying onto multiple vectors, block-diagonal approximations).

Another way I propose to implement is phrasing multiplication with the MC Fisher as a GGN-vector product where the loss function is such that it's Hessian corresponds to the summed outer product of sampled gradients.
This uses more FLOPs (batch_size * C backpropagations), but does not require a loop over batch_size. Therefore, this should often yield better run time. One downside of this approach is that it costs as much as multiplying with the exact GGN. This renders the motivation for having an MC-sampled version (less accurate, same cost) weaker.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.