Comments (7)
PyTorch has really weird special cases around scalar tensors. Not sure how easy it will be for us to replicate the semantics.
from functorch.
Not sure how easy it will be for us to replicate the semantics.
I suspect we're going to need to do so regardless.
from functorch.
@zou3519 - what's the expected behavior in this case? I'm trying to figure out what this should actually do.
Reasoning it out:
x = torch.randn((3))
gives a tensor with shape(3) (dim=1)
.f(x) : Tensor<{arbitrary shape; dim >= 1}> -> Tensor<{resultant shape; dim >= 1}>
vmap(f(x)) : Tensor<{arbitrary shape + 1 dim}> -> Tensor<{resultant shape + 1 dim}>
The types/shapes don't match up for x
. E.g. x
only has a single dimension when passed to vmap(f(x))
and would thus fail? Or should x
be promoted to (3, 1)
automatically (or (1, 3)
) and then passed to vmap(f(x))
?
This seems like an issue in the general dispatcher (rather than a batching rule registration style). E.g. when you promote a function using vmap
the resultant function should check for scalar inputs (i.e. dim=1
) and automatically promote to dim=2
?
from functorch.
@laurencer good question. For out-of-place operations and ignoring views, vmap(f)(x) should be equivalent to running torch.stack([f(xi) for xi in x.unbind(0)])
. This heuristic tells us the following:
x = torch.randn((3))
gives a tensor with shape(3,)
(dim=1)f(x[0])
gives a tensor with shape[]
(dim=0), sotorch.stack([f(xi) for xi in x.unbind(0)])
gives a tensor of shape[3]
(dim=1)- so
vmap(f)(x)
should give us a tensor with shape(3,)
(dim=1)
I think the batching rule for squeeze
should check to see if the tensor has dim 1 (and that the dim argument is equal to 0). If it does, then it returns an alias of the tensor (via tensor.alias()).
from functorch.
@zou3519 - is there a description of the new versus old style of batching rules?
I'm having a bit of trouble understanding what the following does and whether it's still needed in the new-style:
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return self.squeeze();
}
From the docs - it seems as though it's now a more functional approach where raw/native Torch tensors are passed in (instead of wrappers) along with optional parameters that describe the batch dimension (instead of being carried on the wrappers).
Is it right to assume that the optional<uint_64>
parameter passed after every Tensor is the number of batch dimensions (e.g. if it is set to 2
then the first 2
dimensions of the Tensor are batch dimensions?).
If it is empty then there are no batch dimensions?
Also if I return an empty batch dimension - then it appears to assume that there is 1
batch dimension (or this might be the original input number)?
Also there's implementations for squeeze
and squeeze.dim
- through trial-and-error I figured out that squeeze.dim
is invoked when the dim
optional parameter is passed. What's the mechanism/way to understand more about this dispatch method (or is this just the PyTorch dispatcher)?
from functorch.
is there a description of the new versus old style of batching rules?
Unfortunately no, not yet. This is the doc for the new style, but the old style was me hacking everything to work.
I'm having a bit of trouble understanding what the following does and whether it's still needed in the new-style:
That is not needed in the new-style if you use the VMAP_SUPPORT macro.
Is it right to assume that the optional<uint_64> parameter passed after every Tensor is the number of batch dimensions (e.g. if it is set to 2 then the first 2 dimensions of the Tensor are batch dimensions?).
The optional<int64_t>
passed after every Tensor is "the index of the batch dimension" (not the number of batch dimensions!) The new-style batching rules assume that there is only a single batch dimension (but there is some magic somewhere else that allows the batching rule to operate on multiple batch dimensions).
If it is empty then there are no batch dimensions?
Yes
Also if I return an empty batch dimension - then it appears to assume that there is 1 batch dimension (or this might be the original input number)?
You should be able to return nullopt
Also there's implementations for squeeze and squeeze.dim - through trial-and-error I figured out that squeeze.dim is invoked when the dim optional parameter is passed. What's the mechanism/way to understand more about this dispatch method (or is this just the PyTorch dispatcher)?
This is just the PyTorch dispatcher, but the summary is:
- Look in native_functions.yaml
- Depending on how you invoke
squeeze
in python, one of the {squeeze, squeeze.dim} operators gets called. If one just callsx.squeeze()
, it'll invoke thesqueeze
operator, if one callsx.squeeze(1)
, it'll invoke thesqueeze.dim
operator. - There's some logic in PyTorch's python bindings that parses the inputs to e.g.
squeeze
and selects which of {squeeze, squeeze.dim} actually gets called. If you're interested I can point you to that code (it is autogenerated so I can't link you it directly on github)
from functorch.
Fixed by #81, thank you @laurencer!
from functorch.
Related Issues (20)
- How to get only the last few layers' gradident? HOT 2
- [Question] Packaging policy for `functorch` and `torch.func` HOT 5
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
- Tensor.nonzero_static fails on GPU inside torch.func.vmap HOT 1
- Strange behaviour of autograd.functional.jacobian when vectorize=True and strategy=‘forward-mode’
- VMAP over GRU: Batching rule not implemented for aten::gru.input
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.