Comments (6)
Hi Team ⚡️, currently I am working on this issue and would like to share how I reproduced the same error (just as a reference to anyone else who is working on it). It is quite similar to the code shown above, but just smaller :)
- Clone and install from source from hugging face transformers.
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout tags/v.4.41.2
pip install -e .
- As mentioned above, add this to the
CLIPTextTransformer
class:
self.embeddings = thunder.jit(self.embeddings)
self.encoder = thunder.jit(self.encoder)
self.final_layer_norm = thunder.jit(self.final_layer_norm)
- Run the following script: (This is actually from the huggingface repo)
from transformers import CLIPTokenizer, CLIPTextModel
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
pooled_output = outputs.pooler_output # pooled (EOS token) states
(cc. @t-vi )
from lightning-thunder.
@k223kim debugged this more and the infinite recursion is from to_printable
assuming that tree_flatten
will "simplify" the input when it in reality produces the original input as part of the flattened objects for BaseModelOutput
(from transformers, a dataclass https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/modeling_outputs.py#L24-L47.)
def to_printable(
trace: Optional,
x: Any,
*,
import_ctx: Optional[dict] = None,
object_ctx: Optional[dict] = None,
) -> Printable:
# Short-circuits if x is a Proxy
if isinstance(x, ProxyInterface):
return x
if is_collection(x):
flat, spec = tree_flatten(x)
printables = []
for f in flat:
printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
printable = tree_unflatten(printables, spec)
return printable
from lightning-thunder.
More minimal repro to create a test in a fix:
import transformers
import torch
import thunder
def fn(x):
return transformers.modeling_outputs.BaseModelOutput(x)
jfn = thunder.jit(fn)
x = torch.randn(5, 5)
print(jfn(x))
from lightning-thunder.
FYI just figured that self.encoder
consisted of a nn.ModuleList
with a for loop (shown below) which probably caused the recursion error.
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
Adding thunder.jit
to the individual modules of the for loop instead of the entire nn.ModuleList fixed the RecursionError.
from lightning-thunder.
Thanks @athitten ! That's really helpful.
Tagging triage review. Triage team, beyond the obvious "add support for control flow", I'm curious what our options are here.
from lightning-thunder.
Staring down the traceback (rather than running it myself) it does not look like the modules itself (litgpt also uses a for loop over ModuleList), but as if we do have a trace that fails to print itself because of some reference cycle (which might be caused by the interpreter erroneously inserting that into the trace).
from lightning-thunder.
Related Issues (20)
- Add support for random ops in OpInfo
- enable using python_callable without mapping symbols to their impls HOT 1
- don't silently drop symbols without implementation
- Memory leak when raising an exception in jitted fn and catching it outside
- Prologue trace orders arguments in a way that breaks aliasing relation HOT 2
- PR #1110 nearly doubles the compilation & execution time of a copy-heavy program HOT 8
- HF LLaVa support HOT 4
- Thunder seems to use way more memory when `litgpt.Config.parallel_residual=True` HOT 5
- FSDP2 & Thunder looks memory hungrier than `thunder.distributed.fsdp` for certain models HOT 2
- `test_auto_register_torchops.py::TestFallbackToTorch::test_alexnet` is failing HOT 3
- improve shape accuracy in transform output by providing better tooling HOT 4
- provide trace checker and debug mode with it enabled HOT 2
- python `slice` is not represented properly in thunder
- [NeVa] thunder.core.interpreter.InterpreterError: Encountered exception IndexError: list index out of range while tracing GraphModule HOT 7
- Recipes and high-level entrypoint HOT 5
- Enable NvtxProfileTransform by default HOT 6
- `Tensor.copy_` tries to copy onto an intermediate tensor in a canonicalized trace
- Propagate tag information throughout a trace lifetime HOT 1
- Transform writer's guide
- KeyError on fusion remat when using saved for backward remat
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning-thunder.