Code Monkey home page Code Monkey logo

curated-transformers's People

Contributors

adrianeboyd avatar danieldk avatar kennethenevoldsen avatar kit1980 avatar mayankjobanputra avatar omahs avatar shademe avatar svlandeg avatar vinbo8 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

curated-transformers's Issues

Dictionary update length error

I tried the quantization in our own project and got this error:

We have 3 separated config, one tagger, one parser and one ner. After that we use an assemble step to merge these together. And I tried to quantize this assembled model and got that error.
config files: https://gist.github.com/SzaboGergo01/5ca9abeaaf199a686d8b41335e9fe261

Running command: /home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/bin/python -m spacy quantize-transformer --max-mse-loss 0.000003 models/hu_core_news_trf_xl-3.5.0 models/hu_core_news_trf_xl-3.5.0-quantized
Traceback (most recent call last):
  File "/home/a100/anaconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/a100/anaconda3/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/spacy/__main__.py", line 4, in <module>
    setup_cli()
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/spacy/cli/_util.py", line 74, in setup_cli
    command(prog_name=COMMAND)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/click/core.py", line 1259, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/typer/main.py", line 497, in wrapper
    return callback(**use_params)  # type: ignore
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/curated_transformers/cli/quantize.py", line 43, in quantize_cli
    nlp_quantize_dynamic(
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/curated_transformers/cli/quantize.py", line 75, in nlp_quantize_dynamic
    quantize_dynamic(
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/curated_transformers/cli/quantize.py", line 108, in quantize_dynamic
    quantized_model = torch.quantization.quantize_dynamic(
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/torch/ao/quantization/quantize.py", line 447, in quantize_dynamic
    model = copy.deepcopy(model)
  File "/home/a100/anaconda3/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/a100/anaconda3/lib/python3.8/copy.py", line 272, in _reconstruct
    y.__setstate__(state)
  File "/home/a100/gszabo/roberta/hu_core_news_trf_xl/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1237, in __setstate__
    self.__dict__.update(state)
ValueError: dictionary update sequence element #0 has length 12; 2 is required

My spacy environment:

## Info about spaCy

- **spaCy version:** 3.5.1
- **Platform:** Linux-5.4.0-136-generic-x86_64-with-glibc2.10
- **Python version:** 3.8.5

In principle the torch.quantization.quantize_dynamic() function is the source of the error. Have you ever encountered something similar or do you know a solution to this?

Loading a model saved locally

Hi
I probably missed something but how do you load a model saved locally (e.g a fine tuned Llama model)?
I'm trying to use AutoGenerator but I cannot load a local model (hf format).

In the documentation it is stated that a model architecture is infered by the repo name, I guess there may be a conflict with a custom local name.

Thank you.

Make `QkvMode` ADT-like

Some behaviors, like how to split heads, do not apply to all modes, so they should be scoped.

Move the old Falcon architecuture to the extras/addons pacakage

We currently support two different architectures for the Falcon family of models. This adds a fair bit of complexity to the codebase since one of the architectures doesn't lend itself to the reuse of the existing components we've already implemented. So, it would be a good idea to move this arch to the extras/addons package.

Support for loading from an AbstractFileSystem

A problem that I face very often with HuggingFace transformers is to efficiently load a model from a private cloud storage. transformers unfortunately does not support fsspec URLs in their .from_pretrained API. The consequence is that it is both inefficient and slightly ugly to load a checkpoint from cloud storage because we first have to transit through disk.

A better alternative would be to directly load from a fsspec file system

encoder = BERTEncoder.load(
   fs=GCSFileSystem(...),
   device=torch.device("cuda", index=0),
)

or perhaps directly by passing a fsspec-compliant URL

encoder = BERTEncoder.load(
   url="gs://my-bucket/.../my-model/",
   device=torch.device("cuda", index=0),
)

The HuggingFace Hub can also be interacted with through fsspec (documentation), perhaps it can help completely abstract the storage layer.

It would be a very nice and useful addition to the package when hosting on the Hub is not possible.

Add Low-Rank Adapters injection into base models

Low-Rank Adaptation (LoRA) has become the de-facto parameter-efficient finetuning technique to adapt a base language model to a specific task. curated-transformers already supports dynamic quantization using bitsandbytes, hence adding some utilities to inject trainable adapters opens the door to using curated-transformers as a replacement to the HuggingFace transformers + peft stack. This could also enable a very nice finetuning integration into spaCy in the future.

For reference, I find this implementation in lit-gpt really readable.

Do you find this idea interesting?

If so, as for the user-facing API, drawing inspiration from HuggingFace peft it could look something like

# Load and quantize the base model
model = AutoGenerator.from_hf_hub(
    name="meta-llama/Llama-2-7b-chat-hf",
    device=torch.device("cuda", index=0),
    quantization_config=BitsAndBytesConfig.for_4bit(
        quantization_dtype=Dtype4Bit.FP4,
        compute_dtype=torch.bfloat16,
        double_quantization=True,
    ),
)

# Replace targeted linear layers by `LoRALayer` that wrap the original weights
model_with_adapters = inject_adapters(
    base_model=model,
    lora_config=LoraConfig(
        rank=64,
        alpha=16,
        dropout=0.1,
        bias=LoraBias.NONE,
        target_modules=[...]
    ),
)

Truncation of sequences that are beyond the model's maximum length

Hi,
First, I would like to thank you for this library :-) I'm really enjoying it.

I tried to tokenize a sequence with around 4K tokens and then fed it to a RoBERTa-based model (CodeBERT). This led to the following issue,

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-a373b5333f39>](https://localhost:8080/#) in <cell line: 1>()
      2    ids = input_sentence.padded_tensor(padding_id=0, pad_left=True)
      3    mask = input.attention_mask(pad_left=True)
----> 4    model_output = encoder(piece_ids=ids, attention_mask=mask)

10 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, attention_mask, positions, type_ids)
    122         type_ids: Optional[Tensor] = None,
    123     ) -> ModelOutput:
