Code Monkey home page Code Monkey logo

Comments (13)

frankfliu avatar frankfliu commented on July 24, 2024

What error did you see?

tracing a PyTorch model has nothing to do with Chinese language support. As long as you the model support Chinese, it should just work.

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

I downloaded the large v3 version of the model from Whisper and encountered an exception after running it locally.

ai.djl.engine.EngineException: PytorchStreamReader failed locating file constants.pkl: file not found

from djl.

frankfliu avatar frankfliu commented on July 24, 2024

You need to trace the model, see: https://docs.djl.ai/examples/docs/whisper_speech_text.html#trace-the-model

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

I inputted a Chinese voice, but the output result was not as expected.
This is the code to convert Whisper Large to Torch Script.

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import torch
import numpy as np

processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processor.tokenizer.save_pretrained("whisper-tokenizer")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", return_dict=False)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

test = []
for ele in ds:
    test.append(ele["audio"]["array"])

input_features = processor(np.concatenate(test), return_tensors="pt").input_features
generated_ids = model.generate(inputs=input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Original: " + transcription)

// Start tracing
traced_model = torch.jit.trace_module(model, {"generate": [input_features]})
generated_ids = traced_model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Traced: " + transcription)

torch.jit.save(traced_model, "whisper-large.pt") 

Tracing log
image

Now the program has not gone down and there are no errors

[WARN ] - Simple repository pointing to a non-archive file.
[INFO ] - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[INFO ] - Number of inter-op threads is 6
[INFO ] - Number of intra-op threads is 10
[INFO ] - <|startoftranscript|> <|zh|> Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission Commission <|endoftext|>
Input #0, wav, from 'D:\workspace\djl\examples\build\example\chinese.wav':
  Duration: 00:00:45.96, bitrate: 256 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, 1 channels, s16, 256 kb/s

from djl.

frankfliu avatar frankfliu commented on July 24, 2024

Do python code produce correct output?

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

Do python code produce correct output?

The image of the Tracing log contains Python output logs, only some warning logs.

from djl.

frankfliu avatar frankfliu commented on July 24, 2024

I mean do you get expect result when you use python to run inference?

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

The log printed in the demo code using tracing looks problematic

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import torch
import numpy as np

processor = WhisperProcessor.from_pretrained("openai/whisper-base")
processor.tokenizer.save_pretrained("whisper-tokenizer")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", return_dict=False)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

test = []
for ele in ds:
    test.append(ele["audio"]["array"])

input_features = processor(np.concatenate(test), return_tensors="pt").input_features
generated_ids = model.generate(inputs=input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Original: " + transcription)

# Start tracing
traced_model = torch.jit.trace_module(model, {"generate": [input_features]})
generated_ids = traced_model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Traced: " + transcription)

torch.jit.save(traced_model, "whisper-large.pt")

image

Using Python code inference is normal

from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import soundfile as sf
import librosa

processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", return_dict=True)

# Load and process the FLAC audio file
audio_path = "jfk.flac"
audio, sample_rate = librosa.load(audio_path, sr=16000)  # Resample audio to 16000 Hz

input_features = processor(audio, return_tensors="pt").input_features

# Generate text transcription
generated_ids = model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print("Transcription: " + transcription)

image

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

Hi Frank ! I successfully converted the model after switching to a different computer, but after loading the model in Java and inputting Chinese audio for inference, the output result is in English. Is it caused by WhisperTranslator?Can you please advise me on how to correct it?

image

Python inference
image

from djl.

frankfliu avatar frankfliu commented on July 24, 2024

The WhisperTranslator was not implemented properly, I fixed the tokenizer decode function, now it should support Chinese. You can use our default model from model zoo, it should just work.

Did you use forced_decoder_ids when you trace your model? that might be reason you get English.

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

I tried using the new implementation of WhisperTranslator, but the result was still in English text. And predict a three second video, with the first one taking fifty seconds and the second one taking forty seconds.

        Criteria<Audio, String> criteria =
                Criteria.builder()
                        .setTypes(Audio.class, String.class)
                        .optModelUrls("file:build/example/whisper-large-v2.pt")
                        .optEngine("PyTorch")
                        .optDevice(Device.cpu())
                        .optTranslator(new CustWhisperTranslator())
                        .build();
        whisperModel = criteria.loadModel();
        predictor = whisperModel.newPredictor();
public class CustWhisperTranslator implements NoBatchifyTranslator<Audio, String> {

    private static final Map<Character, Byte> BYTES_DECODER = bpeDecoder();

    private List<AudioProcessor> processors;
    private Vocabulary vocabulary;

    public CustWhisperTranslator() {
        processors = new ArrayList<>();
    }

    /** {@inheritDoc} */
    @Override
    public void prepare(TranslatorContext ctx) throws IOException {
        Path path = ctx.getModel().getModelPath();
        Path melFile = path.resolve("mel_80_filters.npz");

        processors.add(new PadOrTrim(480000));
        // Use model's NDManager
        NDManager modelManager = ctx.getModel().getNDManager();
        processors.add(LogMelSpectrogram.newInstance(melFile, 80, modelManager));

        Map<String, Integer> vocab;
        Map<String, Integer> added;
        Type type = new TypeToken<Map<String, Integer>>() {}.getType();
        try (Reader reader = Files.newBufferedReader(path.resolve("vocab.json"))) {
            vocab = JsonUtils.GSON.fromJson(reader, type);
        }
        try (Reader reader = Files.newBufferedReader(path.resolve("added_tokens.json"))) {
            added = JsonUtils.GSON.fromJson(reader, type);
        }
        String[] result = new String[vocab.size() + added.size()];
        vocab.forEach((key, value) -> result[value] = key);
        added.forEach((key, value) -> result[value] = key);
        vocabulary = new DefaultVocabulary(Arrays.asList(result));
    }

    /** {@inheritDoc} */
    @Override
    public NDList processInput(TranslatorContext ctx, Audio input) throws Exception {
        NDArray samples = ctx.getNDManager().create(input.getData());
        for (AudioProcessor processor : processors) {
            samples = processor.extractFeatures(samples.getManager(), samples);
        }
        samples = samples.expandDims(0);
        NDArray placeholder = ctx.getNDManager().create("");
        placeholder.setName("module_method:generate");
        return new NDList(samples, placeholder);
    }

    /** {@inheritDoc} */
    @Override
    public String processOutput(TranslatorContext ctx, NDList list) throws Exception {
        NDArray result = list.singletonOrThrow();
        StringBuilder sb = new StringBuilder();
//        List<String> sentence = new ArrayList<>();
        for (long ele : result.toLongArray()) {
//            sentence.add(vocabulary.getToken(ele));
            sb.append(vocabulary.getToken(ele));
            if ("<|endoftext|>".equals(vocabulary.getToken(ele))) {
                break;
            }
        }
//        String output = String.join(" ", sentence);
//        return output.replaceAll("[^a-zA-Z0-9<|> ,.!]", "");
        byte[] buf = new byte[sb.length()];
        for (int i = 0; i < sb.length(); ++i) {
            char c = sb.charAt(i);
            buf[i] = BYTES_DECODER.get(c);
        }

        return new String(buf, StandardCharsets.UTF_8);
    }

    /**
     * Returns list of utf-8 byte and a mapping to unicode strings.
     *
     * <p>We specifically avoids mapping to whitespace/control characters the bpe code barfs on. The
     * reversible bpe codes work on unicode strings. This means you need a large # of unicode
     * characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token
     * dataset you end up needing around 5K for decent coverage. This is a significant percentage of
     * your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
     * unicode strings.
     */
    private static Map<Character, Byte> bpeDecoder() {
        Map<Character, Byte> map = new ConcurrentHashMap<>();
        for (char i = '!'; i <= '~'; ++i) {
            map.put(i, (byte) i);
        }
        for (char i = '¡'; i <= '¬'; ++i) {
            map.put(i, (byte) i);
        }
        for (char i = '®'; i <= 'ÿ'; ++i) {
            map.put(i, (byte) i);
        }

        int n = 0;
        for (char i = 0; i < 256; ++i) {
            if (!map.containsKey(i)) {
                map.put((char) (256 + n), (byte) i);
                ++n;
            }
        }
        return map;
    }

}
[WARN ] - Simple repository pointing to a non-archive file.
[INFO ] - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[INFO ] - Number of inter-op threads is 6
[INFO ] - Number of intra-op threads is 10
[INFO ] - 开始时间:[2024-04-15T12:17:34.685+0800]
[INFO ] - 结束时间:[2024-04-15T12:18:21.346+0800], result: [<|startoftranscript|><|zh|><|en|><|notimestamps|> Beijing's weather<|endoftext|>]
[INFO ] - 开始时间:[2024-04-15T12:18:21.378+0800]
[INFO ] - 结束时间:[2024-04-15T12:19:09.164+0800], result: [<|startoftranscript|><|zh|><|en|><|notimestamps|> Beijing's weather<|endoftext|>]
Input #0, wav, from 'D:\workspace\djl\examples\build\example\chinese.wav':
  Metadata:
    encoder         : Adobe Audition CC 2017.1 (Macintosh)
    date            : 2020-06-05
    creation_time   : 14:26:45
    time_reference  : 0
  Duration: 00:00:03.10, bitrate: 275 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, 1 channels, s16, 256 kb/s
Input #0, wav, from 'D:\workspace\djl\examples\build\example\chinese.wav':
  Metadata:
    encoder         : Adobe Audition CC 2017.1 (Macintosh)
    date            : 2020-06-05
    creation_time   : 14:26:45
    time_reference  : 0
  Duration: 00:00:03.10, bitrate: 275 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, 1 channels, s16, 256 kb/s

In addition, at the beginning of tracing the model, I did not use the forced_decodeer_ids parameter. Later, I added this parameter model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe"), and the result was the same

from djl.

frankfliu avatar frankfliu commented on July 24, 2024

I tested with https://resources.djl.ai/demo/pytorch/whisper/whisper_en.zip, it can output Chinese.

Can you share your python script that trace the model? Are you using openai/whisper-base?

from djl.

leleZeng avatar leleZeng commented on July 24, 2024

This is my code, I am using the openaI/whisper-large-v2 model.

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import torch
import numpy as np

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
processor.tokenizer.save_pretrained("whisper-tokenizer")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2", return_dict=False)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

test = []
for ele in ds:
    test.append(ele["audio"]["array"])

input_features = processor(np.concatenate(test), return_tensors="pt").input_features
generated_ids = model.generate(inputs=input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Original: " + transcription)

# Start tracing
traced_model = torch.jit.trace_module(model, {"generate": [input_features]})
generated_ids = traced_model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Traced: " + transcription)

torch.jit.save(traced_model, "whisper-large-v2.pt")

from djl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.