Code Monkey home page Code Monkey logo

lightning-thunder's Introduction

Thunder Thunder

Make PyTorch models Lightning fast.


Lightning.aiPerformanceGet startedInstallExamplesInside ThunderGet involved!Documentation

license CI testing General checks Documentation Status pre-commit.ci status

Welcome to ⚡ Lightning Thunder

Thunder makes PyTorch models Lightning fast.

Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (for instance, nvFuser, torch.compile, cuDNN, and TransformerEngine FP8).

It supports both single and multi-GPU configurations. Thunder aims to be usable, understandable, and extensible.

 

Note

Lightning Thunder is in alpha. Feel free to get involved, but expect a few bumps along the way.

 

Single-GPU performance

Thunder can achieve significant speedups over standard non-compiled PyTorch code ("PyTorch eager"), through the compounding effects of optimizations and the use of best-in-class executors. The figure below shows the pretraining throughput for Llama 2 7B as implemented in LitGPT.

Thunder

As shown in the plot above, Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.

 

Multi-GPU performance

Thunder also supports distributed strategies such as DDP and FSDP for training models on multiple GPUs. The following plot displays the normalized throughput measured for Llama 2 7B without FP8 mixed precision; support for FSDP is in progress.

Thunder

 

Get started

The easiest way to get started with Thunder, requiring no extra installations or setups, is by using our Zero to Thunder Tutorial Studio.

 

Install Thunder

To use Thunder on your local machine, first install nvFuser nightly and PyTorch nightly together as follows:

# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com

Then, install Thunder as follows:

# install thunder
pip install lightning-thunder
Advanced install options

 

Install from main

Alternatively, you can install the latest version of Thunder directly from this GitHub repository as follows:

# 1) Install nvFuser and PyTorch nightly dependencies:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
# 2) Install Thunder itself
pip install git+https://github.com/Lightning-AI/lightning-thunder.git

 

Install to tinker and contribute

If you are interested in tinkering with and contributing to Thunder, we recommend cloning the Thunder repository and installing it in pip's editable mode:

git clone https://github.com/Lightning-AI/lightning-thunder.git
cd lightning-thunder
pip install -e .

 

Develop and run tests

After cloning the lightning-thunder repository and installing it as an editable package as explained above, ou can set up your environment for developing Thunder by installing the development requirements:

pip install -r requirements/devel.txt

Now you run tests:

pytest thunder/tests

Thunder is very thoroughly tested, so expect this to take a while.

 

Hello World

Below is a simple example of how Thunder allows you to compile and run PyTorch code:

import torch
import thunder


def foo(a, b):
    return a + b


jfoo = thunder.jit(foo)

a = torch.full((2, 2), 1)
b = torch.full((2, 2), 3)

result = jfoo(a, b)

print(result)

# prints
# tensor(
#  [[4, 4]
#   [4, 4]])

The compiled function jfoo takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs.

 

Train models

Thunder is in its early stages and should not be used for production runs yet.

However, it can already deliver outstanding performance for pretraining and finetuning LLMs supported by LitGPT, such as Mistral, Llama 2, Gemma, Falcon, and others.

Check out the LitGPT integration to learn about running LitGPT and Thunder together.

 

Inside Thunder: A brief look at the core features

Given a Python callable or PyTorch module, Thunder can generate an optimized program that:

  • Computes its forward and backward passes
  • Coalesces operations into efficient fusion regions
  • Dispatches computations to optimized kernels
  • Distributes computations optimally across machines

To do so, Thunder ships with:

  • A JIT for acquiring Python programs targeting PyTorch and custom operations
  • A multi-level intermediate representation (IR) to represent operations as a trace of a reduced operation set
  • An extensible set of transformations on the trace of a computational graph, such as grad, fusions, distributed (like ddp, fsdp), functional (like vmap, vjp, jvp)
  • A way to dispatch operations to an extensible collection of executors

Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility.

Thunder doesn't generate code for accelerators, such as GPUs, directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like:

Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations.

 

Documentation

Online documentation is available. To build documentation locally you can use

make docs

and point your browser to the generated docs at docs/build/index.html.

 

Get involved!

We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the GitHub Issue tracker.

We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.

 

License

Lightning Thunder is released under the Apache 2.0 license. See the LICENSE file for details.

lightning-thunder's People

Contributors

mruberry avatar ivanyashchuk avatar t-vi avatar borda avatar nikitaved avatar carmocca avatar crcrpar avatar robieta avatar dependabot[bot] avatar apaz-cli avatar kshitij12345 avatar rdspring1 avatar jjsjann123 avatar lantiga avatar wujingyue avatar vedaanta-nvidia avatar kiya00 avatar jacobhinkle avatar riccardofelluga avatar tfogal avatar young768 avatar kevinstephano avatar k223kim avatar izzyputterman avatar awaelchli avatar aidyn-a avatar pl-ghost avatar pre-commit-ci[bot] avatar anerudhan avatar parthmannan avatar

Stargazers

 avatar Aleksey Kuritsyn avatar Matt Stancliff avatar Dileep Kishore avatar Rohit Goswami avatar  avatar Arjun Variar avatar John Henning avatar  avatar Viraj Deshwal avatar  avatar  avatar Wojciech Prazuch avatar  avatar Pramukh avatar  avatar Jay avatar Andrei Moraru avatar David Toth avatar  avatar Mudassir Khan avatar Amy Rouillard avatar Egor Osinkin avatar Felix Aertebjerg avatar Andrés Marafioti avatar Robin Kahlow avatar Joao Ponte avatar  avatar Peter Brookstein avatar Ludovic Tuncay avatar  avatar Steven Yang avatar Albert Mohwald avatar  avatar  avatar  avatar Udeepa Meepegama avatar Riccardo avatar Antyanta Bangunharcana avatar Tilak avatar Abhinay Kumar avatar  avatar Jakub Langr avatar Colin Carter avatar Jonas Oppenlaender avatar Roger Shieh avatar Daniel Garcia avatar  avatar Julio C. Rangel avatar  avatar  avatar  avatar  avatar EmpyEmpt avatar  avatar kanttouchthis avatar Mei Chen avatar  avatar Alex Golonzovsky avatar Nick Brown avatar Chris (Tu) NGUYEN avatar Elmira Ghorbani avatar GSGR005 avatar  avatar Hanlin Zhang avatar Shaun Prince avatar Lucianius L. Wang avatar Jazz Yao-Tsung Wang avatar  avatar Priya Mishra avatar Tazmeen Afroz avatar Niranjan Anandkumar avatar Scott Riggs avatar Sung Ho Hwang avatar  avatar Siddharth Tiwari avatar Gökdeniz Gülmez avatar Bruno Soares de Castro avatar ipruning avatar  avatar Ryan Leary avatar Roman Kh avatar Robin Cole avatar CaioWingeter avatar  avatar  avatar  avatar Vectory avatar  avatar Okunator avatar Bipin Krishnan P avatar elucida avatar Marco Caccin avatar Frédéric Bastien avatar acmore avatar Kenneth Darrick Quiggins avatar Samuel Pandohan Terampil Gultom avatar Nimish avatar Stephen Lizcano avatar zhangwei avatar

Watchers

 avatar Takeshi Watanabe avatar Vishal Goklani avatar  avatar Naoya Maruyama avatar John Paul Hennessy avatar  avatar Gao, Xiang avatar  avatar  avatar  avatar  avatar William Falcon avatar Daniel Galvez avatar Tejash Shah avatar  avatar Scott Kwait avatar  avatar Ruixiang Zhang avatar Syed Tousif Ahmed avatar Noha Alon avatar Ethan Harris avatar Kasper Piskorski avatar Rob Levy avatar  avatar  avatar  avatar Justus Schock avatar  avatar  avatar  avatar Christian Sarofeen avatar  avatar  avatar Wojciech Prazuch avatar Khushi Agrawal avatar  avatar

lightning-thunder's Issues

Handle returning NamedTuples from the JIT

NamedTuples have (some) support in the interpreter, but to return them from the JIT, we would need some more things:

  • ensure value tracking covers creation (populating the attribute_wrappers) - this is in interpreter,
  • when seeing NamedTuples being returned, add epilogue code to create the NamedTuple from its contents (which are available),

This is nontrivial as it requires return values to be routed through the epilogue, but I don't think that is grave.

This is probably an intermediate to advanced issue for people fond of the JIT / frontend bits.

How do I access the `ThunderModule` if I'm compiling a function?

🚀 Feature

Motivation