--> 124         embeddings = self.embeddings(piece_ids, positions=positions, type_ids=type_ids)
    125         layer_output = embeddings
    126 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/roberta/embeddings.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
     96         if positions is None:
     97             positions = self._get_positions(piece_ids)
---> 98         return super().forward(
     99             piece_ids,
    100             positions=positions,

[/usr/local/lib/python3.10/dist-packages/curated_transformers/layers/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
    180             if positions is None:
    181                 positions = self._get_positions(piece_ids)
--> 182             position_embeddings = self.position_embeddings(positions)
    183             embeddings += position_embeddings
    184 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2231         # remove once script supports set_grad_enabled
   2232         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2234 
   2235 

IndexError: index out of range in self

For reference, here was the code that I was using,

MODEL_TAG = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_hf_hub(name=MODEL_TAG, revision="main")
model = RoBERTaEncoder.from_hf_hub(
    name=MODEL_TAG,
    revision="main",
)
code = [
   'void avcodec_string(char *buf, int buf_size, AVCodecContext *enc, int encode)\n\n{\n\n    const char *codec_type;\n\n    const char *codec_name;\n\n    const char *profile = NULL;\n\n    const AVCodec *p;\n\n    int64_t bitrate;\n\n    int new_line = 0;\n\n    AVRational display_aspect_ratio;\n\n    const char *separator = enc->dump_separator ? (const char *)enc->dump_separator : ", ";\n\n\n\n    if (!buf || buf_size <= 0)\n\n        return;\n\n    codec_type = av_get_media_type_string(enc->codec_type);\n\n    codec_name = avcodec_get_name(enc->codec_id);\n\n    if (enc->profile != FF_PROFILE_UNKNOWN) {\n\n        if (enc->codec)\n\n            p = enc->codec;\n\n        else\n\n            p = encode ? avcodec_find_encoder(enc->codec_id) :\n\n                        avcodec_find_decoder(enc->codec_id);\n\n        if (p)\n\n            profile = av_get_profile_name(p, enc->profile);\n\n    }\n\n\n\n    snprintf(buf, buf_size, "%s: %s", codec_type ? codec_type : "unknown",\n\n             codec_name);\n\n    buf[0] ^= \'a\' ^ \'A\'; /* first letter in uppercase */\n\n\n\n    if (enc->codec && strcmp(enc->codec->name, codec_name))\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", enc->codec->name);\n\n\n\n    if (profile)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", profile);\n\n    if (   enc->codec_type == AVMEDIA_TYPE_VIDEO\n\n        && av_log_get_level() >= AV_LOG_VERBOSE\n\n        && enc->refs)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %d reference frame%s",\n\n                 enc->refs, enc->refs > 1 ? "s" : "");\n\n\n\n    if (enc->codec_tag) {\n\n        char tag_buf[32];\n\n        av_get_codec_tag_string(tag_buf, sizeof(tag_buf), enc->codec_tag);\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 " (%s / 0x%04X)", tag_buf, enc->codec_tag);\n\n    }\n\n\n\n    switch (enc->codec_type) {\n\n    case AVMEDIA_TYPE_VIDEO:\n\n        {\n\n            char detail[256] = "(";\n\n\n\n            av_strlcat(buf, separator, buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 "%s", enc->pix_fmt == AV_PIX_FMT_NONE ? "none" :\n\n                     av_get_pix_fmt_name(enc->pix_fmt));\n\n            if (enc->bits_per_raw_sample && enc->pix_fmt != AV_PIX_FMT_NONE &&\n\n                enc->bits_per_raw_sample < av_pix_fmt_desc_get(enc->pix_fmt)->comp[0].depth)\n\n                av_strlcatf(detail, sizeof(detail), "%d bpc, ", enc->bits_per_raw_sample);\n\n            if (enc->color_range != AVCOL_RANGE_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_color_range_name(enc->color_range));\n\n\n\n            if (enc->colorspace != AVCOL_SPC_UNSPECIFIED ||\n\n                enc->color_primaries != AVCOL_PRI_UNSPECIFIED ||\n\n                enc->color_trc != AVCOL_TRC_UNSPECIFIED) {\n\n                if (enc->colorspace != (int)enc->color_primaries ||\n\n                    enc->colorspace != (int)enc->color_trc) {\n\n                    new_line = 1;\n\n                    av_strlcatf(detail, sizeof(detail), "%s/%s/%s, ",\n\n                                av_color_space_name(enc->colorspace),\n\n                                av_color_primaries_name(enc->color_primaries),\n\n                                av_color_transfer_name(enc->color_trc));\n\n                } else\n\n                    av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                                av_get_colorspace_name(enc->colorspace));\n\n            }\n\n\n\n            if (av_log_get_level() >= AV_LOG_DEBUG &&\n\n                enc->chroma_sample_location != AVCHROMA_LOC_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_chroma_location_name(enc->chroma_sample_location));\n\n\n\n            if (strlen(detail) > 1) {\n\n                detail[strlen(detail) - 2] = 0;\n\n                av_strlcatf(buf, buf_size, "%s)", detail);\n\n            }\n\n        }\n\n\n\n        if (enc->width) {\n\n            av_strlcat(buf, new_line ? separator : ", ", buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%dx%d",\n\n                     enc->width, enc->height);\n\n\n\n            if (av_log_get_level() >= AV_LOG_VERBOSE &&\n\n                (enc->width != enc->coded_width ||\n\n                 enc->height != enc->coded_height))\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " (%dx%d)", enc->coded_width, enc->coded_height);\n\n\n\n            if (enc->sample_aspect_ratio.num) {\n\n                av_reduce(&display_aspect_ratio.num, &display_aspect_ratio.den,\n\n                          enc->width * enc->sample_aspect_ratio.num,\n\n                          enc->height * enc->sample_aspect_ratio.den,\n\n                          1024 * 1024);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " [SAR %d:%d DAR %d:%d]",\n\n                         enc->sample_aspect_ratio.num, enc->sample_aspect_ratio.den,\n\n                         display_aspect_ratio.num, display_aspect_ratio.den);\n\n            }\n\n            if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n                int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n            }\n\n        }\n\n        if (encode) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", q=%d-%d", enc->qmin, enc->qmax);\n\n        } else {\n\n            if (enc->properties & FF_CODEC_PROPERTY_CLOSED_CAPTIONS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", Closed Captions");\n\n            if (enc->properties & FF_CODEC_PROPERTY_LOSSLESS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", lossless");\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_AUDIO:\n\n        av_strlcat(buf, separator, buf_size);\n\n\n\n        if (enc->sample_rate) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%d Hz, ", enc->sample_rate);\n\n        }\n\n        av_get_channel_layout_string(buf + strlen(buf), buf_size - strlen(buf), enc->channels, enc->channel_layout);\n\n        if (enc->sample_fmt != AV_SAMPLE_FMT_NONE) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %s", av_get_sample_fmt_name(enc->sample_fmt));\n\n        }\n\n        if (   enc->bits_per_raw_sample > 0\n\n            && enc->bits_per_raw_sample != av_get_bytes_per_sample(enc->sample_fmt) * 8)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     " (%d bit)", enc->bits_per_raw_sample);\n\n        break;\n\n    case AVMEDIA_TYPE_DATA:\n\n        if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n            int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n            if (g)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_SUBTITLE:\n\n        if (enc->width)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %dx%d", enc->width, enc->height);\n\n        break;\n\n    default:\n\n        return;\n\n    }\n\n    if (encode) {\n\n        if (enc->flags & AV_CODEC_FLAG_PASS1)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 1");\n\n        if (enc->flags & AV_CODEC_FLAG_PASS2)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 2");\n\n    }\n\n    bitrate = get_bit_rate(enc);\n\n    if (bitrate != 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %"PRId64" kb/s", bitrate / 1000);\n\n    } else if (enc->rc_max_rate > 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", max. %"PRId64" kb/s", (int64_t)enc->rc_max_rate / 1000);\n\n    }\n\n}\n',
]
with torch.no_grad():
    input_sentence = tokenizer(code)
    ids = input_sentence.padded_tensor(padding_id=0, pad_left=False)
    mask = input_sentence.attention_mask(pad_left=False)
    model_output = model(piece_ids=ids, attention_mask=mask)

I went through the API docs and skimmed through source code and it appears that truncation is not supported. Note that when I manually truncated the sequence, I was able to feed it to the RoBERTa encoder.

Output logits for generation

For calculating perplexities logits are required. I couldn't find any way to do it currently. Maybe there's already a way, if so, please point me to it.

Add a an extras/contrib package

Some models deviate so much from standard transformer encoder/decoders (e.g. DeBERTa and Falcon old architecture) that we probably should not support them in mainline Curated Transformers to avoid cluttering the abstractions too much (it's curated after all). Though it would still be nice to provide official support for these models. Consider an add-on package for these models.

Probably requires #346.

Optimal Qlora settings

In HF transformers, the default setting of qlora does not replicate the qlora of the original paper, leaving valuable performance lying on the ML practitioners street using lib defaults.
One has to apply lora to certain parts of the NN, please see Tweet by Tim Dettmers:

https://twitter.com/Tim_Dettmers/status/1695377756232589459

I guess this has to be customized for each model architecture, sounds like a feature for curated-transformers, to me.

Decoupling weights downloading from loading

First, thank you for open sourcing this library, I'm currently evaluating curated-transformers to assess if I can migrate from transformers to this more focused and lightweight library. One of the few things that I'm missing with transformers is a way to decouple model downloading from model loading. In a distributed training setting, we only want the weights to be cached on disk by a single process on each node, and then loaded from all the child processes. This is especially critical when the model is very large and cannot fit in a single device.

Currently, the following snippet downloads and loads the model

encoder = BERTEncoder.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
   device=torch.device("cuda", index=0),
)

I suggest (optionally) letting the user decouple downloading and loading with an API similar to this

BERTEncoder.pull_from_hf_hub(name="bert-base-uncased", revision="main")
encoder = BERTEncoder.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
   device=torch.device("cuda", index=0),
)

where the first line caches the weights on disk and the second one first checks the cache and only downloads when the cache misses.

Does it sound like a useful feature to you?

Unable to load fine-tuned causalLM MPT model

I am unable to load a fine-tuned CausalLM MPT model from HF hub. I am getting the following error.

I used the following script to fine-tune the MPT models using HF transformers. Script

Is this expected? How do I fix this issue?

  File "src/experiments/fine-tune.py", line 367, in main
    generator = AutoGenerator.from_hf_hub(name="xxxx", device=torch.device("cuda", index=0))
  File "/home/monk/Projects/curated-transformers/curated_transformers/generation/auto_generator.py", line 61, in from_hf_hub
    generator_cls.from_hf_hub(
  File "/home/monk/Projects/curated-transformers/curated_transformers/generation/default_generator.py", line 81, in from_hf_hub
    causal_lm = AutoCausalLM.from_hf_hub(
  File "/home/monk/Projects/curated-transformers/curated_transformers/models/auto_model.py", line 196, in from_hf_hub
    causal_lm = cls._instantiate_model_from_hf_hub(
  File "/home/monk/Projects/curated-transformers/curated_transformers/models/auto_model.py", line 58, in _instantiate_model_from_hf_hub
    module = module_cls.from_hf_hub(
  File "/home/monk/Projects/curated-transformers/curated_transformers/models/hf_hub.py", line 143, in from_hf_hub
    load_model_from_checkpoints(
  File "/home/monk/Projects/curated-transformers/curated_transformers/util/serde.py", line 138, in load_model_from_checkpoints
    _emplace_module_state_dict(
  File "/home/monk/Projects/curated-transformers/curated_transformers/util/serde.py", line 228, in _emplace_module_state_dict
    apply_to_module(module, apply)
  File "/home/monk/Projects/curated-transformers/curated_transformers/util/pytorch.py", line 42, in apply_to_module
    func(itr)
  File "/home/monk/Projects/curated-transformers/curated_transformers/util/serde.py", line 214, in apply
    raise ValueError(
ValueError: Key `bias` found in state dict but no data in module `decoder.output_layer_norm`

Check BOS/EOS usage

Do we want to stick to the piece encoders' current method of using the BOS and EOS markers to delimit individual documents (as opposed to sentences)? If not, how do we get the sentence boundaries at this stage of the pipeline?

Support for Encoder-Decoder-style architectures

I regularly follow the developments on this project, and I must say that I am very interested and pleased with the direction curated-transformers is taking. The code is very understandable and high-quality, it's a pleasure to work with, congratulations!

This is perhaps already in your plans, but just to mention it here, I think a very nice addition to the project would be to have at least one reference implementation of an encoder-decoder style Transformers, such as the T5 architecture. T5 models are very popular for some tasks, especially in the < 1B parameters range which is still very relevant nowadays. Currently we have reference implementations for decoder-style and encoder-style models, but we're missing at least one reference implementation of an encoder-decoder-style architecture, perhaps with a reusable cross-attention block.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.