Code Monkey home page Code Monkey logo

Comments (15)

mripacampus avatar mripacampus commented on August 27, 2024 4

Exactly! This is how I did it:

from threading import Thread
from queue import Queue

def generate_streaming_completion(options):
  model = options.pop("model")
  tokenizer = options.pop("tokenizer")
  model_options = options.pop("model_options")
  stream = model_options.stream and model_options.num_beams == 1

  q = Queue()

  generation_config = GenerationConfig(
    temperature=model_options.temperature,
    top_p=model_options.top_p,
    top_k=model_options.top_k,
    num_beams=model_options.num_beams,
    max_new_tokens=model_options.max_new_tokens,
  )

  prompt = generate_prompt(model_options.instruction, model_options.input)
  inputs = tokenizer(prompt, return_tensors="pt")
  input_ids = inputs["input_ids"].cuda()


  def stream_callback(res):
    q.put(json.dumps({ "text": res }) + '\n')

  logits_processor= [CallbackLogitsWarper(tokenizer, stream_callback)] if stream else None

  def generate():
    with torch.no_grad():
      model.eval()
      model.generate(
          input_ids=input_ids,
          generation_config=generation_config,
          logits_processor=logits_processor,
          return_dict_in_generate=True,
          # output_scores=True,
          # max_new_tokens=600
      )
      print('STREAMING DONE')
    torch.cuda.empty_cache()
    q.put("[DONE]")

  # start the generate function in a new Thread so that the code doesn't stop executing here
  Thread(target=generate,args=()).start()

  while True:
    next_item = q.get(True,10000) # Blocks until an input is available
    if next_item == "[DONE]":
        yield next_item
        break
    yield next_item

from alpaca-lora.

canonhui avatar canonhui commented on August 27, 2024 1

That's awesome! Solved my problem perfectly! Greatly appreciated😊

from alpaca-lora.

baleksey avatar baleksey commented on August 27, 2024

Yeah, you can do that. You need to create LogitsWarper

from transformers import LogitsWarper
class CallbackLogitsWarper(LogitsWarper):
    def __init__(self, tokenizer, callback):
        self.tokenizer = tokenizer
        self.callback = callback
        self.res_tokens = []

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
        self.res_tokens.append(input_ids[0][-1])
        result = self.tokenizer.decode(self.res_tokens).lstrip()
        self.callback(result) # return current generation already in words format back
        return scores

Then add logits_processor param to your mode.generate()

def callback(result):
    print(result)
     
generation_output = model.generate(
    input_ids=input_ids,
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=100,
    logits_processor=[CallbackLogitsWarper(tokenizer, callback)],
)

from alpaca-lora.

marcoripa96 avatar marcoripa96 commented on August 27, 2024

Thanks for the fast replay! I tried but it doesn't seem those are the final tokens produced in the generation_output variable.

This is what I did:

from transformers import LogitsWarper
import torch
class CallbackLogitsWarper(LogitsWarper):
    def __init__(self, tokenizer, callback):
        self.tokenizer = tokenizer
        self.callback = callback
        self.res_tokens = []

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
        self.res_tokens.append(input_ids[0][-1])
        # result = self.tokenizer.decode(self.res_tokens).lstrip()
        result = self.tokenizer.decode(input_ids[0][-1])
        self.callback(result)
        return scores
generation_config = GenerationConfig(
    temperature=0.1,
    top_p=0.75,
    num_beams=4,
)

def callback(result):
    sys.stdout.write(result)

def evaluate(instruction, input=None):
    prompt = generate_prompt(instruction, input)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].cuda()
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=256,
        logits_processor=[CallbackLogitsWarper(tokenizer, callback)]
    )
    
    for s in generation_output.sequences:
        output = tokenizer.decode(s)
        print()
        print("Response:", output.split("### Response:")[1].strip())

And this is what is being produced by "streaming" each token compared to the final generation_output:

Instruction: Give me a random sentence
 : <0x0A> The sun was sh ining bright ly in the sky, casting a ating the low over the landscape. </s> <0x0A>
Response: The sun was shining brightly in the clear blue sky.

from alpaca-lora.

baleksey avatar baleksey commented on August 27, 2024

In GenerationConfig remove  num_beams parameter at all (or set it to 1)

from alpaca-lora.

marcoripa96 avatar marcoripa96 commented on August 27, 2024

It kinda works now, but still I don't get why I get partial tokens and I don't really have a way to know which tokens needs to be joined (ch icken in the example below). Additionally what's with the : and <0x0A> strings at the beginning of the generated sentence?

Instruction: Tell me a funny joke
Response obtained decoding one token at a time : <0x0A> Why did the ch icken cross the road? To get to the other side! </s>
Response obtained decoding all tokens together: Why did the chicken cross the road? To get to the other side!

from alpaca-lora.

baleksey avatar baleksey commented on August 27, 2024

": and <0x0A>" - are the part of trained response (actually it's ":\n"). You can replace(":\n","") before print out so you won't see it. As for strange "ch icken" this is because (not sure 100%) different tokens combinations can produce different results by tokenizer.decode() and if you try to decode each token by it's own - you'll see that often. That's why in my example I decode the whole result and send it instead of separate tokens. In this way result will be stable and expected.

from alpaca-lora.

mripacampus avatar mripacampus commented on August 27, 2024

I see. The problem with decoding each time the whole sequence is that it’s really not optimal. Imagine a server streaming the response. Ideally you would send one token each time and a possibile client would concatenate the responses to form the complete output.

I tried looking around and it seems there would be open issues for supporting streaming on the generate function for the HG transformers library. I guess for now I can use this.

from alpaca-lora.

baleksey avatar baleksey commented on August 27, 2024

It's not optimal, agree. Using this hack on my local app/gui so don't really care for now.

Please, let us know if you find a way to stream it one at a time without those artifacts.

from alpaca-lora.

maralski avatar maralski commented on August 27, 2024

Look at the approach used in this repo. It is a different model but you may be able to replicate the streaming approach.

from alpaca-lora.

canonhui avatar canonhui commented on August 27, 2024

can we yield the value of the variable result in CallbackLogitsWarper.__call__ to a generator?

from alpaca-lora.

mripacampus avatar mripacampus commented on August 27, 2024

You can, I used a queue for that. Enqueue the result of the callback and infinitely loop and wait till there is a new item in the queue, then yield that item. I’ll give an example later on

from alpaca-lora.

canonhui avatar canonhui commented on August 27, 2024

So in this way, we need multiprocessing to simultaneously run the generate and the yield?

from alpaca-lora.

fragro avatar fragro commented on August 27, 2024

Looks like the issue with incomplete words has been solved with the latest PR to transformers:

huggingface/transformers#22449

Going to give this a shot this evening.

from alpaca-lora.

gante avatar gante commented on August 27, 2024

@fragro some extra context behind huggingface/transformers#22449 -- different tokenizers have different strategies to stitch tokens together, so the solution I've built there is a simple heuristic that (AFAIK) works with all models/tokenizers.

If you notice it is causing slowdowns, I'm sure better model-specific strategies can be found. However, since decoding is not a heavy task, it should be fine :)

from alpaca-lora.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.