Code Monkey home page Code Monkey logo

Comments (2)

zou3519 avatar zou3519 commented on June 20, 2024

One really annoying thing about torch.no_grad is that it is not traceable. JAX has a stop_gradient primitive that operates on Tensors so it does become traceable: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html

However, I think it's useful to be able to use something like torch.no_grad inside of a transform. For example, one pattern I've seen is:

def f(x):
    with torch.no_grad():
        shift = x.mean()
    return x - shift

Proposal:

  • torch.no_grad does affect grad/vjp transforms. Any computation that happens within torch.no_grad is invisible to vjp/grad
  • If a user calls grad/vjp inside of torch.no_grad, we raise a warning that explains that their gradients will be 0. (Or maybe this should be an error?)
  • For tracing... either we introduce something like stop_gradient(Tensor) -> Tensor or figure out how to "trace" torch.no_grad. This sounds a bit like factory function tracing and could potentially be done with a mode-based dispatch key

Alternatives:

  • functorch just straight up ignores torch.no_grad
  • we introduce a functorch.stop_gradient or something

from functorch.

zou3519 avatar zou3519 commented on June 20, 2024

New proposal: here's what I think the semantics should be.

Case 1: grad gets called inside torch.no_grad.

  • grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level"
  • Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f

Case 2: torch.no_grad gets called inside `grad

  • grad should respect torch.no_grad

How does one actually implement this? We can probably do something with a mode stack here...

from functorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.