Code Monkey home page Code Monkey logo

Comments (5)

crcrpar avatar crcrpar commented on September 26, 2024

fsdp(jit(model)) seems affected but jit(fsdp(model)) not and fsdp(jit(model)) seems to fail to insert param all-gather.

[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/tests/distributed/test_ddp.py", line 793, in test_fsdp_grad_parity_with_without_bucketing
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     loss = cm(x).mean()
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 62, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     res = self._forward_fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 617, in fn_
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 202, in cache_info_wrapper
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     res = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 473, in get_computation_and_inputs
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     jit_results: TraceResults = interpreter(
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 190, in _general_frontend
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1551, in thunder_general_jit
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     result = jfn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6701, in fn_
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     raise e
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6669, in fn_2
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/tests/distributed/test_ddp.py", line 81, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return self.net2(new_gelu(self.net1(x)))
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return self._call_impl(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return forward_call(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6066, in _impl
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return fn.__func__(fn.__self__, *args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 116, in forward
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return F.linear(input, self.weight, self.bias)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1278, in wrapping_wrapper
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     res = ufn(*uargs, **ukwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/symbol.py", line 257, in __call__
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     result = self.meta(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 132, in _fn
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     result = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/torch/__init__.py", line 3999, in linear
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     return prims.linear(a, w, bias)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/symbol.py", line 253, in __call__
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     result = self.meta(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 132, in _fn
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     result = fn(*args, **kwargs)
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/prims.py", line 3549, in linear_meta
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     utils.check(
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]   File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660]     raise exception_type(s())
[rank1]:E0517 16:21:20.180000 140528173395840 torch/testing/_internal/common_distributed.py:660] RuntimeError: Expected w.shape=(4, 12) to have an innermost dimension of length 6, the same length as the innermost dimension of a.shape=(2, 6)!

The fsdp'd model is defined as follows:

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net1 = nn.Linear(12, 12)
        self.net2 = nn.Linear(12, 8)

    def forward(self, x):
        return self.net2(new_gelu(self.net1(x)))

Input shape is (2, 12) as in

x = torch.ones((2, 12), device=device)
and the world size is 2.

So the sharded params of net1 would have (6, 12), (6,) and those of net2, (4, 12), (4,).
From the failure log, self.net1(x)'s bsym returns a tensor of (2, 6), leading to the mismatch of 12 and 6.
If net1's params are correctly all-gather'ed then the output of prim.linear should have the shape of (2, 12) but it's apparently not

from lightning-thunder.

crcrpar avatar crcrpar commented on September 26, 2024

Early transforms are applied around

for transform in early_transforms:
.
Param all-gather's are applied as one them.

From the message above, the error could come from

jit_results: TraceResults = interpreter(
fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
)
which is before the needed transform.

from lightning-thunder.

mruberry avatar mruberry commented on September 26, 2024

triage review — @t-vi to look at this next

from lightning-thunder.

xwang233 avatar xwang233 commented on September 26, 2024

Thanks for the report. We also see those similar distributed tests failed in our CI.
image

from lightning-thunder.

nikitaved avatar nikitaved commented on September 26, 2024

As written in the header, my bisect points to #421.

from lightning-thunder.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.