Comments (5)
fsdp(jit(model))
seems affected but jit(fsdp(model))
not and fsdp(jit(model))
seems to fail to insert param all-gather.
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/tests/distributed/test_ddp.py", line 793, in test_fsdp_grad_parity_with_without_bucketing
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] loss = cm(x).mean()
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 62, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] res = self._forward_fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 617, in fn_
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 202, in cache_info_wrapper
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] res = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 473, in get_computation_and_inputs
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] jit_results: TraceResults = interpreter(
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 190, in _general_frontend
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1551, in thunder_general_jit
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] result = jfn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6701, in fn_
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] raise e
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6669, in fn_2
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/tests/distributed/test_ddp.py", line 81, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return self.net2(new_gelu(self.net1(x)))
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 116, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return F.linear(input, self.weight, self.bias)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1278, in wrapping_wrapper
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] res = ufn(*uargs, **ukwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] result = self.meta(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 132, in _fn
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] result = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/torch/__init__.py", line 3999, in linear
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] return prims.linear(a, w, bias)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/symbol.py", line 253, in __call__
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] result = self.meta(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 132, in _fn
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] result = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/prims.py", line 3549, in linear_meta
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] utils.check(
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] raise exception_type(s())
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] RuntimeError: Expected w.shape=(4, 12) to have an innermost dimension of length 6, the same length as the innermost dimension of a.shape=(2, 6)!
The fsdp'd model is defined as follows:
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Linear(12, 12)
self.net2 = nn.Linear(12, 8)
def forward(self, x):
return self.net2(new_gelu(self.net1(x)))
Input shape is (2, 12) as in
So the sharded params of net1
would have (6, 12), (6,) and those of net2
, (4, 12), (4,).
From the failure log, self.net1(x)
's bsym returns a tensor of (2, 6), leading to the mismatch of 12 and 6.
If net1
's params are correctly all-gather'ed then the output of prim.linear
should have the shape of (2, 12) but it's apparently not
from lightning-thunder.
Early transforms are applied around
lightning-thunder/thunder/__init__.py
Line 495 in 54bb614
Param all-gather's are applied as one them.
From the message above, the error could come from
lightning-thunder/thunder/__init__.py
Lines 473 to 475 in 54bb614
from lightning-thunder.
triage review — @t-vi to look at this next
from lightning-thunder.
Thanks for the report. We also see those similar distributed tests failed in our CI.
from lightning-thunder.
As written in the header, my bisect points to #421.
from lightning-thunder.
Related Issues (20)
- Add support for random ops in OpInfo
- enable using python_callable without mapping symbols to their impls HOT 1
- don't silently drop symbols without implementation
- Memory leak when raising an exception in jitted fn and catching it outside
- Prologue trace orders arguments in a way that breaks aliasing relation HOT 2
- PR #1110 nearly doubles the compilation & execution time of a copy-heavy program HOT 8
- HF LLaVa support HOT 4
- Thunder seems to use way more memory when `litgpt.Config.parallel_residual=True` HOT 5
- FSDP2 & Thunder looks memory hungrier than `thunder.distributed.fsdp` for certain models HOT 2
- `test_auto_register_torchops.py::TestFallbackToTorch::test_alexnet` is failing HOT 3
- improve shape accuracy in transform output by providing better tooling HOT 4
- provide trace checker and debug mode with it enabled HOT 2
- python `slice` is not represented properly in thunder
- [NeVa] thunder.core.interpreter.InterpreterError: Encountered exception IndexError: list index out of range while tracing GraphModule HOT 7
- Recipes and high-level entrypoint HOT 5
- Enable NvtxProfileTransform by default HOT 6
- `Tensor.copy_` tries to copy onto an intermediate tensor in a canonicalized trace
- Propagate tag information throughout a trace lifetime HOT 1
- Transform writer's guide
- KeyError on fusion remat when using saved for backward remat
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning-thunder.