Code Monkey home page Code Monkey logo

torchdeq's Issues

Unexpected behaviour of indexing?

Hi,

Thanks again for this library!

  1. Am I right in that n_states / indexing can be used to implement the sparse fixed-point correction of DEQ Optical Flow?

  2. If yes, I am confused about the output in this example:

from torchdeq import get_deq

# Settings from `DEQ Optical Flow` paper
args = {
    "n_states": 2,
    "f_max_iter": 24,
}

deq = get_deq(args)

print('deq.indexing: ', deq.indexing)

Output: deq.indexing: [12, 12]
Expected output: [8, 16] (uniformly sample between 0 and 24)

Am I missinterpreting?

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.

Hi, I try to use your framework to build my own DEQ model with some simple fully connected neural network but I keep getting an error that I can't fix.
Can you help resolve these errors? Maybe there is something wrong with the definition of neural network?

Many thanks in advance!

Code:

import torch
import torch.nn as nn

from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
import os


import matplotlib.pyplot as plt
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F


class FCNN(nn.Module):
    def __init__(self, input_dim, neurons_per_layer, output_dim):
        super(FCNN, self).__init__()

        self.linear1 = nn.Linear(input_dim, neurons_per_layer)
        self.linear2 = nn.Linear(neurons_per_layer,neurons_per_layer)
        self.out = nn.Linear(neurons_per_layer, output_dim)

    def forward(self, x, z):
        # z: links' flow  x: context

        zx = torch.cat((z, x))

        z_processed = F.relu(self.linear1(zx))
        z_processed = F.relu(self.linear2(z_processed))

        return F.relu(self.out(z_processed))


if __name__ == '__main__':

    seed = 1
    torch.manual_seed(seed)

    fcnn = FCNN(input_dim=15, neurons_per_layer=20, output_dim=10)
    print(fcnn)

    x_, z_ = torch.ones(5), torch.ones(10)  # z0
    for i in range(10):
        z_ = fcnn(x_, z_)
        print('z:', z_)

    # Let's try a multi-variable DEQ!
    deq = get_deq(f_solver='broyden', f_max_iter=20, f_tol=1e-6)

    x_, z_ = torch.ones(5), torch.zeros(10)  # z0
    # f = lambda z: fcnn(x, z)
    z_out, info = deq(fcnn, (x_, z_))

Errors:
Traceback (most recent call last):
File "E:\PycharmProjects\torch_gpu\DEQ4TA\FCNN.py", line 55, in
z_out, info = deq(fcnn, (x_, z_))
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\core.py", line 592, in forward
deq_func, z_star = deq_decorator(func, z_star, no_stat=self.no_stat)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 139, in deq_decorator
return func, func.list2vec(z_init)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 60, in list2vec
return torch.cat(z_list, dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.

Fixed-point is not returned when indexing is set

It seems like the best fixed-point estimate z_star = lowest_xest is only returned when the indexed trajectory is empty.
When one specifies indexing, they are not getting the best fixed-point estimate.

Relevant Code

From the Broyden solver

# Store the solution at the specified index
if indexing and (nstep+1) in indexing:
    indexing_list.append(lowest_xest)

# ...

# at least return the lowest value when enabling  ``indexing''
if indexing and not indexing_list:
    indexing_list.append(lowest_xest)

info = solver_stat_from_info(stop_mode, lowest_dict, trace_dict, lowest_step_dict)
return lowest_xest, indexing_list, info

Note that the best fixed-point estimate z_star = lowest_xest is ignored in DEQIndexing

_, trajectory, info = self._solve_fixed_point()

Example

If solver nstep > indexing, lowest_xest is added to trajectory.
Only if nothing was added to the trajectory, lowest_xest is added.
Which means that the trajectory sometimes contains the best fixed-point estimate lowest_xest and sometimes not?

Scenario 1: indexing=[8], nstep=5 -> trajectory contains fp_5
Scenario 2: indexing=[8], nstep=10 -> trajectory contains fp_8
Shouldn't the trajectory contain [fp_5, fp_8] (assuming fp_8 is the better estimate)?

Note that indexing defaults to indexing=[f_max_iter] if not specified otherwise and the best fixed-point estimate is added to the trajectory. So the problem only arises if one specifies indexing or n_states, e.g. to implement the fixed-point correction loss.
It is also not a problem in DEQSliced

Custom autograd fails with torchdeq in eval mode

It's a very nieche problem, but tripped me over big time :')

Issue

For model.eval() , z_pred will not have tracked gradients (z_pred.requires_gradient==False).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad.

Minimal example


import torch

import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

        # deq
        self.deq = get_deq()
        apply_norm(self.layer, 'weight_norm')

    def implicit_layer(self, x):
        return self.layer(x)
    
    def forward(self, x, pos):

        z = torch.zeros_like(x)

        reset_norm(self.layer)

        f = lambda z: self.f(z, pos)

        z_pred, info = self.deq(self.implicit_layer, z)
        
        # if model.eval() -> z_pred[-1].requires_grad is False!
        energy = z_pred[-1]
        forces = -1 * (
            torch.autograd.grad(
                energy,
                # diff with respect to pos
                # if you get 'One of the differentiated Tensors appears to not have been used in the graph'
                # then because pos is not 'used' to calculate the energy
                pos, 
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
                # allow_unused=True, 
            )[0]
        )

        return energy, forces


def run(model, eval=False):

    if eval:
        model.eval()
    else:
        model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for step in range(10):
        x = torch.randn(10, 10)
        pos = torch.randn(10, 3)
        energy, forces = model(x, pos)
        
        # loss
        optimizer.zero_grad()
        energy_target = torch.randn(10, 1)
        energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
        force_target = torch.randn(10, 3)
        force_loss = torch.nn.functional.mse_loss(forces, force_target)
        loss = energy_loss + force_loss

        if not eval:
            loss.backward()
            optimizer.step()
    
    return True

if __name__ == '__main__':
    model = MyModel()
    success = run(model, eval=False)
    print(f'train success: {success}')
    success = run(model, eval=True)
    print(f'eval success: {success}')

While model.train() it will work perfectly well. For model.eval() we get the error: RuntimeError: One of the differentiated Tensors does not require grad.

Desired behaviour

A flag to set such that z_pred[-1].requires_grad is always True, even when model.eval().
self.deq = get_deq(grad_in_eval=True)

Implementation of `torchdeq.utils.mixed_init` different from original paper

From the paper:
"we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"

"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."

Current implementation:

def mixed_init(z_shape, device=None):
"""
Initializes a tensor with a shape of `z_shape` with half Gaussian random values and hald zeros.
Proposed in the paper, `Path Independent Equilibrium Models Can Better Exploit Test-Time Computation <https://arxiv.org/abs/2211.09961>`_,
for better path independence.
Args:
z_shape (tuple): Shape of the tensor to be initialized.
device (torch.device, optional): The desired device of returned tensor. Default None.
Returns:
torch.Tensor: A tensor of shape `z_shape` with values randomly initialized and zero masked.
"""
z_init = torch.randn(*z_shape, device=device)
mask = torch.zeros_like(z_init, device=device).bernoulli_(0.5)
return z_init * mask

It seems more appropriate to do this instead to match the paper.

*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)

This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.