Sometimes the code requires that a ThunderModule is passed, however, if the user is compiling a function that takes the module as an argument, the user doesn't have a way to get a reference to it.

For example, #96 implements a workaround for this issue with the no_sync context manager.

Pitch

Provide an API to get this reference. Maybe it's something like thunder.compile_data(jitted_function).module.

Additional context

The design might need to consider the presence of multiple ThunderModules.

Partial function is not supported in `grad_transform`

🚀 Feature

Hitting this assert below vvv

root@847841b8737c:/opt/pytorch/lightning-thunder# python /volume/pooling.py
Traceback (most recent call last):
  File "/volume/pooling.py", line 36, in <module>
    o = jit_model(image)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 632, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 265, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 574, in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
  File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 216, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3879, in forward_and_backward_from_trace
    forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 528, in _trace
    result = fn(*proxyargs, **proxykwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3850, in augmented_forward_fn
    result, env = augmented_forward_pass(*args, trace=trace, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3461, in augmented_forward_pass
    result, env = eval_trace(
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 1698, in eval_trace
    prim_func = symbol_mapper(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3385, in vjp_symbol_mapper
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/vjp_utils.py", line 63, in make_aug_forward_and_backward
    joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 506, in _trace
    proxyargs, proxykwargs = _unpack_inputs(fn, trace, args, kwargs, rename_proxies=rename_proxies)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 273, in _unpack_inputs
    si = get_siginfo(fn, args, kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/codeutils.py", line 313, in get_siginfo
    check(
  File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
NotImplementedError: Support for partials with positional args (like ('test',)) is not implemented yet

I was trying to use something like

foo = partial(bar, pos_arg0)
OperatorExecutor.register_operator(..., grad_transform=fn)

This isn't a high priority issue, since we can easily work around it for now. Filing the issue just to keep track of missing feature.

Add `torch.nn.Dropout` recomputation support during the backward pass to Thunder

🚀 Feature

I would like to have Thunder save the seed and offset from random number generation to allow for the recomputation of Dropout in the backward pass.

There are two pieces needed to make it work:

  • Support stateless (deterministic) PRNG. This is done with thunder.prims.uniform_philox.
  • Trace transform to query PyTorch's PRNG state before each uniform call, replacing uniform with uniform_philox , and incrementing PRNG state properly. This is not implemented.

Motivation

Multihead Attention modules in LLMs often use dropout where the memory used is the square of the sequence length.

cc @apaz-cli

Handling inplace through SSA

This issue is to facilitate discussion of inplace handling, namely the "big" solution of having a static single assignment (SSA) representation.

For any handling of inplace, we want to make certain that two things are achieved:

  • we don't want to take shortcuts that complicate passes by introducing the need to detect obstacles to optimizations, because it would harm usability and extensibility of Thunder.
  • we don't want to create ad-hoc band-aids to get things working that we would need to regress on later to introduce more proper handling because developing in the open more or less means no regressions.

Some thoughts from video/chat discussions:

About the problem:

  • The key difficulty in SSA is that we would need to keep track of which tensors get modified by an inplace update (i.e. which
    have memory that is to be updated), so we would need to know about views (the fancy term is alias analysis),
  • this is difficult for some things in PyTorch (i.e. reshape),
  • "assuming the worst" works to some extend.

Solution considerations:

  • Likely we would want inplace updates to have all affected tensors as outputs.
  • on inputs we would need to check for aliases as part of the prologue (maybe with a separate "assume aliasing is the OK" cache mode or sorts later),
  • operations need to know if their output is a view of their inputs (difficult for reshape, easy for most others),
  • initially, we would only check if tensors share storage,
  • likely the translation could be done in the interpretation phase,
  • we would need to have versioning / disambiguation of versions for tensor proxies during this, but not when we have the SSA.

Later versions could refine the alias analysis as needed.

@tfogal @mruberry @IvanYashchuk

CI fails to build `cuda 12.1 | torch 2.3 /test | cudnn FE v1.2`

🐛 Bug

CI fails to build build_push cuda 12.1 | torch 2.3 /test | cudnn FE v1.2 failed, apparently because PyTorch bumped the Triton dependency to 2.3.0.
https://github.com/Lightning-AI/lightning-thunder/runs/23676066094

To Reproduce

82.88 The conflict is caused by:
82.88     The user requested triton==2.2.0
82.88     torch 2.3.0+cu121 depends on triton==2.3.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12"

@Borda

Better name for elements of list in `prologue_trace` and `computation_trace`

import thunder
import torch

def foo(xs):
    result = []
    for x in xs:
        result.append(x + x)
    return result

jfoo = thunder.jit(foo)

o = jfoo([torch.randn(3,),] * 6)
print(thunder.last_prologue_traces(jfoo)[-1])
print(thunder.last_traces(jfoo)[-1])

Names for the arguments to the computation trace are : res, x, a, b, t_0_4, t_0_5. It would be nice if there was a consistent pattern.

Traces

Prologue Trace

# Constructed by Transform for execution (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 1)
    # prims.check_len(args, 1)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  subscr: "Any" = args[0]
  res: "cpu f32[3]" = subscr[0]
  x: "cpu f32[3]" = subscr[1]
  a: "cpu f32[3]" = subscr[2]
  b: "cpu f32[3]" = subscr[3]
  t_0_4: "cpu f32[3]" = subscr[4]
  t_0_5: "cpu f32[3]" = subscr[5]
  ...
  return (res, x, a, b, t_0_4, t_0_5)

Computation Trace

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(res, x, a, b, t_0_4, t_0_5):
  # res: "cpu f32[3]"
  # x: "cpu f32[3]"
  # a: "cpu f32[3]"
  # b: "cpu f32[3]"
  # t_0_4: "cpu f32[3]"
  # t_0_5: "cpu f32[3]"
  result = torch.add(res, res)  # result: "cpu f32[3]"
    # result = ltorch.add(res, res, alpha=None)  # result: "cpu f32[3]"
      # result = prims.add(res, res)  # result: "cpu f32[3]"
  del res
  ...
  return [result, t1, t2, t3, t4, t5]

Distributed Tests failing but CI is green

On latest main 94c9494, CI flow for distributed shows success https://github.com/Lightning-AI/lightning-thunder/runs/23172744261

But looking at the log, there are a few tests that have failed.

Sample

=================================== FAILURES ===================================
_ CompileDDPTest.test_fsdp_grad_parity_with_without_bucketing_executor_nvfuser_bucketing_block_zero2 _
/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_distributed.py:533: in wrapper
    self._join_processes(fn)
/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_distributed.py:752: in _join_processes
    self._check_return_codes(elapsed_time)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

Link to log: https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=196660&view=logs&j=47e66f3c-897a-5428-da11-bf5c7745762e&t=97be8351-284a-5dba-49eb-f9fe7c3ed1a2&l=811

cc @Borda

`TensorBase.cuda`

🚀 Feature

Implement Tensor.cuda that returns a cuda-backed copy of the given tensor.

Motivation

NeMo text-to-image model. It's plausible that the source tensor used in the model is a GPU tensor already, so we might be able to get by with just returning a tensor copy without worrying about movement between devices.

Label tracking meta-issue (edit me to get automatically CC'ed on issues!)

This issue is used by lightning-probot to manage subscriptions to labels. To subscribe yourself to a label, add a line * label @yourusername, or add your username to an existing line (space separated) in the body of this issue. Do not try to subscribe in comments, the bot only parses the initial post.

This is a copy of pytorch/pytorch#24422.

As a courtesy to others, please do not edit the subscriptions of users who are not you.


The current list of labels can be retrieved with $ gh label list --limit 1000 --json name --jq '.[] | "* " + .name' | sort -n

adding DDP/FSDP transform after JITting does not work

🐛 Bug

The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: Lightning-AI/litgpt#1204

The objective is that fsdp|ddp can be applied after the thunder.jit call.

It works with FSDP, but not with DDP where it fails with:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]:     out = tmodel(x)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]:     res = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]:     computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]:     bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]:     updated_bwd_trace = visitor_transform(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]:     visit_type = visit(bsym)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]:     self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]:     self._maybe_allreduce(bucket, group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]:     self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]:     result = self.meta(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]:     result = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]:     utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]:     check(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:     raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>

To Reproduce

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)

out = tmodel(x)

if local_rank == 0:
    print(thunder.last_backward_traces(tmodel)[-1].python())

torchrun --nproc-per-node 2 bug.py

cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23

Operator support for `F.one_hot`

🐛 Bug

