Code Monkey home page Code Monkey logo

Comments (6)

k223kim avatar k223kim commented on September 26, 2024 2

Hi Team ⚡️, currently I am working on this issue and would like to share how I reproduced the same error (just as a reference to anyone else who is working on it). It is quite similar to the code shown above, but just smaller :)

  1. Clone and install from source from hugging face transformers.
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout tags/v.4.41.2
pip install -e .
  1. As mentioned above, add this to the CLIPTextTransformer class:
        self.embeddings = thunder.jit(self.embeddings)
        self.encoder = thunder.jit(self.encoder)
        self.final_layer_norm = thunder.jit(self.final_layer_norm)
  1. Run the following script: (This is actually from the huggingface repo)
from transformers import CLIPTokenizer, CLIPTextModel

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
pooled_output = outputs.pooler_output  # pooled (EOS token) states

(cc. @t-vi )

from lightning-thunder.

t-vi avatar t-vi commented on September 26, 2024 1

@k223kim debugged this more and the infinite recursion is from to_printable assuming that tree_flatten will "simplify" the input when it in reality produces the original input as part of the flattened objects for BaseModelOutput (from transformers, a dataclass https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/modeling_outputs.py#L24-L47.)

def to_printable(
    trace: Optional,
    x: Any,
    *,
    import_ctx: Optional[dict] = None,
    object_ctx: Optional[dict] = None,
) -> Printable:
    # Short-circuits if x is a Proxy
    if isinstance(x, ProxyInterface):
        return x

    if is_collection(x):
        flat, spec = tree_flatten(x)

        printables = []
        for f in flat:
            printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))

        printable = tree_unflatten(printables, spec)
        return printable

from lightning-thunder.

t-vi avatar t-vi commented on September 26, 2024 1

More minimal repro to create a test in a fix:

import transformers
import torch
import thunder

def fn(x):
    return transformers.modeling_outputs.BaseModelOutput(x)

jfn = thunder.jit(fn)

x = torch.randn(5, 5)

print(jfn(x))

from lightning-thunder.

athitten avatar athitten commented on September 26, 2024

FYI just figured that self.encoder consisted of a nn.ModuleList with a for loop (shown below) which probably caused the recursion error.
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])

Adding thunder.jit to the individual modules of the for loop instead of the entire nn.ModuleList fixed the RecursionError.

from lightning-thunder.

tfogal avatar tfogal commented on September 26, 2024

Thanks @athitten ! That's really helpful.

Tagging triage review. Triage team, beyond the obvious "add support for control flow", I'm curious what our options are here.

from lightning-thunder.

t-vi avatar t-vi commented on September 26, 2024

Staring down the traceback (rather than running it myself) it does not look like the modules itself (litgpt also uses a for loop over ModuleList), but as if we do have a trace that fails to print itself because of some reference cycle (which might be caused by the interpreter erroneously inserting that into the trace).

from lightning-thunder.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.