locuslab / torchdeq Goto Github PK
View Code? Open in Web Editor NEWModern Fixed Point Systems using Pytorch
License: MIT License
Modern Fixed Point Systems using Pytorch
License: MIT License
Hi,
Thanks again for this library!
Am I right in that n_states
/ indexing
can be used to implement the sparse fixed-point correction of DEQ Optical Flow?
If yes, I am confused about the output in this example:
from torchdeq import get_deq
# Settings from `DEQ Optical Flow` paper
args = {
"n_states": 2,
"f_max_iter": 24,
}
deq = get_deq(args)
print('deq.indexing: ', deq.indexing)
Output: deq.indexing: [12, 12]
Expected output: [8, 16]
(uniformly sample between 0 and 24)
Am I missinterpreting?
Hi, I try to use your framework to build my own DEQ model with some simple fully connected neural network but I keep getting an error that I can't fix.
Can you help resolve these errors? Maybe there is something wrong with the definition of neural network?
Many thanks in advance!
Code:
import torch
import torch.nn as nn
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class FCNN(nn.Module):
def __init__(self, input_dim, neurons_per_layer, output_dim):
super(FCNN, self).__init__()
self.linear1 = nn.Linear(input_dim, neurons_per_layer)
self.linear2 = nn.Linear(neurons_per_layer,neurons_per_layer)
self.out = nn.Linear(neurons_per_layer, output_dim)
def forward(self, x, z):
# z: links' flow x: context
zx = torch.cat((z, x))
z_processed = F.relu(self.linear1(zx))
z_processed = F.relu(self.linear2(z_processed))
return F.relu(self.out(z_processed))
if __name__ == '__main__':
seed = 1
torch.manual_seed(seed)
fcnn = FCNN(input_dim=15, neurons_per_layer=20, output_dim=10)
print(fcnn)
x_, z_ = torch.ones(5), torch.ones(10) # z0
for i in range(10):
z_ = fcnn(x_, z_)
print('z:', z_)
# Let's try a multi-variable DEQ!
deq = get_deq(f_solver='broyden', f_max_iter=20, f_tol=1e-6)
x_, z_ = torch.ones(5), torch.zeros(10) # z0
# f = lambda z: fcnn(x, z)
z_out, info = deq(fcnn, (x_, z_))
Errors:
Traceback (most recent call last):
File "E:\PycharmProjects\torch_gpu\DEQ4TA\FCNN.py", line 55, in
z_out, info = deq(fcnn, (x_, z_))
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\core.py", line 592, in forward
deq_func, z_star = deq_decorator(func, z_star, no_stat=self.no_stat)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 139, in deq_decorator
return func, func.list2vec(z_init)
File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 60, in list2vec
return torch.cat(z_list, dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.
It seems like the best fixed-point estimate z_star = lowest_xest
is only returned when the indexed trajectory is empty.
When one specifies indexing
, they are not getting the best fixed-point estimate.
From the Broyden solver
# Store the solution at the specified index
if indexing and (nstep+1) in indexing:
indexing_list.append(lowest_xest)
# ...
# at least return the lowest value when enabling ``indexing''
if indexing and not indexing_list:
indexing_list.append(lowest_xest)
info = solver_stat_from_info(stop_mode, lowest_dict, trace_dict, lowest_step_dict)
return lowest_xest, indexing_list, info
Note that the best fixed-point estimate z_star = lowest_xest
is ignored in DEQIndexing
_, trajectory, info = self._solve_fixed_point()
If solver nstep > indexing, lowest_xest
is added to trajectory.
Only if nothing was added to the trajectory, lowest_xest
is added.
Which means that the trajectory sometimes contains the best fixed-point estimate lowest_xest
and sometimes not?
Scenario 1: indexing=[8], nstep=5 -> trajectory contains fp_5
Scenario 2: indexing=[8], nstep=10 -> trajectory contains fp_8
Shouldn't the trajectory contain [fp_5, fp_8] (assuming fp_8 is the better estimate)?
Note that indexing defaults to indexing=[f_max_iter]
if not specified otherwise and the best fixed-point estimate is added to the trajectory. So the problem only arises if one specifies indexing
or n_states
, e.g. to implement the fixed-point correction loss.
It is also not a problem in DEQSliced
In all parameters needed by get_deq
function, I find the backward pass b_solver and b_max_iter, what are they used for? As far as I know, the backward pass refers to backpropagation, and it does not need a fixed-point solver...
It's a very nieche problem, but tripped me over big time :')
For model.eval()
, z_pred
will not have tracked gradients (z_pred.requires_gradient==False
).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad
.
import torch
import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer = torch.nn.Linear(10, 10)
# deq
self.deq = get_deq()
apply_norm(self.layer, 'weight_norm')
def implicit_layer(self, x):
return self.layer(x)
def forward(self, x, pos):
z = torch.zeros_like(x)
reset_norm(self.layer)
f = lambda z: self.f(z, pos)
z_pred, info = self.deq(self.implicit_layer, z)
# if model.eval() -> z_pred[-1].requires_grad is False!
energy = z_pred[-1]
forces = -1 * (
torch.autograd.grad(
energy,
# diff with respect to pos
# if you get 'One of the differentiated Tensors appears to not have been used in the graph'
# then because pos is not 'used' to calculate the energy
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
# allow_unused=True,
)[0]
)
return energy, forces
def run(model, eval=False):
if eval:
model.eval()
else:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(10):
x = torch.randn(10, 10)
pos = torch.randn(10, 3)
energy, forces = model(x, pos)
# loss
optimizer.zero_grad()
energy_target = torch.randn(10, 1)
energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
force_target = torch.randn(10, 3)
force_loss = torch.nn.functional.mse_loss(forces, force_target)
loss = energy_loss + force_loss
if not eval:
loss.backward()
optimizer.step()
return True
if __name__ == '__main__':
model = MyModel()
success = run(model, eval=False)
print(f'train success: {success}')
success = run(model, eval=True)
print(f'eval success: {success}')
While model.train()
it will work perfectly well. For model.eval()
we get the error: RuntimeError: One of the differentiated Tensors does not require grad
.
A flag to set such that z_pred[-1].requires_grad
is always True
, even when model.eval()
.
self.deq = get_deq(grad_in_eval=True)
From the paper:
"we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"
"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."
Current implementation:
torchdeq/torchdeq/utils/init.py
Lines 4 to 21 in 4f6bd5f
It seems more appropriate to do this instead to match the paper.
*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)
This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.