thunder fails When attempting to compile a graph containing torch.nn.functional.one_hot within the forward pass.
The error message indicates that the input to the method must be a Tensor, but a TensorProxy is received instead.

To Reproduce

Steps to reproduce the behavior:

  • Define a PyTorch model class with a forward pass involving F.one_hot to convert the input tensor to a one-hot encoded representation.
  • Create an instance of the model and evaluate it on a random input tensor.
  • Compile the model using thunder.jit.
  • Call the compiled model with the same input tensor.

Example

import thunder


class MLP(nn.Module):
    def __init__(self, hidden_size=1024):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(6 * 256, hidden_size, bias=False)
        self.head = nn.Linear(hidden_size, 32000, bias=False)

    def forward(self, inputs):
        x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
        x = self.hidden(x)
        logits = self.head(x)
        return logits


x = torch.randint(0, 6, (1, 256))

model = MLP(1024).eval()
print(model(x))

model = thunder.jit(model)
print(model(x))
Output
tensor([[-0.1134, -0.0827, -0.0205,  ...,  0.0757,  0.0066,  0.0974]],
       grad_fn=<MmBackward0>)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-6-6425e5faad6e>](https://localhost:8080/#) in <cell line: 23>()
     21 
     22 model = thunder.jit(model)
---> 23 print(model(x))

16 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    192 
    193     def forward(self, *args, **kwargs):
--> 194         res = self._forward_fn(*args, **kwargs)
    195         return res
    196 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in fn_(*args, **kwargs)
    609         cs.calls += 1
    610 
--> 611         cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    612         cs.last_trace_host_execution_start = time.time_ns()
    613 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in cache_info_wrapper(*args, **kwargs)
    260         tok = _cache_info_ctx.set({})
    261         try:
--> 262             res = fn(*args, **kwargs)
    263         finally:
    264             _cache_info_ctx.reset(tok)

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in get_computation_and_inputs(*args, **kwargs)
    496                 prologue_trc: TraceCtx
    497                 computation_trc: TraceCtx
--> 498                 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    499                     fn, args, kwargs, sharp_edges=cd.sharp_edges
    500                 )

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in _general_frontend(fn, args, kwargs, sharp_edges)
    173 # Translates the Python function to a thunder program using the thunder interpreter
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
    176 
    177 

[/usr/local/lib/python3.10/dist-packages/thunder/core/jit_ext.py](https://localhost:8080/#) in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1384     with general_jit_ctx(ctx):
   1385         with tracectx(computation_trace):
-> 1386             result = jfn(*args, **kwargs)
   1387             prims.python_return(result)
   1388             process_recorded_modifications(ctx, epilogue_trace)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_(*args, **kwargs)
   6578                 assert isinstance(e, BaseException), e
   6579                 runtimectx.curexc = None
-> 6580                 raise e
   6581 
   6582             return interpretation_result

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_2()
   6541                 def getfn():
   6542                     def fn_2(args, kwargs):
-> 6543                         return fn(*args, **kwargs)
   6544 
   6545                     return fn_2

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _wrapped_call_impl()
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_impl()
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in forward()
      9 
     10     def forward(self, inputs):
---> 11         x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
     12         x = self.hidden(x)
     13         logits = self.head(x)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)
   6067         kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
   6068         try:
-> 6069             opaque_result: Any = fn(*args_, **kwargs_)
   6070         except Exception as e:
   6071             runtimectx.curexc = e

TypeError: one_hot(): argument 'input' (position 1) must be Tensor, not TensorProxy

Environment

  • OS: Ubuntu/Google Colab
  • Python Version: 3.10
  • PyTorch Version: 2.3.0.dev20240314+cu121
  • Thunder Version: 0.1.0
  • Installation:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
pip install lightning-thunder

Additional context

  • Other functional methods like F.relu doesn't seem to raise the issue.

Non-supported diffusion transformer operators

TensorBase.bfloat16
_set_grad_enabled of torch._C
_VariableFunctionsClass.empty of torch
TensorBase.long
TensorBase.type
TensorBase.__setitem__
_VariableFunctionsClass.lerp of torch
device of torch
TensorBase.clone
TensorBase.masked_fill_
TensorBase.get_device
TensorBase.grad_fn
_VariableFunctionsClass.linspace of torch

cc @apaz-cli

Operator support for `F.hardswish`

🚀 Feature

Implement HardSwish activation function.

Motivation

Relatively easy activation function implementation as a good first issue as nikitaved suggested under #64

Pitch

Add HardSwish (x * ReLU6(x + 3) / 6) leveraging existing ReLU6 support.

cc @apaz-cli

[lit-GPT] Thunder with torch.compile executor performs consistently worse than Thunder on all model sizes/batch sizes on Pythia models

🐛 Bug

The performance of using the hybridized torch.compile executor w/ Thunder is worse than plain Thunder on Pythia models. These set of models differ from LLaMa architecture in few main ways -

  1. Use LayerNorm instead of RMSNorm
  2. Use GeLU instead of 'SiLU(x) * x`
  3. Uses parallel residual (i.e. the MLP block is computed with an input computed before the Attention block, not after)

Example performance on H100 Single Node FP16 for Pythia6.9B, MBS=1, GBS=8, FSDP ZeRO2 w/o bucketing
Thunder iteration time (ms) = 232.74 ms
Thunder + torch.compile iteration time (ms) = 239.23 ms

cc @crcrpar @apaz-cli

Support `memory_format` on `to()`

🚀 Feature

to(memory_format=something) is part of the MegatronImagen model in NeMo.

Ideally, this would work:

$ git diff .
diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imen/imagen.py
index 4fa6cd230..2cf7a8ffa 100644
--- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
+++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py
@@ -31,6 +31,7 @@ from nemo.collections.nlp.modules.common.megatron.module import Float16Module
 from nemo.collections.nlp.parts.utils_funcs import get_last_rank
 from nemo.core.classes.common import Serialization
 from nemo.utils import logging
+import thunder
 
 try:
     from apex import amp
@@ -190,6 +191,7 @@ class MegatronImagen(MegatronBaseModel):
         self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
 
         self.model = self.model_provider_func()
+        self.model = thunder.jit(self.model)
 
         if self.trainer.precision in ['bf16', 'bf16-mixed']:
             self.autocast_dtype = torch.bfloat16

Motivation

Trying to evaluate NeMo models in thunder and expand our model support there. Megatron-based models appear to be widely used.

Alternatives

I wonder if we could temporarily just accept the keyword without actually doing anything about it. I imagine that would be very slow, but it might allow us to get models like this one into thunder more easily.

I'll start trying to convert smaller parts of the model next.

Additional context

Model in question:

https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L175

I think the to that is failing for me
is actually this line:
https://github.com/NVIDIA/NeMo/blob/23baa48e441ecb6cc6b49c23bf8cfc076db38bdc/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py#L135

Model test:
log.txt

Support for torchvision models, e.g., a simple ViT

🐛 Bug

I was trying to run a simple torchvision ViT and am getting the following error:

File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 136, in <module>
    train(
  File "/teamspace/studios/this_studio/minimal-vit/01_pytorch-vit.py", line 31, in train
    logits = model(features)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 194, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 611, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 262, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 498, in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/__init__.py", line 175, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/jit_ext.py", line 1386, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6580, in fn_
    raise e
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 6543, in fn_2
    return fn(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(x, x, x, need_weights=False)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1236, in forward
    any_nested = query.is_nested or key.is_nested or value.is_nested
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 5942, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/interpreter.py", line 1253, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/proxies.py", line 1234, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_nested

Not sure how to go about debugging this. I thought that sharing this may help improving thunder in terms of supporting more models and edge cases

To Reproduce

Steps to reproduce the behavior:

I attached self-contained code in the zip.

# Runs PyTorch eager, works ok!

python 01_pytorch-vit.py

# Runs torch.compile, works ok!
python 01_pytorch-vit.py --compilation_option "torch.compile"

# Runs thunder.jit(), fails! (See error above)
python 01_pytorch-vit.py --compilation_option "thunder_default"

Code sample

See zip attached

Expected behavior

Either a clearer error message or ideally it should work :)

Environment

Same as Zero to Thunder studio.

Archive.zip

cc @apaz-cli

caching in make_aug_forward_and_backward breaks TE executor.

As discussed offline, Caching in make_aug_forward_and_backward leads to reusing the symbols created by transformer_engine_ex which are stateful and lead to incorrect program.
Ref:

key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs))
cached_result = _cache.get(key, None) if subkey is not None else None
if cached_result is not None:
return cached_result

