Code Monkey home page Code Monkey logo

Comments (3)

TheBloke avatar TheBloke commented on August 16, 2024

Update: I've been doing some more research and testing and have confirmed that this same or very similar issue happens with simple test code that has been confirmed to work for other MPS users (on Apple Silicon).

I've just raised a bug report on the Transformers Github, here: huggingface/transformers#22529

I'm pretty confident in saying this issue isn't specific to LLaMa_MPS so I understand if you want to close this. But if anyone has any thoughts on how to debug it further, that'd be really helpful!

from llama_mps.

TheBloke avatar TheBloke commented on August 16, 2024

I've done a bit more research and found this PyTorch issue which looks very similar: pytorch/pytorch#92311

if I add an extra print statement to llama/generation.pyto output the model report, I see this:

Enter your LLaMA prompt: fishing is
Thinking...
raw output:  [1, 9427, 292, 338]
raw output:  [1, 9427, 292, 338, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]
 ⁇ raw output:  [1, 9427, 292, 338, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]

So that definitely looks like the issue reported in PyTorch, where MPS returns -9223372036854775808 when argmax() is used, which I guess must be happening inside the transformers code.

from llama_mps.

TheBloke avatar TheBloke commented on August 16, 2024

I got it working!

tomj@Eddie ~/src/LLaMA_MPS (main●●)$ PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 PYTORCH_ENABLE_MPS_FALLBACK=1 ~/anaconda3/envs/torch21/bin/python chat.py --ckpt_dir ~/Downloads/Torrents/Done/LLaMA/7B --tokenizer_path ~/src/llama.cpp/models/tokenizer.model --max_batch_size 8 --max_seq_len 256 --temperature 0.8
Seed: 55702
Loading checkpoint
Loaded in 14.32 seconds
Running the raw 'llama' model in an auto-complete mode.
Enter your LLaMA prompt: llamas are
Thinking...
2 to 3 times more likely to get lymphoma than other breeds.
Source: The Canine Lymphoma Book, by Dr. Marty Goldstein
What is the most common form of cancer in dogs?
Lymphoma (or lymphosarcoma), is the most common type of cancer diagnosed in dogs.

It definitely was the problem mentioned in the PyTorch issue. Specifically, torch.argmax() is broken and always returns -9223372036854775808. And the same problem applies to torch.multinomial() which calls argmax in the C++ code.

To get this working I had to:

  1. Hack my local copy of Transformers, finding every relevant reference to argmax(..) and replacing it with max(..).indices
  2. Modify LLaMa_MPS' generation.py as follows:
    a. Change one reference to argmax(..) to max(..).indices (in the code path for temp = 0.0)
    b. Rewrite the section that calls torch.multinomial(). The first thing I tried was simply moving to CPU to use torch.multinomial, then move back. That worked, but at the cost of CPU usage. So then I tried asking GPT-4 how to rewrite the code without using multinomial. It provided the following code; I have no idea if it's functionally identical! But it does appear to work. Based on this here is how I changed generation.py:
                    next_token_scores = torch.nn.functional.softmax(
                        next_token_scores, dim=-1
                    )
                    # Can't run torch.multinomial() due to MPS bug
                    #next_token = torch.multinomial(
                    #    next_token_scores, num_samples=1
                    #)

                    ## This code written by GPT-4 when I told it I couldn't use torch.multinomial!
                    # Calculate cumulative distribution
                    cumulative_probs = torch.cumsum(next_token_scores, dim=-1)

                    # Sample random values between 0 and 1
                    random_values = torch.rand((cumulative_probs.size(0), 1))

                    # Find the indices where random_values would be inserted to maintain sorted order
                    next_token = torch.searchsorted(cumulative_probs, random_values)
                    ## End GPT-4 written code

                    next_token = next_token.squeeze(1)

And now this produces output, and uses the GPU:
image

Overall performance isn't all that great and inference definitely seems slower than llama.cpp, but you did warn that this would likely be the case. But my main reason for trying to get this to work was to see if I could later try some model fine tuning on MPS, so it was good to get GPU usage working in general. And learn a bit more about PyTorch and Transformers in the process.

I have no idea if this info will benefit anyone else, and hopefully the PyTorch issue will be fixed soon anyway. But it was interesting exploring it.

Thanks again for making this code available!

from llama_mps.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.