Comments (2)
One really annoying thing about torch.no_grad is that it is not traceable. JAX has a stop_gradient primitive that operates on Tensors so it does become traceable: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html
However, I think it's useful to be able to use something like torch.no_grad inside of a transform. For example, one pattern I've seen is:
def f(x):
with torch.no_grad():
shift = x.mean()
return x - shift
Proposal:
- torch.no_grad does affect grad/vjp transforms. Any computation that happens within torch.no_grad is invisible to vjp/grad
- If a user calls grad/vjp inside of torch.no_grad, we raise a warning that explains that their gradients will be 0. (Or maybe this should be an error?)
- For tracing... either we introduce something like stop_gradient(Tensor) -> Tensor or figure out how to "trace" torch.no_grad. This sounds a bit like factory function tracing and could potentially be done with a mode-based dispatch key
Alternatives:
- functorch just straight up ignores torch.no_grad
- we introduce a functorch.stop_gradient or something
from functorch.
New proposal: here's what I think the semantics should be.
Case 1: grad
gets called inside torch.no_grad.
grad
should ignore torch.no_grad because it's "creating a new level of autograd above the current level"- Another way to think about this is that
grad(f)
is a "function transform": its result should not be affected by context managers that are outside of the functionf
Case 2: torch.no_grad gets called inside `grad
grad
should respect torch.no_grad
How does one actually implement this? We can probably do something with a mode stack here...
from functorch.
Related Issues (20)
- Add pytorch 1.13.1 compatibility HOT 3
- Unit Test Error When Testing vmap With Missing Module "autograd_function_db" HOT 7
- Will pmap be supported in functorh? HOT 2
- How to get only the last few layers' gradident? HOT 2
- [Question] Packaging policy for `functorch` and `torch.func` HOT 5
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.