Sample Program

import torch
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex
from transformer_engine.pytorch import fp8_autocast
dim = 256

class ThunderModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(dim, dim, bias=False)
        self.fc2 = torch.nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        return self.fc2(torch.nn.functional.relu(self.fc1(x)))

x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()

thunder_model = ThunderModel().cuda()

jit_model = thunder.jit(thunder_model, executors=(transformer_engine_ex,),)

with fp8_autocast():
    o = jit_model(x).sum()

print(thunder.last_traces(jit_model)[-1])

Generated Trace (te_linear_0 is called twice):

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, t1, t2, = args
  del args
  (t6, ctx_te_1) = te_linear_0(t0, t1, None)
  t7 = torch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
    # t7 = ltorch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
      # t7 = prims.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
  t8 = torch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
    # t8 = ltorch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
      # t8 = prims.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
  del t6
  (t13, C12) = te_linear_0(t8, t2, None)
  del t8
  return {'output': t13, 'flat_args': [t0, t1, t2], 'flat_output': (t13,)}, ((t7,), (C12, ctx_te_1))

Sunset `thunder/benchmarks/distributed.py` and Improve `thunder/benchmarks/benchmark_litgpt.py`

  • [cosmetic] improve the format of JSON output of benchmark_litgpt.py

https://github.com/Lightning-AI/lightning-thunder/blob/cdd43a7fc1110eec10f1854250299b84d1c3b2a8/thunder/benchmarks/distributed.py has been useful but I would find it not easy to extend, e.g. to support gradient accumulation.

https://github.com/Lightning-AI/lightning-thunder/blob/cdd43a7fc1110eec10f1854250299b84d1c3b2a8/thunder/benchmarks/benchmark_litgpt.py would be easy to work with as in #45 which is adding gradient accumulation with no_sync.

cc @crcrpar @carmocca @awaelchli

The `_FabricModule` cannot be jitted after #78

🐛 Bug

extensions/thunder/pretrain.py:146: in setup
    main(
extensions/thunder/pretrain.py:233: in main
    fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
extensions/thunder/pretrain.py:253: in fit
    validate(fabric, model, val_dataloader, max_iters=2)  # sanity check
../nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
extensions/thunder/pretrain.py:389: in validate
    loss = forward_and_loss(model, input_ids, targets)
../lightning-thunder/thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
../lightning-thunder/thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
../lightning-thunder/thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6669: in fn_
    raise e
../lightning-thunder/thunder/core/interpreter.py:6632: in fn_2
    return fn(*args, **kwargs)
extensions/thunder/pretrain.py:371: in forward_and_loss
    logits = model(input_ids)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../lightning/src/lightning/fabric/wrappers.py:142: in forward
    with precision.forward_context():
../lightning/src/lightning/fabric/plugins/precision/half.py:54: in forward_context
    return self.tensor_init_context()
../lightning/src/lightning/fabric/plugins/precision/half.py:46: in tensor_init_context
    return _DtypeContextManager(self._desired_input_dtype)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def __init__(self, dtype: torch.dtype) -> None:
>       self._previous_dtype: torch.dtype = torch.get_default_dtype()
E       NotImplementedError: Trying to call function torch.get_default_dtype, but it is not yet supported. Please file an issue requesting support. To find out which operations are not yet recongnized by `thunder.jit`, please run `examine` as per:
E       
E       from thunder.examine import examine
E       examine(<your thunder.jit callable argument>, ...)

../lightning/src/lightning/fabric/plugins/precision/utils.py:33: NotImplementedError

Jitting the _FabricModule is currently necessary to compile the joint forward and loss

To Reproduce

from lightning import Fabric
import torch
import thunder

fabric = Fabric(devices=1, precision="16-true")
model = torch.nn.Linear(1, 1, bias=False, device=fabric.device)
x = torch.randn(1, 1)
x = fabric.to_device(x)

fmodel = fabric.setup(model)
tmodel = thunder.jit(fmodel)

print(tmodel(x))

cc @nikitaved

Does `jit` understand monkeypatched methods?

🐛 Bug

Tensor.register_hook is currently not supported by Thunder.

In Lightning Fabric, we use this once for error checking that the user properly called backward. https://github.com/Lightning-AI/pytorch-lightning/blob/096b063d6eeb41567409f4a6b9bac6f5af28ed93/src/lightning/fabric/wrappers.py#L232-L233

Since this hook is not critical, as it's only meant to avoid user errors, I would like to be able to monkeypatch it externally.

However, it doesn't seem like it has an effect with Thunder:

To Reproduce

import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
import thunder

model = torch.nn.Linear(1, 1, bias=False, device="cuda")
x = torch.randn(1, 1, device="cuda", requires_grad=True)

fabric = Fabric(accelerator="cuda", devices=1)
model = fabric.setup(model)

# monkeypatch what's causing trouble
assert isinstance(model, _FabricModule)
assert model._register_backward_hook is not None
model._register_backward_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()
print(y)
print(x.grad)

Which fails as Thunder doesn't support register_hook

AttributeError: The torch language context has no method register_hook

Interestingly, a non-fabric snippet doesn't fail so there is something funny going on:

import thunder
import torch

class Wrapper(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(1, 1, bias=False)

    def forward(self, x):
        y = self.model(x)
        self.register_hook(y)
        return y

    def register_hook(self, tensor):
        tensor.register_hook(self.hook)

    def hook(self, _):
        print("hi")

model = Wrapper()
x = torch.randn(1, 1)

model.register_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()

Updating a nn.Module attribute in forward raises an exception in prologue trace.

import torch
import thunder

import thunder.examine

class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.bar = 1

    def forward(self, x):
        self.bar = self.bar + 1
        # self.bar = 2  # This works
        return x

m = MyModule()

x = torch.randn(16, 16, device='cuda')

jit_linear = thunder.jit(m)

o = jit_linear(x)

Error:

File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 537, in get_computation_and_inputs
    inps = pro(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "thunder.prologue_0", line 16, in prologue
  File "/home/kkalambarkar/lightning-thunder/thunder/executors/pythonex.py", line 100, in _check_number_type_and_value_impl
    utils.check(
  File "/home/kkalambarkar/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: Expected 2 to be equal to and have the type of 1

cc @apaz-cli

Benchmark targets on test_nanogpt_cross_entropy_grad has some import issue

🐛 Bug

Benchmark targets on test_nanogpt_cross_entropy_grad has some import issue

To Reproduce

Steps to reproduce the behavior:

root@8d345ed01185:/opt/pytorch/lightning-thunder# pytest -vvvs thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad]
============================================================================================== test session starts ==============================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/opt/pytorch/lightning-thunder/.hypothesis/examples'))
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: cov-4.1.0, hypothesis-6.100.0, random-order-1.1.1, timestamper-0.0.10, timeout-2.2.0, xdist-3.5.0, shard-0.1.2, benchmark-4.0.0
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 1 item
Running 1 items in this shard: thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad]

[2024-04-10 21:48:06] thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] FAILED

=================================================================================================== FAILURES ====================================================================================================
______________________________________________________________________________ test_nanogpt_cross_entropy_grad[thunder+apex-grad] _______________________________________________________________________________

benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x7fe26c8f0cd0>
executor = functools.partial(<function thunder_grad_transform at 0x7fe26ca5ecb0>, compile_fn=<function thunder_apex_executor at 0x7fe26cad8b80>)

    @pytest.mark.parametrize(
        "executor,",
        (grad_executors + apex_grad_executors),
        ids=(grad_executors_ids + apex_grad_executors_ids),
    )
    def test_nanogpt_cross_entropy_grad(benchmark, executor: None | Callable):
        if executor is None:
            pytest.skip("Executor is unavailable")

        bench: Benchmark = NanoGPTCrossEntropyBenchmark(
            config="gpt2-xl", device="cuda:0", dtype=thunder.bfloat16, requires_grad=True
        )

        setup = make_setup(bench)
        fn = executor(bench)
        fn = wrap_for_benchmark(fn)

>       benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1)

thunder/benchmarks/targets.py:479:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:137: in pedantic
    return self._raw_pedantic(target, args=args, kwargs=kwargs, setup=setup, rounds=rounds,
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:211: in _raw_pedantic
    runner(loops_range)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:95: in runner
    result = function_to_benchmark(*args, **kwargs)
thunder/benchmarks/targets.py:60: in fn_
    result = fn(*args, **kwargs)
thunder/benchmarks/targets.py:235: in wrapper
    populate_grads(grads, cfn, args=args, kwargs=kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

grads = [tensor([[0.0000e+00, 2.2768e-18, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0090e-12,...4, 8.5986e-29, 0.0000e+00,  ..., 0.0000e+00, 5.7932e-31,
         0.0000e+00]], device='cuda:0', dtype=torch.bfloat16)]
tom = <function NanoGPTCrossEntropyBenchmark.fn.<locals>.foo at 0x7fe26c93c8b0>
args = (tensor([[ 79.0000, 227.0000,   8.5625,  ..., 166.0000, 152.0000, 154.0000],
        [240.0000, 224.0000,   2.3125,  ....cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>), tensor([223, 144, 141,  ..., 219, 169, 186], device='cuda:0'))
kwargs = {}

    def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, args=None, kwargs=None) -> None:
        idx: int = 0
        from thunder import ThunderModule, compile_data

>       if isinstance(tom, ThunderModule) or thunder.compile_data(tom).using_jit:
E       NameError: name 'thunder' is not defined

thunder/core/transforms.py:555: NameError
============================================================================================ short test summary info ============================================================================================
FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] - NameError: name 'thunder' is not defined
======================================================================================== 1 failed, 5 warnings in 11.31s =========================================================================================

Code sample

see above

Expected behavior

benchmark should be able to run

Environment

  • internal image: pjnl-20240410
  • thunder: dba8ce7

Additional context

same issues on those two:

FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex-grad] - NameError: name 'thunder' is not defined
FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy_grad[thunder+apex+nvfuser-grad] - NameError: name 'thunder' is not defined

cc @tfogal @IvanYashchuk

`test_vjp_correctness` fails with ops that return tensors that do not require grads.

🐛 Bug

As per title. To reproduce, one could uncomment these tests in these tests in #118 to get:

thunder/tests/test_grad.py:423: in test_vjp_correctness                                                                                                                                                                                       
    result = run_snippet(                                                                                                                                                                                                                     
thunder/tests/framework.py:483: in run_snippet                                                                                                                                                                                                
    raise ex                                                                                                                                                                                                                                  
thunder/tests/framework.py:475: in run_snippet                                                                                                                                                                                                
    snippet(*args, **kwargs)                                                                                                                                                                                                                  
thunder/tests/test_grad.py:394: in snippet_vjp_correctness                                                                                                                                                                                    
    check_vjp(func, *args, executor=executor)                                                                                                                                                                                                 
thunder/tests/test_grad.py:304: in check_vjp                                                                                                                                                                                                  
    _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)                                                                                                                                      
thunder/common.py:783: in _fn                                                                                                                                                                                                                 
    trc_or_result = trace(compile_data=cd)(processed_function, *args, **kwargs)                                                                                                                                                               
thunder/core/interpreter.py:1298: in fn_                                                                                                                                                                                                      
    return fn(*args, **kwargs)                                                                                                                                                                                                                
thunder/common.py:534: in _trace                                                                                                                                                                                                              
    result = fn(*proxyargs, **proxykwargs)                                                                                                                                                                                                    
thunder/core/transforms.py:3629: in _vjp                                                                                                                                                                                                      
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)                                                                                                                                                                         
thunder/core/transforms.py:3603: in vjp_call_metafunc                                                                                                                                                                                         
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)                                                                                                                                                                     
thunder/core/transforms.py:3414: in augmented_forward_pass                                                                                                                                                                                    
    result, env = eval_trace(                                                                                                                                                                                                                 
thunder/core/transforms.py:1693: in eval_trace                                                                                                                                                                                                
    prim_func = symbol_mapper(symbol)                                                                                                                                                                                                         
thunder/core/transforms.py:3338: in vjp_symbol_mapper                                                                                                                                                                                         
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)                                                                                                                                                                             
thunder/core/vjp_utils.py:99: in make_aug_forward_and_backward                                                                                                                                                                                
    backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0])                                                                                                                                    
