Code Monkey home page Code Monkey logo

Comments (10)

pseudotensor avatar pseudotensor commented on June 23, 2024 1

I commented in the PR, but the same inference code but using that PR hit no error, thanks!

from attention_sinks.

FrankEssenberger avatar FrankEssenberger commented on June 23, 2024 1

@FrankEssenberger When using the branch from #23, the main branch, or the latest release?

Also, do you know roughly your input data length? That could also be related, e.g. if the input is 2250 tokens or so.

Sorry I was on holiday - it worked with the latest version of the code. Thanks.

from attention_sinks.

pseudotensor avatar pseudotensor commented on June 23, 2024

Similar thing if use model_id = "h2oai/h2ogpt-4096-llama2-7b-chat".

from attention_sinks.

pseudotensor avatar pseudotensor commented on June 23, 2024

If I set attention_sink_window_size=4096, then it doesn't fail for mistral. Do I have to set the window size larger or equal to the input token size?

For mpt it still fails with some other error about 2048 vs. input token size, so maybe that mpt is not compatible fully.

from attention_sinks.

tomaarsen avatar tomaarsen commented on June 23, 2024

If I set attention_sink_window_size=4096, then it doesn't fail for mistral. Do I have to set the window size larger or equal to the input token size?

That was my initial intuition, also because 4096 is likely the window size that you want if you want it to be able to use the last 4k tokens in memory. However, it shouldn't throw a CUDA indexing error either way - I'm looking into it now.

For MPT, you have to edit the configuration if you want to use anything over 2048 tokens:

# NOTE: Running mpt-7b with pure transformers with a sequence length of larger than 2048 requires the following:
# ```python
# config = AutoConfig.from_pretrained(args.model_name_or_path)
# config.max_seq_len = 8192
# model = AutoModelForCausalLM.from_pretrained(
# model_name_or_path,
# config=config,
# ...,
# )
# ```
# This is not required for "attention_sinks" or "windowed". To prevent crashes I'll put `num_tokens` to 2048 for "transformers"

That said, I'm not sure if MPT can reasonably process sequences longer than 2048, I think the model implodes after 2048, but perhaps not with attention_sinks? Definitely worth a try.

  • Tom Aarsen

from attention_sinks.

tomaarsen avatar tomaarsen commented on June 23, 2024

I actually get a different error when running your code:

Traceback (most recent call last):
  File "[sic]\attention_sinks\issue_22.py", line 412, in <module>
    generated_tokens = model.generate(
  File "[sic]\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "[sic]\transformers\src\transformers\generation\utils.py", line 1658, in generate
    return self.greedy_search(
  File "[sic]\transformers\src\transformers\generation\utils.py", line 2506, in greedy_search
    outputs = self(
  File "[sic]\\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "[sic]\transformers\src\transformers\models\mistral\modeling_mistral.py", line 1048, in forward
    outputs = self.model(
  File "[sic]\\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "[sic]\attention_sinks\attention_sinks\inject_mixin.py", line 131, in wrapped_forward
    outputs = old_forward(*args, **kwargs)
  File "[sic]\transformers\src\transformers\models\mistral\modeling_mistral.py", line 891, in forward
    attention_mask = self._prepare_decoder_attention_mask(
  File "[sic]\transformers\src\transformers\models\mistral\modeling_mistral.py", line 813, in _prepare_decoder_attention_mask
    expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
RuntimeError: The size of tensor a (3817) must match the size of tensor b (3818) at non-singleton dimension 3

Will dig into this deeper.

Edit: This is probably because I'm using the wrong transformers version. My bad

from attention_sinks.

tomaarsen avatar tomaarsen commented on June 23, 2024

@pseudotensor Perhaps you can experiment with

pip install git+https://github.com/tomaarsen/attention_sinks.git@hotfix/long_input_seq

I'll do some more tests of my own later.

from attention_sinks.

pseudotensor avatar pseudotensor commented on June 23, 2024

Thanks! I should clarify I'm using transformers==4.34.1 -- I had upgraded just in case it would help with the failure but it didn't change anything.

I'll check in morning w.r.t. the PR.

from attention_sinks.

FrankEssenberger avatar FrankEssenberger commented on June 23, 2024

Hi, just a quick comment. I see the same error with Mistral and when I use a attention_sink_window_size=2300 it works and attention_sink_window_size=2200 it fails with the out of bounds. Since mistral has 4096 sliding windows it could be that the error is somehow related to a different issue.

from attention_sinks.

tomaarsen avatar tomaarsen commented on June 23, 2024

@FrankEssenberger When using the branch from #23, the main branch, or the latest release?

Also, do you know roughly your input data length? That could also be related, e.g. if the input is 2250 tokens or so.

from attention_sinks.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.