Code Monkey home page Code Monkey logo

torchdeq's Introduction

TorchDEQ Logo

TorchDEQ: A Library for Deep Equilibrium Models

License pypi Documentation Status

DocumentationColab TutorialDEQ ZooRoadmapCitation

Introduction

Deep Equilibrium Models, or DEQs, a recently developed class of implicit neural networks, merge the concepts of fixed point systems with modern deep learning. Fundamentally, DEQ models establish their output based on the equilibrium of nonlinear systems. This can be represented as:

$$\mathbf{z}^\star=f_\theta(\mathbf{z}^\star, \mathbf{x})$$

Here, $\mathbf{x}$ is the input fed into the network, while $\mathbf{z}^\star$ stands as its output.

Enter TorchDEQ - a fully featured, out-of-the-box, and PyTorch-based library tailored for the design and deployment of DEQs. It provides intuitive, decoupled, and modular interfaces to customize general-purpose DEQs for arbitrary tasks, all with just a handful of code lines.

Dive into the world of DEQ with TorchDEQ! Craft your own DEQ effortlessly in just a single line of code. Kickstart your journey with our Colab Tutorial — best enjoyed with a comforting cup of tea!

Installation

  • Through pip.

    pip install torchdeq
  • From source.

    git clone https://github.com/locuslab/torchdeq.git && cd torchdeq
    pip install -e .

Quick Start

  • Automatic arg parser decorator. You can call this function to add commonly used DEQ args to your program.
add_deq_args(parser)
  • Automatic DEQ instantiation. Call get_deq to get your DEQ layer in a single line! It's highly decoupled implementation agnostic to your model design.
deq = get_deq(args)
  • Easy DEQ forward. Even for a multi-equilibria system, you can execute your DEQ forward in a single line!
# Assume f is a function of three tensors a, b, c.
def fn(a, b, c):
    # Do something here...
    # Having the same input and output tensor shapes.
    return a, b, c

# A callable object (`fn` here) that defines your fixed point system.
# `fn` can be a functor defined in your Pytorch forward function.
# A functor can take your input injection from the local variables. 
# You can also pass a Pytorch Module into the DEQ class.
z_out, info = deq(fn, (a0, b0, c0))
  • Automatic DEQ backward. Gradients (both exact and inexact grad) are tracked automatically! Working with TorchDEQ is the same as other standard PyTorch operators. Just post-process z_out as normal tensors!

Contributions

We warmly welcome contributions to TorchDEQ from the community! If you have suggestions for improving the library, introducing new features, or identifying and fixing bugs, please open an issue to discuss with us! Once a direction has been discussed, we can proceed to build, test, and submit a pull request (PR) to TorchDEQ together. Keep a PR clean, well-tested, and have a single focus! While numerical errors and stability may seem minor initially, they can culminate in significant effects over time.

We have provided a preliminary roadmap for the development of this library and are always open to fresh perspectives. Feel free to reach out for questions, discussions, or library developments! Here is my email.

Logo Explained

The logo we’ve chosen draws inspiration from the ancient symbol, Ouroboros, a powerful emblem depicting a serpent or dragon eternally consuming its own tail. Unearthed in the tomb of Tutankhamun, the Ouroboros symbolizes the cyclicality of time, embodying both creation and destruction, inception and conclusion. It’s a profound representation of infinity and wholeness, transcending various mythologies and philosophies across time.

For DEQ models, our choice of logo bears a metaphorical weight. The dragon, denoting $f(\mathbf{x})$, biting its tail, representing $\mathbf{x}$, paints a vivid picture of a function attaining a fixed point. It's a metaphor layered with meaning, visualizing the attainment of stability, illustrated by the dragon completing its circle by biting its tail. This symbol is not just a snapshot of equilibrium; it's a dynamic representation of the infinite nature inherent in DEQ models.

Citation

@misc{torchdeq,
    author = {Zhengyang Geng and J. Zico Kolter},
    title = {TorchDEQ: A Library for Deep Equilibrium Models},
    year = {2023},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/locuslab/torchdeq}},
}

Acknowledgements

This codebase is largely inspired by remarkable projects from the community. We would like to sincerely thank DEQ, DEQ-Flow, PyTorch, and scipy for their awesome open source.

torchdeq's People

Contributors

eltociear avatar gsunshine avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

torchdeq's Issues

Implementation of `torchdeq.utils.mixed_init` different from original paper

From the paper:
"we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"

"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."

Current implementation:

def mixed_init(z_shape, device=None):
"""
Initializes a tensor with a shape of `z_shape` with half Gaussian random values and hald zeros.
Proposed in the paper, `Path Independent Equilibrium Models Can Better Exploit Test-Time Computation <https://arxiv.org/abs/2211.09961>`_,
for better path independence.
Args:
z_shape (tuple): Shape of the tensor to be initialized.
device (torch.device, optional): The desired device of returned tensor. Default None.
Returns:
torch.Tensor: A tensor of shape `z_shape` with values randomly initialized and zero masked.
"""
z_init = torch.randn(*z_shape, device=device)
mask = torch.zeros_like(z_init, device=device).bernoulli_(0.5)
return z_init * mask

It seems more appropriate to do this instead to match the paper.

*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)

This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.

Custom autograd fails with torchdeq in eval mode

It's a very nieche problem, but tripped me over big time :')

Issue

For model.eval() , z_pred will not have tracked gradients (z_pred.requires_gradient==False).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad.

Minimal example


import torch

import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

        # deq
        self.deq = get_deq()
        apply_norm(self.layer, 'weight_norm')

    def implicit_layer(self, x):
        return self.layer(x)
    
    def forward(self, x, pos):

        z = torch.zeros_like(x)

        reset_norm(self.layer)

        f = lambda z: self.f(z, pos)

        z_pred, info = self.deq(self.implicit_layer, z)
        
        # if model.eval() -> z_pred[-1].requires_grad is False!
        energy = z_pred[-1]
        forces = -1 * (
            torch.autograd.grad(
                energy,
                # diff with respect to pos
                # if you get 'One of the differentiated Tensors appears to not have been used in the graph'
                # then because pos is not 'used' to calculate the energy
                pos, 
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
                # allow_unused=True, 
            )[0]
        )

        return energy, forces


def run(model, eval=False):

    if eval:
        model.eval()
    else:
        model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for step in range(10):
        x = torch.randn(10, 10)
        pos = torch.randn(10, 3)
        energy, forces = model(x, pos)
        
        # loss
        optimizer.zero_grad()
        energy_target = torch.randn(10, 1)
        energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
        force_target = torch.randn(10, 3)
        force_loss = torch.nn.functional.mse_loss(forces, force_target)
        loss = energy_loss + force_loss

        if not eval:
            loss.backward()
            optimizer.step()
    
    return True

if __name__ == '__main__':
    model = MyModel()
    success = run(model, eval=False)
    print(f'train success: {success}')
    success = run(model, eval=True)
    print(f'eval success: {success}')

While model.train() it will work perfectly well. For model.eval() we get the error: RuntimeError: One of the differentiated Tensors does not require grad.

Desired behaviour

A flag to set such that z_pred[-1].requires_grad is always True, even when model.eval().
self.deq = get_deq(grad_in_eval=True)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.

Hi, I try to use your framework to build my own DEQ model with some simple fully connected neural network but I keep getting an error that I can't fix.
Can you help resolve these errors? Maybe there is something wrong with the definition of neural network?

Many thanks in advance!

Code:

import torch
import torch.nn as nn

from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
import os


import matplotlib.pyplot as plt
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F


class FCNN(nn.Module):
    def __init__(self, input_dim, neurons_per_layer, output_dim):
        super(FCNN, self).__init__()

        self.linear1 = nn.Linear(input_dim, neurons_per_layer)
        self.linear2 = nn.Linear(neurons_per_layer,neurons_per_layer)
        self.out = nn.Linear(neurons_per_layer, output_dim)

    def forward(self, x, z):
        # z: links' flow  x: context

        zx = torch.cat((z, x))

        z_processed = F.relu(self.linear1(zx))
        z_processed = F.relu(self.linear2(z_processed))

        return F.relu(self.out(z_processed))


if __name__ == '__main__':

    seed = 1
    torch.manual_seed(seed)

    fcnn = FCNN(input_dim=15, neurons_per_layer=20, output_dim=10)
    print(fcnn)

    x_, z_ = torch.ones(5), torch.ones(10)  # z0
    for i in range(10):
        z_ = fcnn(x_, z_)
        print('z:', z_)

    # Let's try a multi-variable DEQ!
    deq = get_deq(f_solver='broyden', f_max_iter=20, f_tol=1e-6)

    x_, z_ = torch.ones(5), torch.zeros(10)  # z0
    # f = lambda z: fcnn(x, z)
    z_out, info = deq(fcnn, (x_, z_))

Errors:
Traceback (most recent call last):
File "E:\PycharmProjects\torch_gpu\DEQ4TA\FCNN.py", line 55, in
z_out, info = deq(fcnn, (x_, z_))
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\core.py", line 592, in forward
deq_func, z_star = deq_decorator(func, z_star, no_stat=self.no_stat)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 139, in deq_decorator
return func, func.list2vec(z_init)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 60, in list2vec
return torch.cat(z_list, dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.