thunder/core/utils.py:1062: in find_producer_symbols                                                                                                                                                                                          
    if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
                                                                                                                                                                                                                                              
x = None                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
>   if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
E   AttributeError: 'NoneType' object has no attribute 'name'                                                                                                                                                                                 
                                                                                                                                                                                                                                              
thunder/core/utils.py:1062: AttributeError     

Cuda only?

Hi

Thanks for sharing this with the community. Much appreciated.

I am wondering if this works only with cuda hardware. For example, does it work with AMD GPUs through rocm?

torchex running pooling without decomposition

🚀 Feature

max_poolXd through decomposition is expensive in thunder. torch executor should be able to run those as a single aten call on fwd as well as bwd via a custom grad_transform

Motivation

Currently if we run the example below vvv

import torch
import torch.nn as nn

import thunder

dtype = torch.float16
batch_size = 32
test_grad = True

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        layers = list()
        layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.layer = nn.Sequential(*layers)

    def forward(self, inp):
        return self.layer(inp)

model = Model()

model = model.cuda()
model = model.to(dtype)

image = torch.randn(batch_size, 3, 224, 224, dtype=dtype).cuda()
if test_grad:
    image.requires_grad_()

def fn(arg):
    return model(arg)

jit_model = thunder.jit(fn)

# warm up
for i in range(20):
  o = jit_model(image)
  if test_grad:
      o.sum().backward()
  o = fn(image)
  if test_grad:
      o.sum().backward()

import time
fwd_traces = thunder.last_traces(jit_model)
print("fwd_traces:\n", fwd_traces[-1])
if test_grad:
    bwd_traces = thunder.last_backward_traces(jit_model)
    print("bwd_traces:\n", bwd_traces[-1])

torch.cuda.synchronize()

t0 = time.time()
for i in range(10):
    o = jit_model(image)
    if test_grad:
        o.sum().backward()
        image.grad = None
torch.cuda.synchronize()
print("jit_model elapsed time: ", time.time() - t0)

torch.cuda.synchronize()
t0 = time.time()
for i in range(10):
    o = fn(image)
    if test_grad:
        o.sum().backward()
        image.grad = None
torch.cuda.synchronize()
print("torch eager elapsed time: ", time.time() - t0)

jit_model elapsed time: 0.02024698257446289
torch eager elapsed time: 0.002202272415161133

fwd graph looks like

from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(arg):
  # arg: "cuda:0 f16[32, 3, 224, 224]"
  t0 = prims.pad(arg, -float('inf'), [(0, 0, 0), (0, 0, 0), (1, 1, 0), (1, 1, 0)])  # t0: "cuda:0 f16[32, 3, 226, 226]"
  t1 = ltorch.arange(9, None, 1, device=devices.Device("cuda:0"), dtype=None)  # t1: "cuda:0 i64[9]"
    # t1 = prims.iota(9, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64)  # t1: "cuda:0 i64[9]"
  t2 = prims.broadcast_in_dim(t1, [9, 1], [0])  # t2: "cuda:0 i64[9, 1]"
  t3 = prims.broadcast_in_dim(t1, [1, 9], [1])  # t3: "cuda:0 i64[1, 9]"
  t4 = prims.broadcast_in_dim(t2, (9, 9), (0, 1))  # t4: "cuda:0 i64[9, 9]"
  t5 = prims.broadcast_in_dim(t3, (9, 9), (0, 1))  # t5: "cuda:0 i64[9, 9]"
  t6 = prims.eq(t4, t5)  # t6: "cuda:0 b8[9, 9]"
  t7 = prims.convert_element_type(t6, dtypes.float16)  # t7: "cuda:0 f16[9, 9]"
  t8 = prims.reshape(t7, (1, 9, 1, 3, 3))  # t8: "cuda:0 f16[1, 9, 1, 3, 3]"
  t9 = prims.broadcast_in_dim(t8, (3, 9, 1, 3, 3), (0, 1, 2, 3, 4))  # t9: "cuda:0 f16[3, 9, 1, 3, 3]"
  t10 = prims.reshape(t9, (27, 1, 3, 3))  # t10: "cuda:0 f16[27, 1, 3, 3]"
  t11 = prims.convolution(t0, t10, None, (2,), (0,), (1,), False, (0, 0), 3)  # t11: "cuda:0 f16[32, 27, 112, 112]"
  t12 = prims.reshape(t11, (32, 3, 9, 112, 112))  # t12: "cuda:0 f16[32, 3, 9, 112, 112]"
  t13 = prims.convert_element_type(t12, dtypes.float32)  # t13: "cuda:0 f32[32, 3, 9, 112, 112]"
  t14 = prims.amax(t13, (2,))  # t14: "cuda:0 f32[32, 3, 112, 112]"
  t15 = prims.convert_element_type(t14, dtypes.float16)  # t15: "cuda:0 f16[32, 3, 112, 112]"
  return {'output': t15, 'flat_args': [arg], 'flat_output': (t15,)}, ((t10, t14, t13), (0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 3, 32, 27, 112, 112, 32, 3, 9, 112, 112, 2))

