Comments (12)
It sounds like the ask is that you'd like to generate the autograd graph AOT?
There's currently some prototype functionality (well, arguably the entire repo is prototype :P ) for tracing through the C++ dispatcher (which includes the autograd engine).
See the examples in the nnc
folder: https://github.com/zou3519/functorch/blob/main/examples/nnc/simple_function.py
Specifically, make_fx
will do this. However, it will 1. specialize on your inputs, and 2. return functions at the aten level.
from functorch.
... a very exciting prototype π
Our ideal outcome would be to take a TorchScript model that involves a torch.autograd.grad
call and turn it into a TorchScript model that doesn't. The main reason for this is inference β being tied to the autograd engine severely limits our options for accelerated inference (TRTorch, for example, but also TorchScript optimizations). It would also be great to be able to train without walking the autograd graph twice, but that's secondary.
From what you're saying, it sounds like tracing through the C++ dispatcher is exactly right for this. To clarify a couple of points:
- What, exactly, is NNC? Having a hard time figuring out how it relates to TorchScript, etc.
- Would
torch.jit.script(make_fx(model))
give me aScriptModule
I can use like any other? (model
here would already contain calls totorch.autograd.grad
.) In particular, aScriptModel
I cantorch::jit::load
into a C++ deployment context? - When you say it will "specialize on your inputs," you mean specialize to dtype + device? Or something else?
- Do "functions at the aten level" have any particular limitations?
Really appreciate your taking the time!
from functorch.
Our ideal outcome would be to take a TorchScript model that involves a torch.autograd.grad call and turn it into a TorchScript model that doesn't.
I'm not sure that we'll able to start from a Torchscript model and trace out the autograd graph. I think it's possible in theory, but would require a bit more infrastructure that I'm not currently working on.
What, exactly, is NNC? Having a hard time figuring out how it relates to TorchScript, etc.
NNC (neural network compiler, also called tensorexpr
in our code base) is a codegen compiler (kind of in the veins of TVM or Halide). Currently, in this repo, we're using it primarily for overhead reductions. The idea is that if your operations/tensors are small enough, PyTorch framework overhead is often the dominating factor. Generating a single binary blob that does your computation can lead to significant speedups.
Would torch.jit.script(make_fx(model)) give me a ScriptModule I can use like any other?
Currently, somewhat awkwardly, it doesn't work. This is since I'm tracing out into torch.aten.ops
, and there's some awkward mismatches there. However, torch.jit.trace(make_fx(model))
does, and will give you a ScriptModule you can use like any other.
When you say it will "specialize on your inputs," you mean specialize to dtype + device?
And shape. Unfortunately, within the autograd engine, there's a lot of instances where the autograd rules will depend on the shapes of the tensor.
Do "functions at the aten level" have any particular limitations?
Not really - it's basically just the C++ API instead of the Python API.
from functorch.
I'm not sure that we'll able to start from a Torchscript model and trace out the autograd graph. I think it's possible in theory, but would require a bit more infrastructure that I'm not currently working on.
Would starting with a Python model work?
NNC (neural network compiler)...
Very cool! I'm assuming, then, that normal TorchScript optimization already uses NNC everywhere it can, and that using nnc_jit
wouldn't give you any particular speed up?
And shape. Unfortunately, within the autograd engine, there's a lot of instances where the autograd rules will depend on the shapes of the tensor.
Hm, I see. For example, if I have:
def f(x):
outshape = x.shape[:-1]
x = x.reshape(-1, x.shape[-1])
x = 2 * x
return x.reshape(outshape + (x.shape[-1],))
I will not be able to generalize to different leading shapes on x
((10, 3)
vs (13, 3)
, for example) even if they have the same number of dimensions?
from functorch.
Would starting with a Python model work?
Yes.
Very cool! I'm assuming, then, that normal TorchScript optimization already uses NNC everywhere it can, and that using nnc_jit wouldn't give you any particular speed up?
nnc_jit
is currently targeted at the overhead-dominated CPU use cases, so I suspect that it might not be very useful for you. But in those use cases, it does some things like lower the entire model to a binary blob, so it can be substantially faster than Torchscript then.
I will not be able to generalize to different leading shapes on x ((10, 3) vs (13, 3), for example) even if they have the same number of dimensions?
Currently, no. This is something that we're aware of, and we're trying to figure out ways of addressing this.
from functorch.
Aha β can nnc_jit
work with models of "real" complexity, inlining more complicated operations? (Say, tensordot
, permute
, cat
, whatever.) Or is it limited to elementwise operations that are already supported for fusion?
Currently, no. This is something that we're aware of, and we're trying to figure out ways of addressing this.
π How strong are the shape dependencies in autograd, in a rough sense? In the example I gave above, for example, if I manually rewrote the fx
graph to use the right dynamic shape as the argument to reshape
, instead of the traced constant, would you expect it to give the right result for a modified batch dimension? (Same number of dimensions.)
from functorch.
can nnc_jit work with models of "real" complexity, inlining more complicated operations? (Say, tensordot, permute, cat, whatever.)
Yes - we currently have lowerings for permute
and cat
. Generating fast lowerings/schedules for things like tensordot
is significantly harder, so we've usually just been calling PyTorch C++ implementations for those. In cases where you are primarily overhead bound this can still be a significant win. Once again though, this stuff currently only works on CPU.
How strong are the shape dependencies in autograd, in a rough sense? In the example I gave above, for example, if I manually rewrote the fx graph to use the right dynamic shape as the argument to reshape, instead of the traced constant, would you expect it to give the right result for a modified batch dimension?
It's hard to say - there are 2 sources of shape specialization in autograd. The first one is the user-facing stuff, where you can generally get around this stuff by re-implementing it with stuff like torch.flatten
instead of explicit accesses to shapes. However, the harder stuff is shape specialization within C++.
It's possible that if we're just changing the batch dimension we can avoid it, but that requires some investigation that I haven't done.
from functorch.
...tensordot is significantly easier,...
Do you mean harder?
Yes - we currently have lowerings for permute and cat. Generating fast lowerings/schedules for things like tensordot is significantly easier, so we've usually just been calling PyTorch C++ implementations for those. In cases where you are primarily overhead bound this can still be a significant win. Once again though, this stuff currently only works on CPU.
Awesome!
However, the harder stuff is shape specialization within C++.
Would you guess that this is usually specialization for speed, or does it affect correctness? (Context: for us, we have "batch" dimensions that constantly change during inference, but don't change much.)
from functorch.
Do you mean harder?
whoops
Would you guess that this is usually specialization for speed, or does it affect correctness?
Hmm, it's often stuff like pulling out the shapes, and then doing an explicit reshape using those shapes or something like that.
from functorch.
Aha ok, so stuff that really is risky for correctness. Still, would be interesting to see if those problems come up for our networks. (I'd mostly worry about matmuls, tensordot
s, and einsum
s β does that sound right?)
from functorch.
I'm not actually totally sure when it comes up - perhaps it would just work in some cases? I'll make a note to investigate that further at some point.
from functorch.
π if you ever do end up looking into any of these things, would be very curious to hear what you find β I will also probably play around with functorch
for this once I have the time.
Thanks very much for the answers + all the great work on compilers for PyTorch!
from functorch.
Related Issues (20)
- Will pmap be supported in functorhοΌ HOT 2
- How to get only the last few layers' gradident? HOT 2
- [Question] Packaging policy for `functorch` and `torch.func` HOT 5
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
- Tensor.nonzero_static fails on GPU inside torch.func.vmap HOT 1
- Strange behaviour of autograd.functional.jacobian when vectorize=True and strategy=βforward-modeβ
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.