Code Monkey home page Code Monkey logo

Comments (6)

aws-donkrets avatar aws-donkrets commented on July 3, 2024 2

Hi enochlev.
You are likely seeing a significant amount of overhead because the context prefix step is not being performed in parallel and is falling back to populating the KV cache token-by-token. When initializing the model you have the option to provide context length estimates for the number of tokens that you expect to see as inputs. By default this is set to be half the size of n_positions.
You can set this field to a single integer or to a list of integers to create models for different input size estimates:

LlamaForSampling.from_pretrained(... context_length_estimate=3000, n_positions=4000)

# Alternative
LlamaForSampling.from_pretrained(... context_length_estimate=[100, 1500, 3000], n_positions=4000)

At inference time, the model will select the context length estimate network that fits all of the input tokens. For any remaining tokens, it will fall back to token-by-token prefill (this is the behavior that you are currently seeing).

from transformers-neuronx.

enochlev avatar enochlev commented on July 3, 2024

Also lmk if this is out of context for this repo

from transformers-neuronx.

aws-donkrets avatar aws-donkrets commented on July 3, 2024

Hi enochlev. This is not out of context for this repo :-) Are you seeing the 10-30 second overhead consistently for your 3000 token context or are you seeing the delay only when running a single inference? Sometimes there is a warm-up period for model execution.

from transformers-neuronx.

enochlev avatar enochlev commented on July 3, 2024

I appreciate your time and you responding

It is not a warmup delay, it is consistent with any nth inference

After further testing I have found out this is an issue when I run it with torchserve only
via
torchserve --ncs --start --model-store model_store --ts-config config.properties --models llama-2-13b

running model inference manually I do not see any delay

from transformers_neuronx.llama.model import LlamaForSampling, LlamaForCausalLM
neuron_model = LlamaForSampling.from_pretrained('model_store/llama-2-13b/llama-2-13b-split/', batch_size=2, tp_degree=6, amp='bf16')
neuron_model.load('model_store/llama-2-13b/neuron_artifacts')
neuron_model.to_neuron()



with torch.inference_mode():
    start = time.time()
    generated_sequences = neuron_model.sample(input_ids, sequence_length=len(input_ids[0]) + 80, top_k=50)
    elapsed = time.time() - start

print(f'generated sequences {generated_sequences} in {elapsed} seconds')

I even modified the entire inf2_handler.py to work outside the torchserve api via like this

``` import logging import os from abc import ABC from threading import Thread
import torch_neuronx
from transformers import AutoConfig, LlamaTokenizer
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
from transformers_neuronx.llama.model import LlamaForSampling

from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch
from ts.handler_utils.micro_batching import MicroBatching
from ts.protocol.otf_message_handler import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class LLMHandler(BaseHandler, ABC):
    
    #Transformers handler class for text completion streaming on Inferentia2
    def __init__(self):
        super().__init__()
        self.initialized = False
        self.max_length = None
        self.tokenizer = None
        self.output_streamer = None
        # enable micro batching
        self.handle = MicroBatching(self)

    def initialize(self):
        model_dir = 'model_store/llama-2-13b'
        model_checkpoint_path = 'model_store/llama-2-13b/llama-2-13b-split/'
        os.environ["NEURONX_CACHE"] = "on"
        os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache"
        os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"

        # micro batching initialization
        micro_batching_parallelism = True
        if micro_batching_parallelism:
            logger.info(
                f"Setting micro batching parallelism  from model_config_yaml: {micro_batching_parallelism}"
            )
            self.handle.parallelism = {'preprocess': 2, 'inference': 1, 'postprocess': 2}

        micro_batch_size = 2
        logger.info(f"Setting micro batching size: {micro_batch_size}")
        self.handle.micro_batch_size = micro_batch_size

        # settings for model compiliation and loading
        amp = "bf16"
        tp_degree = 6
        self.max_length = 50

        # allocate "tp_degree" number of neuron cores to the worker process
        os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
        try:
            num_neuron_cores_available = (
                torch_neuronx.xla_impl.data_parallel.device_count()
            )
            assert num_neuron_cores_available >= int(tp_degree)
        except (RuntimeError, AssertionError) as error:
            logger.error(
                "Required number of neuron cores for tp_degree "
                + str(tp_degree)
                + " are not available: "
                + str(error)
            )

            raise error

        self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = LlamaForSampling.from_pretrained(
            model_checkpoint_path,
            batch_size=self.handle.micro_batch_size,
            amp=amp,
            tp_degree=tp_degree,
        )
        logger.info("Starting to compile the model")
        print("loading artifacts.....................")
        self.model.load(f"{model_dir}/neuron_artifacts")
        self.model.to_neuron()
        logger.info("Model has been successfully compiled")
        model_config = AutoConfig.from_pretrained(model_checkpoint_path)
        self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
        self.output_streamer = TextIteratorStreamerBatch(
            self.tokenizer,
            batch_size=self.handle.micro_batch_size,
            skip_special_tokens=True,
        )

        self.initialized = True

    def preprocess(self, input_text):

        # Ensure the compiled model can handle the input received
        if len(input_text) > self.handle.micro_batch_size:
            raise ValueError(
                f"Model is compiled for batch size {self.handle.micro_batch_size} but received input of size {len(input_text)}"
            )

        # Pad input to match compiled model batch size
        input_text.extend([""] * (self.handle.micro_batch_size - len(input_text)))

        return self.tokenizer(input_text, return_tensors="pt", padding=True)

    def inference(self, tokenized_input):
        print("Inferance")
        generation_kwargs = dict(
            input_ids =tokenized_input,
            streamer=self.output_streamer,
            max_new_tokens=self.max_length,
        )
        self.model.reset_generation()
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        #micro_batch_idx = self.handle.get_micro_batch_idx()
        #micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
        for new_text in self.output_streamer:
            logger.debug("send response stream")
            print(new_text)
            # send_intermediate_predict_response(
            #     new_text[: len(micro_batch_req_id_map)],
            #     micro_batch_req_id_map,
            #     "Intermediate Prediction success",
            #     200,
            #     self.context,
            # )

        thread.join()

        #return [""] * len(micro_batch_req_id_map)

    def postprocess(self, inference_output):
        return inference_output

    def get_micro_batch_req_id_map(self, micro_batch_idx: int):
        start_idx = micro_batch_idx * self.handle.micro_batch_size
        micro_batch_req_id_map = {
            index: self.context.request_ids[batch_index]
            for index, batch_index in enumerate(
                range(start_idx, start_idx + self.handle.micro_batch_size)
            )
            if batch_index in self.context.request_ids
        }

        return micro_batch_req_id_map
```

and there was no startup delay

I guess I can modify this and build my own API wrapper around it, but would be nice if I didn't have to rewrite/maintain api code

from transformers-neuronx.

enochlev avatar enochlev commented on July 3, 2024

That fixed it

That is the exact behavior I was observing. Input length (not input + output) had to be smaller then the largest number of context_length_estimate for it to start generating instantly.

from transformers-neuronx.

aws-donkrets avatar aws-donkrets commented on July 3, 2024

enochlev - glad to hear this resolved your issue; closing this ticket for now.

from transformers-neuronx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.