bwd graph:

 i239 = operator.neg(i86)  # i239
    # i239 = prims.neg(i86)  # i239
  del i86
  i257 = operator.neg(i27)  # i257
    # i257 = prims.neg(i27)  # i257
  del i27
  i258 = operator.neg(i28)  # i258
    # i258 = prims.neg(i28)  # i258
  del i28
  i259 = operator.neg(i30)  # i259
    # i259 = prims.neg(i30)  # i259
  del i30
  i260 = operator.neg(i31)  # i260
    # i260 = prims.neg(i31)  # i260
  del i31
  i261 = operator.neg(i33)  # i261
    # i261 = prims.neg(i33)  # i261
  del i33
  i262 = operator.neg(i34)  # i262
    # i262 = prims.neg(i34)  # i262
  del i34
  i263 = operator.neg(i36)  # i263
    # i263 = prims.neg(i36)  # i263
  del i36
  i264 = operator.neg(i37)  # i264
    # i264 = prims.neg(i37)  # i264
  del i37
  t303 = torch.unsqueeze(t19, 2)  # t303
    # t303 = ltorch.unsqueeze(t19, 2)  # t303
      # t303 = prims.broadcast_in_dim(t19, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t303
  del t19
  t220 = Tensor.expand(t303, [32, 3, 1, 2, 2])  # t220
    # t220 = ltorch.expand(t303, [32, 3, 1, 2, 2])  # t220
      # t220 = prims.broadcast_in_dim(t303, (32, 3, 1, 2, 2), (0, 1, 2, 3, 4))  # t220
  del t303
  t221 = Tensor.expand(t220, (i104, i105, i106, i107, i108))  # t221
    # t221 = ltorch.expand(t220, (i104, i105, i106, i107, i108))  # t221
      # t221 = prims.broadcast_in_dim(t220, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4))  # t221
  del t220
  t233 = torch.permute(t15, (1, 0, 2, 3))  # t233
    # t233 = ltorch.permute(t15, (1, 0, 2, 3))  # t233
      # t233 = prims.transpose(t15, (1, 0, 2, 3))  # t233
  del t15
  t234 = torch.reshape(t233, [1, i91, 9, 3, 3])  # t234
    # t234 = ltorch.reshape(t233, [1, i91, 9, 3, 3])  # t234
      # t234 = prims.reshape(t233, (1, i91, 9, 3, 3))  # t234
  del t233
  t235 = torch.permute(t234, (1, 0, 2, 3, 4))  # t235
    # t235 = ltorch.permute(t234, (1, 0, 2, 3, 4))  # t235
      # t235 = prims.transpose(t234, (1, 0, 2, 3, 4))  # t235
  del t234
  t236 = torch.reshape(t235, [3, 9, 3, 3])  # t236
    # t236 = ltorch.reshape(t235, [3, 9, 3, 3])  # t236
      # t236 = prims.reshape(t235, (3, 9, 3, 3))  # t236
  del t235
  [t230, t282] = nvFusion0(i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221)
    # t18 = prims.convert_element_type(t17, dtypes.float32)  # t18
    # t282 = prims.pad(t0, 0.0, [(0, 0, 0), (0, 0, 0), (i9, 3, 0), (i10, 3, 0)])  # t282
    # t217 = prims.convert_element_type(t21, dtypes.float32)  # t217
    # t218 = prims.broadcast_in_dim(t217, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t218
    # t219 = prims.broadcast_in_dim(t218, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4))  # t219
    # t222 = prims.eq(t18, t221)  # t222
    # t223 = prims.sum(t222, (i109,))  # t223
    # t224 = prims.broadcast_in_dim(t223, [32, 3, 1, 2, 2], [0, 1, 3, 4])  # t224
    # t225 = prims.convert_element_type(t222, dtypes.float32)  # t225
    # t226 = prims.mul(t219, t225)  # t226
    # t227 = prims.broadcast_in_dim(t224, (32, 3, 9, 2, 2), (0, 1, 2, 3, 4))  # t227
    # t228 = prims.convert_element_type(t227, dtypes.float32)  # t228
    # t229 = prims.div(t226, t228)  # t229
    # t230 = prims.convert_element_type(t229, dtypes.float16)  # t230
  del i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221
  t283 = torch.permute(t282, (1, 0, 2, 3))  # t283
    # t283 = ltorch.permute(t282, (1, 0, 2, 3))  # t283
      # t283 = prims.transpose(t282, (1, 0, 2, 3))  # t283
  del t282
  t284 = torch.reshape(t283, [i16, 3, 32, 11, 11])  # t284
    # t284 = ltorch.reshape(t283, [i16, 3, 32, 11, 11])  # t284
      # t284 = prims.reshape(t283, (i16, 3, 32, 11, 11))  # t284
  del t283
  t285 = torch.permute(t284, (1, 0, 2, 3, 4))  # t285
    # t285 = ltorch.permute(t284, (1, 0, 2, 3, 4))  # t285
      # t285 = prims.transpose(t284, (1, 0, 2, 3, 4))  # t285
  del t284
  t286 = torch.reshape(t285, [3, 32, 11, 11])  # t286
    # t286 = ltorch.reshape(t285, [3, 32, 11, 11])  # t286
      # t286 = prims.reshape(t285, (3, 32, 11, 11))  # t286
  del t285
  t231 = torch.reshape(t230, (i97, i98, i99, i100))  # t231
    # t231 = ltorch.reshape(t230, (i97, i98, i99, i100))  # t231
      # t231 = prims.reshape(t230, (i97, i98, i99, i100))  # t231
  del t230, i97, i98, i99, i100
  t232 = torch_pad_prim_impl(t231, 0.0, [(0, 0, 0), (0, 0, 0), (0, 0, 1), (0, 0, 1)])  # t232
  del t231
  t237 = torch.flip(t236, (2, 3))  # t237
    # t237 = ltorch.flip(t236, (2, 3))  # t237
      # t237 = prims.flip(t236, (2, 3))  # t237
  del t236
  t238 = torch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
    # t238 = ltorch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
      # t238 = prims.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91)  # t238
  del t232, t237, i87, i89, i90, i91
  [t270] = nvFusion1(i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238)
    # t241 = prims.pad(t238, 0.0, [(0, 0, 0), (0, 0, 0), (i239, 0, 0), (i239, 0, 0)])  # t241
    # t265 = prims.pad(t241, 0.0, [(i257, i258, 0), (i259, i260, 0), (i261, i262, 0), (i263, i264, 0)])  # t265
    # t266 = prims.slice(t265, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t266
    # t267 = prims.slice(t266, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t267
    # t268 = prims.slice(t267, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t268
    # t269 = prims.slice(t268, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1])  # t269
    # t270 = prims.where(t2, t269, 0.0)  # t270
  del i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238
  t287 = torch.permute(t270, (1, 0, 2, 3))  # t287
    # t287 = ltorch.permute(t270, (1, 0, 2, 3))  # t287
      # t287 = prims.transpose(t270, (1, 0, 2, 3))  # t287
  del t270
  t288 = torch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
    # t288 = ltorch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
      # t288 = prims.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16)  # t288
  del t286, t287, i11, i12, i7, i8, i14, i15, i16
  t289 = torch.permute(t288, (1, 0, 2, 3))  # t289
    # t289 = ltorch.permute(t288, (1, 0, 2, 3))  # t289
      # t289 = prims.transpose(t288, (1, 0, 2, 3))  # t289
  del t288
  return (None, t289)

Pitch

I'm prototyping this in a draft PR (not functional yet!)

Alternatives

We can have pooling layers as we prim as well, but I don't think that's a necessity at this point.

Skipped distributed tests show up as passed (return 0)

          Note that this can be very misleading because a skipped test also returns 0, so it can make it seem like a test passed when it didn't run

Originally posted by @carmocca in #130 (comment)

python -um pytest -sv "$test" --pythonwarnings ignore --junitxml="$test-results.xml" 2>&1 > "$test-output.txt"
pytest_status=$?
printf "$test status >>> $pytest_status\n"
if [ $pytest_status -ne 0 ]; then
status=$pytest_status
cat "$test-output.txt"
fi

cc @Borda

