Comments (15)
I'm not sure if this is helpful, but I was working on deploying some of these models using ONNX and this is what I came up with so far. If others are looking for a place to start here is some code that will convert the base model and the head and then you can run them separately. I haven't been able to merge them into one graph yet but hopefully it's a start while we wait for #8 :).
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from onnxruntime import InferenceSession
from pathlib import Path
from setfit import SetFitModel
from transformers import AutoTokenizer
from transformers.convert_graph_to_onnx import convert
import torch
import numpy as np
import sys
def mean_pooling_np(model_output: np.array, attention_mask: np.array):
token_embeddings = model_output[0]
input_mask_expanded = np.broadcast_to(
np.expand_dims(attention_mask, axis=2), token_embeddings.shape
)
sum_embeddings = np.sum(input_mask_expanded * token_embeddings, axis=1)
sum_mask = np.clip(input_mask_expanded.sum(1), 1e-9, sys.maxint)
return sum_embeddings / sum_mask
trained_model_path = "path/to/your/trained/setfit/model"
onnx_sentence_model_path = "/path/to/save/onnx/to"
onnx_head_model_path = "/path/to/save/onnx/to"
model = SetFitModel.from_pretrained(trained_model_path)
# Convert the sentence transformer model to onnx
convert(
"pt",
trained_model_path,
Path(onnx_sentence_model_path).absolute(),
15,
trained_model_path,
)
# Convert sklearn head into ONNX format
initial_type = [("model_head", FloatTensorType([None, 768]))]
onx = convert_sklearn(model.model_head, initial_types=initial_type, target_opset=15)
with open(onnx_head_model_path, "wb") as f:
f.write(onx.SerializeToString())
# Load and use the models
text = ["some text to do stuff with"]
tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2"
)
session = InferenceSession(onnx_sentence_model_path)
head_session = InferenceSession(onnx_head_model_path)
tokens = tokenizer(text, truncation=True, return_tensors="np")
preds = session.run(None, dict(tokens))
pooled_preds = mean_pooling_np(preds, tokens["attention_mask"])
print(head_session.run(None, {"model_head": pooled_preds}))
from setfit.
Thank you, @nbertagnolli !
A few issues that I had with the script and how I resolved them:
- For the paths to store the
onnx
models, the convert script expects that the directory that will contain them will be empty. Just create an emptyonnx
directory and put them in there. - Be careful with the
initial_type
-- it should have the right shape for your LM. - You should also use the appropriate tokenizer for your model.
- My model-head was expecting np.float32, not np.float64.
- For python 3+, I would change
sys.maxint
withsys.maxsize
.
from setfit.
Hey @nbertagnolli and @blakechi super cool that you're excited to work on the ONNX export for the two heads 🔥 !!
I agree it's best to have an initial PR first so we can hone the design and then iterate from there.
from setfit.
I like that idea, I think that could make working with the sklearn heads easier. : )
from setfit.
@AnshulP10 please take a look at the PR we've been working on #156. @kgourgou pointed out the above script has some things that you need to modify for some models. This PR hopefully addresses those concerns. In the PR there is a function called export_onnx which should do what you want. Let me know if you still have trouble.
from setfit.
Hi @lewtun,
Thanks for your reply! onnx_export
for SetFitModel
sounds good to me.
Sure, happy to lend a hand on adding onnx in this work after #8. 😀
from setfit.
Sorry for the waiting time and good new is Issue #8 has been resolved!
Just a suggestion. Maybe @nbertagnolli can work on converting the sklearn head and I can work on the PyTorch version? So we can work on the part we are familiar with. :)
How do you think? @lewtun @nbertagnolli
from setfit.
Sounds good @blakechi I'll open a PR soon with an initial script for onnx conversion and we can go from there! : )
from setfit.
okay, sounds good to me :)
Make sure you follow @lewtun 's suggestion as below:
Hi @blakechi thanks for your interest in our work! Instead of compiling
SetFitTrainer
into ONNX, I think it would be better to have something like anonnx_export()
function that lives inside asetfit.onnx
module.Ideally, this function would take a
SetFitModel
as input and outputmodel.onnx
, similar to what was done for Stable Diffusion here: https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py#L32Since the current implementation of
SetFitModel
usesscikit-learn
estimators for the classification head, it might be best to wait until we implement a pure PyTorch version in #8
from setfit.
Do we also need a Python API for ONNX models? I think it would be much easier for users since we can handle the tokenizer and ONNX runtime for them.
Found this pretty helpful from Diffusion. We could possibly have it in a setfit.onnx
module as well
from setfit.
With #156 merged, this feature request has been implemented :)
from setfit.
Hi @blakechi thanks for your interest in our work! Instead of compiling SetFitTrainer
into ONNX, I think it would be better to have something like an onnx_export()
function that lives inside a setfit.onnx
module.
Ideally, this function would take a SetFitModel
as input and output model.onnx
, similar to what was done for Stable Diffusion here: https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py#L32
Since the current implementation of SetFitModel
uses scikit-learn
estimators for the classification head, it might be best to wait until we implement a pure PyTorch version in #8
from setfit.
Happy to help @blakechi do you have a branch you're currently working on? Want me to create one and put an initial ONNX script together? Nice work on #8!
from setfit.
Thanks, @nbertagnolli!
Or maybe we can open 2 PR separately? Seems like you are almost ready to open one, so I don't want to hold your back. 😅 But I'm fine on either way :)
Maybe you can give us some suggestion? Which way is more convenient for you to manage? @lewtun
from setfit.
@nbertagnolli I tried this approach but I'm getting an issue regarding the inputs. I trained a multi-label model(one-vs-rest classifier) using text inputs
from setfit.
Related Issues (20)
- Can we use Setfit just for finetuning ST Embedding model to create embeddings HOT 2
- MultiGPU support or better intergration for loading models HOT 1
- i found a typo in docs HOT 1
- Methodological error in zero cost, zero time, zero shot notebook HOT 2
- Usage of deprecated `evaluation_strategy` in TrainingArguments HOT 1
- Data validation when using differentiable_head
- Can you please add an example with early stopping based on classification loss?
- how to optimize setfit inference
- exception:enum PyPreTokenizerTypeWrapper, while loading the fine-tuned model for evaluation HOT 7
- Regarding expanding the label column
- Setfit can't doing model prediction after training for MultiClass Classification with hf Trainer HOT 1
- Item taxonomy generation
- The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
- Possibility to monitor head training?
- Model checkpoints saved during the training are unusable
- Example from quick start fails with 'TrainingArguments' object has no attribute 'eval_strategy' HOT 8
- Clarification on end_to_end vs trainer.train_embeddings HOT 1
- Savely save & load SetFitModel with custom differentiable head
- Multiclass Training Error HOT 1
- setfit example does not work HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from setfit.