Code Monkey home page Code Monkey logo

Comments (13)

dacorvo avatar dacorvo commented on July 22, 2024

Note that the error on the 24 cores host is a lot longer, but it is more or less a repetition of the same lines.

from transformers-neuronx.

dacorvo avatar dacorvo commented on July 22, 2024

Script to reproduce:

stress_generation.py:
import argparse
import time

import torch
from transformers import AutoConfig, AutoTokenizer, set_seed

from optimum.neuron import NeuronModelForCausalLM


def generate(model, tokenizer, prompts, length):
    # Encode tokens and generate
    tokens = tokenizer(prompts, return_tensors="pt", padding=True)
    start = time.time()
    with torch.inference_mode():
        sample_output = model.generate(
            **tokens,
            do_sample=True,
            min_length=length,
            max_length=length
        )
    end = time.time()
    outputs = [tokenizer.decode(tok) for tok in sample_output]
    return outputs, (end - start)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model", type=str, help="A neuron model in a local directory.")
    parser.add_argument(
        "--prompts",
        type=str,
        default="One of my fondest memory is",
        help="The prompts to use for generation, using | as separator.",
    )
    parser.add_argument("--inc-length", type=int, default=128, help="The number of tokens in each increment.")
    parser.add_argument("--max-length", type=int, default=2048, help="The maximum number of generated tokens.")
    parser.add_argument(
        "--iterations", type=int, default=1, help="The number of time one should repeat the sequence."
    )
    parser.add_argument("--seed", type=int, default=None, help="Pass a seed for reproducibility.")
    args = parser.parse_args()
    if args.seed is not None:
        set_seed(args.seed)
    prompts = args.prompts.split("|")
    config = AutoConfig.from_pretrained(args.model)
    batch_size = config.neuron["batch_size"]
    if len(prompts) < batch_size:
        prompts = prompts + [prompts[-1]] * (batch_size - len(prompts))
    model = NeuronModelForCausalLM.from_pretrained(args.model, export=False, low_cpu_mem_usage=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    # Specifiy padding options for decoder-only architecture
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    for _ in range(args.iterations):
        for length in range(args.inc_length, args.max_length + 1, args.inc_length):
            prompts, latency = generate(model, tokenizer, prompts, length)
            print(prompts)
            print(f"{length} outputs generated using Neuron model in {latency:.4f} s")

Just type something like:

python stress_generation.py ./data/llama2_7b_bs4 --inc-length 1024 --max-length 2048

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo , thank you I'm going to try reproducing this issue.

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo I tried a regular transoformers-neuronx flow and could not reproduce the problem. Could you send me the script or optimum-cli command that you used to export to Neuron?

from transformers-neuronx.

dacorvo avatar dacorvo commented on July 22, 2024

@awsilya I just compiled the model with the following script: https://github.com/huggingface/optimum-neuron/blob/main/examples/text-generation/run_generation.py.

I launched the following command:

NEURON_CC_FLAGS="-O1" python examples/text-generation/run_generation.py meta-llama/Llama-2-7b-hf --auto_cast_type f16 --num_cores 2 --batch_size 4 --save_dir ./llama2_7b_bs4

And then I launched the stress test above.

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo - this is for 2 cores on inf2.2xl and batch 2? Is it reproducible or do I need to compile for batch 4 and 24 cores ?

from transformers-neuronx.

dacorvo avatar dacorvo commented on July 22, 2024

Sorry, I made a mistake: it is indeed for batch 4 only. I modified the command in my previous comment.

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo very interesting issue, thank you for discovering it. We have an execution timeout, mostly to detect h/w issues. It is set to 30 sec. by default. Batch 4 len=2048 execution takes ~50sec so we timeout too soon. For now, just set the timeout to a longer value by setting the env. var.

NEURON_RT_EXEC_TIMEOUT=120 (this is in seconds)

longer term, I need to figure out some heuristic to pick a reasonable timeout.

from transformers-neuronx.

dacorvo avatar dacorvo commented on July 22, 2024

I increased the timeout, but now I have an OOM error:

2023-Oct-02 08:17:45.0945  4703:4866  ERROR  TDRV:dmem_alloc_internal                     Failed to alloc DEVICE memory: 13743184
2023-Oct-02 08:17:45.0946  4703:4865  ERROR  TDRV:dmem_alloc_internal                     Failed to alloc DEVICE memory: 13743184
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:dml_dump                                Wrote nrt memory alloc debug info to /tmp/nrt_mem_log_device_0_651a7ca9.csv
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:log_dev_mem                             Failed to allocate 13.107MB (usage: dma rings) on ND 0:NC 1, current utilization:
        * total: 15.970GB
        * model code: 233.719MB
        * model constants: 1.885MB
        * tensors: 8.216GB
        * shared scratchpad: 512.000MB
        * runtime: 3.562KB
        * dma rings: 6.518GB
        * collectives: 518.002MB

2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:ddrs_create_ring                        Failed to create descriptors pring for dynamic ring
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:ddrs_build_ring_from_template           Failed to allocate memory for dma ring for queue qSPPIOParam0_4
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:ddrs_build_ring_and_add_to_ring_set     Failed to build dynamic ring template for queue qSPPIOParam0_4
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:ddrs_build_dma_rings_for_queue          Failed to build and add ring to ring qSPPIOParam0_4 set
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  TDRV:kbl_compute_build_compute_resources     Failed to build dma rings for io
2023-Oct-02 08:17:45.0957  4703:4866  ERROR  NMGR:dlr_infer                               Failed to build ioq switching descriptors
terminate called after throwing an instance of 'c10::Error'
  what():  nrt_execute status=4
Exception raised from task at /opt/workspace/KaenaPyTorchRuntime/neuron_op/ops/tensor.cpp:845 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fe946a96457 in /home/ubuntu/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fe946a603ec in /home/ubuntu/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: neuron::task(void*) + 0x2e9 (0x7fe8edd72c19 in /home/ubuntu/.local/lib/python3.8/site-packages/torch_neuronx/lib/libtorchneuron.so)
frame #3: <unknown function> + 0x8609 (0x7fe9bdd12609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #4: clone + 0x43 (0x7fe9bde4c133 in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo this is unrelated, did you use -O1 compiler option ?

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

this is what I did:

  • compile using NEURON_CC_FLAGS="-O1" python examples/text-generation/run_generation.py meta-llama/Llama-2-7b-hf --auto_cast_type f16 --num_cores 2 --batch_size 4 --save_dir ./llama2_7b_bs4

  • execute - got the timeout error

  • added NEURON_RT_EXEC_TIMEOUT=120 and executed again, successful took about 50 sec for 2048 tokens.

from transformers-neuronx.

dacorvo avatar dacorvo commented on July 22, 2024

I did the exact same thing, and git the OOM error: shall I open yet another issue ?

from transformers-neuronx.

awsilya avatar awsilya commented on July 22, 2024

@dacorvo - Misread the logs above. This is not a compilation issue because we are past model loading and are executing inference. Yes, please open a new ticket, it's unrelated to everything else we debugged so far - "OOM while running llama inference".

Are you passing longer sequences ? Trying to run multiple requests in parallel? Also is this tp=2 or tp=24?

thanks.

from transformers-neuronx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.