Represent slices natively in traces

🚀 Feature

Motivation

Tensor slices are represented in traces as:

  t107 = torch_slice_prim_impl(t53, [0, 0, 0, 0], [4, 32, 2048, 0], [1, 1, 1, 1])  # t107: "cuda:0 bf16[4, 32, 2048, 0]"

But there's no torch_slice_prim_impl import. And we can use Python to represent it.

This reference comes from:

slice_prim_impl = ex.register_operator("torch_slice_prim_impl", meta=prims.slice_prim.meta, fn=_slice_prim_impl)
_register_implementation(prims.slice_prim, slice_prim_impl, checker=_always_executable)

# TODO When getitem is fully supported this can be changed to be an execution transform instead of a direct impl
def _slice_prim_impl(
a: torch.Tensor, start_indices: Sequence[int], end_indices: Sequence[int], strides: None | Sequence[int] = None
) -> torch.Tensor:
_strides = strides if strides is not None else [1] * len(start_indices)
slices: list = []
for start, stop, step in zip(start_indices, end_indices, _strides):
slices.append(slice(start, stop, step))
return operator.getitem(a, slices)

Pitch

Instead represent it with __getitem__ and slice():

t123 = t321.__getitem__([slice(0, 3), slice(0, 5)])  # t123: "cuda:..."

Alternatives

Add the torch_slice_prim_impl import from torchex to the trace so that it's a valid program

cc @apaz-cli @nikitaved

Support `is_cuda`

Something like the following should work

import thunder
import torch

def foo(x):
    if not x.is_cuda:
        x = x.to('cuda')
    return x * x

x = torch.randn(3, device='cpu')
jit_foo = thunder.jit(foo)
o = jit_foo(x)

print(thunder.last_traces(jit_foo)[-1])

Above fails with

  File "/home/kkalambarkar/lightning-thunder/thunder/core/proxies.py", line 1234, in __getattr__
    method: None | Callable = resolve_method(attr, self)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/langctxs.py", line 68, in resolve_method
    method: Callable = ctx.get_method(id, *args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/torch/langctx.py", line 40, in get_method
    raise AttributeError(f"The {self.name} language context has no method {id}")
AttributeError: The torch language context has no method is_cuda

Functional JIT loading closures sharp edge

Strategy required

This issue resumes form PR2410, we need to decide on the strategy for closures sharp edge. Let's start simple, I think we can all agree that this is a sharp edge if we jit foo:

x = 5
def foo():
      return x

And that's because we are using a variable outside of the jitted scope. However, here is where things get interesting: should we consider the following a sharp egde?

def foo(x):
    def bar():
        return x
    return bar()

I assume that, since we captured x when jitting foo, this should not be a sharp edge for bar because the variable was declared in the scope(or in this case captured). To fix such a case we can remember what variables we captured and then look them up when we see a freevar. However, @mruberry has an interesting point, what happens in the case that the variable gets deleted? How can we deal with something like:

def foo():
  a = 5

  def bar():
    nonlocal a
    del a

  bar()

  return a

In conclusion, what do you think should be the definition of sharp edge in this context?

cc @apaz-cli @t-vi @mruberry

Comparison with `torch.compile` instead of Eager

📚 Documentation

Hey! I saw your tool and plots with "acceleration", but you compare to un-optimised eager torch, which is obviously slower. Could you provide a graph with comparison against basic native pytorch loop, where you torch.compile the model? It would be useful for people who already have some optimisations in their pipelines, but would like to try yours framework instead

Thanks

Build op provenance tracking into compile trace output

🚀 Feature

The request is to be able to connect the practitioner's model code clearly to the produced graph trace by Thunder. Ideally, each traced node should be able to map back to the model code for which it got generated.

Motivation

This would tremendously help debugging issues around graph capture, graph optimization (such as rematerialization, DCE etc.). This also helps improve user understanding of what Thunder is doing. It could also be very helpful for developers to build tools that can operate on top of Thunder graphs.
Examples from TorchInductor in the Pitch.

Pitch/Additional Context

Example of FX Graph debug from TorchInductor - mapping traced graph decomposed nodes back to practitioner model code.

        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:35, code: out = self.conv(x)
        convolution: f32[16, 64, 56, 56] = torch.ops.aten.convolution.default(primals_7, primals_1, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
        
        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:36, code: out = self.bn(out)
        add: i64[] = torch.ops.aten.add.Tensor(primals_6, 1);  primals_6 = None
        var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
        getitem: f32[1, 64, 1, 1] = var_mean[0]
        getitem_1: f32[1, 64, 1, 1] = var_mean[1];  var_mean = None
        add_1: f32[1, 64, 1, 1] = torch.ops.aten.add.Tensor(getitem, 1e-05)
        rsqrt: f32[1, 64, 1, 1] = torch.ops.aten.rsqrt.default(add_1);  add_1 = None
        sub: f32[16, 64, 56, 56] = torch.ops.aten.sub.Tensor(convolution, getitem_1)
        mul: f32[16, 64, 56, 56] = torch.ops.aten.mul.Tensor(sub, rsqrt);  sub = None
        squeeze: f32[64] = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]);  getitem_1 = None
        squeeze_1: f32[64] = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]);  rsqrt = None
        mul_1: f32[64] = torch.ops.aten.mul.Tensor(squeeze, 0.1)
        mul_2: f32[64] = torch.ops.aten.mul.Tensor(primals_4, 0.9);  primals_4 = None
        add_2: f32[64] = torch.ops.aten.add.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
        squeeze_2: f32[64] = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]);  getitem = None
        mul_3: f32[64] = torch.ops.aten.mul.Tensor(squeeze_2, 1.0000199302441455);  squeeze_2 = None
        mul_4: f32[64] = torch.ops.aten.mul.Tensor(mul_3, 0.1);  mul_3 = None
        mul_5: f32[64] = torch.ops.aten.mul.Tensor(primals_5, 0.9);  primals_5 = None
        add_3: f32[64] = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
        unsqueeze: f32[64, 1] = torch.ops.aten.unsqueeze.default(primals_2, -1)
        unsqueeze_1: f32[64, 1, 1] = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = None
        unsqueeze_2: f32[64, 1] = torch.ops.aten.unsqueeze.default(primals_3, -1);  primals_3 = None
        unsqueeze_3: f32[64, 1, 1] = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = None
        mul_6: f32[16, 64, 56, 56] = torch.ops.aten.mul.Tensor(mul, unsqueeze_1);  mul = unsqueeze_1 = None
        add_4: f32[16, 64, 56, 56] = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3);  mul_6 = unsqueeze_3 = None
        
        # File: /scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py:37, code: out = self.relu(out)
        relu: f32[16, 64, 56, 56] = torch.ops.aten.relu.default(add_4);  add_4 = None
        le: b8[16, 64, 56, 56] = torch.ops.aten.le.Scalar(relu, 0)

Similarly, in the final codegen output, one can see which decomposed node belongs in each kernel generated and what aten level op did the decomposed node come from. Inside the kernel, there are also comments describing the practitioner code stack which is included in each kernel.
This is already covered by Thunder today to some extent as the trace output lists all the decomposed nodes which are part of a NVFusion block. But a mapping to original code would be fantastic for better understanding.

# aten._native_batch_norm_legit_functional => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
# aten.relu => relu
# aten.threshold_backward => le
triton_poi_fused__native_batch_norm_legit_functional_relu_threshold_backward_4 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4194304], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*i1', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def triton_poi_fused__native_batch_norm_legit_functional_relu_threshold_backward_4(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xnumel = 3211264
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 3136) % 64

    # ORIGIN:
    # call_function aten.relu.default
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 37, in forward\    out = self.relu(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.add.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.rsqrt.default
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.mul.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.var_mean.correction
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.mul.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.add.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN


    # ORIGIN:
    # call_function aten.sub.Tensor
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 36, in forward\    out = self.bn(out)\
    # END ORIGIN

    tmp0 = tl.load(in_ptr0 + (x3), None)
    tmp1 = tl.load(in_ptr1 + (x1), None)
    tmp3 = tl.load(in_ptr2 + (x1), None)
    tmp10 = tl.load(in_ptr3 + (x1), None)
    tmp12 = tl.load(in_ptr4 + (x1), None)
    tmp2 = tmp0 - tmp1
    tmp4 = 50176.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = tl.math.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tmp14 = tl.where(0 != 0, 0, tl.where(0 > tmp13, 0, tmp13))

    # ORIGIN:
    # call_function aten.le.Scalar
    #   File "/scratch/mojitos/Pytorch/resnet/test_conv_bn_relu.py", line 37, in forward\    out = self.relu(out)\
    # END ORIGIN

    tmp15 = 0.0
    tmp16 = tmp14 <= tmp15
    tl.store(out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp14, None)
    tl.store(out_ptr1 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp16, None)
''')

cc @carmocca

Add stride operation primitive

🚀 Feature

I would like to have Thunder manage stride information to allow for tensor manipulations.

In particular I think the following points need to be discussed:

  • Where does the stride information go? Should it be part of TensorProxy?
  • What can a stride manipulation primitive look like? Are there any particular things we need to be careful about?

Motivation

This will enable us to add new operators that reshape the tensor using the stride like torch.as_strided or torch.Tensor.unfold.

Support for CUDA kernels

🚀 Feature

Hi there 👋

From the main readme file I noticed that Thunder except custom kernels, but only the ones that are written in Trition.
Is there a plan to support CUDA kernels?

Motivation

I'm only in the beginning of the custom kernels journey, so I might misunderstand something.

From what I saw online, there are many of highly optimized CUDA kernels already available (since CUDA has been around for quite a while). Plus, there is a high chance that someone with a lot of experience in writing CUDA kernels (but not Trition) want's to use Thunder (or even integrate into an existing project).

I personally would like to write custom CUDA kernels for the LitGPT repo after I finish reading PMPP book.

[feature request] Indexing with boolean masks

🚀 Feature

No indexing with boolean mask, for example:

import torch; import thunder                                                                                   
m = x <= 0.5
                                                                                                                                                                                                                          
def f(x, m):                                                                                                   
    return x[m]                                    
                                                                                                        
                                                                                                                       
jf = thunder.jit(f)                                
jf(x, m)  

fails with

RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors, but found a tensor with dtype bool8 and 1 dimensions

cc @apaz-cli

Make DDP/FSDP a regular transform

🚀 Feature

Make DDP/FSDP a regular transform (to a large part including making transforms flexible enough to support this).

Motivation

Currently DDP/FSDP is not a regular transform, leading to things like #94 and limiting composability / sequencing.
One of the key bits is that DDP/FSDP would need to do the adjustments we currently do to the prologue during tracing with DDP/FSDP in the transform, so we need to allow mutation of prologues through transforms. This is also in line with similar needs for other transforms (lora, quantization, but also value-and-grad-things) that change prologue signatures, so this generalization should happen.

cc @carmocca @awaelchli @crcrpar

Non-`topk` related issue in `mixtral`-like model tests.

🐛 Bug

Now that we have topk supported, it is time to unlock some tests. However, the following diff:

diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py
index d1d55073..ad69a721 100644
--- a/thunder/tests/test_jit_general.py
+++ b/thunder/tests/test_jit_general.py
@@ -613,7 +613,7 @@ def test_nanogpt():
         "falcon-7b-like",
         "falcon-40b-like",
         "codellama2-like",
-        pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
+        "mixtral-like",
     ),
 )
 @pytest.mark.parametrize(

Breaks pytest -sv thunder/tests/test_jit_general.py -k test_litgpt_variants[cpu-mixtral-like] with

___________________________________________________________________________________________________ test_litgpt_variants[cpu-mixtral-like] ___________________________________________________________________________________________________

name = 'mixtral-like', device = device(type='cpu')

    @skipif_not_pytorch_2_1
    @pytest.mark.parametrize(
        "name",
        (
            "gpt-neox-like",
            "llama1-like",
            "long-context-like",
            "llama2-like",
            "falcon-7b-like",
            "falcon-40b-like",
            "codellama2-like",
            "mixtral-like",
        ),
    )
    @pytest.mark.parametrize(
        "device",
        ("cpu", "cuda"),
    )
    def test_litgpt_variants(name, device):
        if device == "cuda" and not torch.cuda.is_available():
            pytest.skip("CUDA not available")
    
        device = torch.device(device)
    
        x = torch.randint(0, 200, (5, 5), device=device)
        config = litgpt_model.Config.from_name(name)
    
        with device:
            reference = litgpt_model.GPT(config)
        expected_logits = reference(x)
    
        expected_logits.sum().backward()
    
        with device:
            model = litgpt_model.GPT(config)
        model.load_state_dict(reference.state_dict())
        tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex)
>       actual_logits = tom(x)

thunder/tests/test_jit_general.py:642: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/__init__.py:194: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
thunder/core/interpreter.py:6684: in fn_
    raise e
thunder/core/interpreter.py:6647: in fn_2
    return fn(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:94: in forward
    x = block(x, cos, sin, mask, input_pos)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:187: in forward
    x = self.mlp(self.norm_2(x)) + x
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:347: in forward
    token_idx, expert_idx = torch.where(mask)
thunder/core/interpreter.py:1258: in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
thunder/core/symbol.py:250: in __call__
    result = self.meta(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (t157,), kwargs = {}, tok = <Token used var=<ContextVar name='langctx' at 0x7fa2ad45a340> at 0x7f9bf1b6bdc0>

    @wraps(fn)
    def _fn(*args, **kwargs):
        try:
            tok = set_langctx(self.langctx)
>           result = fn(*args, **kwargs)
E           TypeError: where() missing 2 required positional arguments: 'a' and 'b'

thunder/core/langctxs.py:124: TypeError
========================================================================================================== short test summary info ===========================================================================================================
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants[cpu-mixtral-like] - TypeError: where() missing 2 required positional arguments: 'a' and 'b'
=============================================================================================== 1 failed, 54 deselected, 10 warnings in 8.04s ================================================================================================

Weight tying + FSDP = nvfuser internal error

🐛 Bug

To Reproduce

Code:

import os
import torch
import torch.distributed as tdist
import thunder
from thunder.tests.lit_gpt_model import GPT, Config

if __name__ == "__main__":
    tdist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)

    config = Config(block_size=256, padded_vocab_size=32000, n_layer=6, n_head=6, head_size=48, n_embd=288, rotary_percentage=1.0, parallel_residual=False, bias=False, _norm_class='RMSNorm', _mlp_class='LLaMAMLP', intermediate_size=768)
    with device:
        model = GPT(config)

    model.transformer.wte.weight = model.lm_head.weight

    model = thunder.distributed.fsdp(model)
    model = thunder.jit(model)

    input_ids = torch.randint(1, 30010, (128, 256), dtype=torch.long, device=device)
    logits = model(input_ids)
    print(logits.shape)

Run with:

torchrun --nproc-per-node 2 --local-ranks-filter 0 repro.py

Nvfuser repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.ops.mul(T1, T1)
    T3 = fd.ops.sum(T2, dims=[2], keepdim=False, dtype=DataType.Null)
    S4 = fd.define_scalar(128, dtype=DataType.Int)
    S5 = fd.define_scalar(256, dtype=DataType.Int)
    S6 = fd.define_scalar(1, dtype=DataType.Int)
    V7 = fd.define_vector([S4, S5, S6], dtype=DataType.Int)
    T8 = fd.ops.broadcast_in_dim(T3, shape=V7, broadcast_dims=[0, 1])
    S9 = fd.define_scalar(288.000, dtype=DataType.Double)
    S10 = fd.ops.reciprocal(S9)
    T11 = fd.ops.mul(T8, S10)
    S12 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T13 = fd.ops.add(T11, S12)
    T14 = fd.ops.rsqrt(T13)
    S15 = fd.define_scalar(128, dtype=DataType.Int)
    S16 = fd.define_scalar(256, dtype=DataType.Int)
    S17 = fd.define_scalar(288, dtype=DataType.Int)
    V18 = fd.define_vector([S15, S16, S17], dtype=DataType.Int)
    T19 = fd.ops.broadcast_in_dim(T14, shape=V18, broadcast_dims=[0, 1, 2])
    T20 = fd.ops.mul(T1, T19)
    T21 = fd.ops.mul(T20, T0)
    fd.add_output(T14)
    fd.add_output(T21)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((128, 256, 288), (0, 0, 1)),
    torch.randn((9437184,), dtype=torch.float32, device='cuda:0').as_strided((128, 256, 288), (73728, 288, 1)),
]
fd.execute(inputs)

Short error:

RuntimeError: _result == CUDA_SUCCESS INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/executor_utils.cpp":907, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_ASSERT failed with error device-side assert triggered

Full error:

error.txt

Removing one of:

  • FSDP
  • 30010 as the highest input value
  • weight tying

makes the problem not appear

cc @tfogal

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.