Code Monkey home page Code Monkey logo

Comments (15)

nbertagnolli avatar nbertagnolli commented on June 11, 2024 7

I'm not sure if this is helpful, but I was working on deploying some of these models using ONNX and this is what I came up with so far. If others are looking for a place to start here is some code that will convert the base model and the head and then you can run them separately. I haven't been able to merge them into one graph yet but hopefully it's a start while we wait for #8 :).

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from onnxruntime import InferenceSession
from pathlib import Path
from setfit import SetFitModel
from transformers import AutoTokenizer
from transformers.convert_graph_to_onnx import convert
import torch
import numpy as np
import sys


def mean_pooling_np(model_output: np.array, attention_mask: np.array):
    token_embeddings = model_output[0] 
    input_mask_expanded = np.broadcast_to(
        np.expand_dims(attention_mask, axis=2), token_embeddings.shape
    )
    sum_embeddings = np.sum(input_mask_expanded * token_embeddings, axis=1)
    sum_mask = np.clip(input_mask_expanded.sum(1), 1e-9, sys.maxint)
    return sum_embeddings / sum_mask


trained_model_path = "path/to/your/trained/setfit/model"
onnx_sentence_model_path = "/path/to/save/onnx/to"
onnx_head_model_path = "/path/to/save/onnx/to"

model = SetFitModel.from_pretrained(trained_model_path)

# Convert the sentence transformer model to onnx
convert(
    "pt",
    trained_model_path,
    Path(onnx_sentence_model_path).absolute(),
    15,
    trained_model_path,
)

# Convert sklearn head into ONNX format
initial_type = [("model_head", FloatTensorType([None, 768]))]
onx = convert_sklearn(model.model_head, initial_types=initial_type, target_opset=15)
with open(onnx_head_model_path, "wb") as f:
    f.write(onx.SerializeToString())


# Load and use the models
text = ["some text to do stuff with"]
tokenizer = AutoTokenizer.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2"
)
session = InferenceSession(onnx_sentence_model_path)
head_session = InferenceSession(onnx_head_model_path)
tokens = tokenizer(text, truncation=True, return_tensors="np")
preds = session.run(None, dict(tokens))
pooled_preds = mean_pooling_np(preds, tokens["attention_mask"])
print(head_session.run(None, {"model_head": pooled_preds}))

from setfit.

kgourgou avatar kgourgou commented on June 11, 2024 2

Thank you, @nbertagnolli !

A few issues that I had with the script and how I resolved them:

  • For the paths to store the onnx models, the convert script expects that the directory that will contain them will be empty. Just create an empty onnx directory and put them in there.
  • Be careful with the initial_type -- it should have the right shape for your LM.
  • You should also use the appropriate tokenizer for your model.
  • My model-head was expecting np.float32, not np.float64.
  • For python 3+, I would change sys.maxint with sys.maxsize.

from setfit.

lewtun avatar lewtun commented on June 11, 2024 2

Hey @nbertagnolli and @blakechi super cool that you're excited to work on the ONNX export for the two heads 🔥 !!

I agree it's best to have an initial PR first so we can hone the design and then iterate from there.

from setfit.

nbertagnolli avatar nbertagnolli commented on June 11, 2024 2

I like that idea, I think that could make working with the sklearn heads easier. : )

from setfit.

nbertagnolli avatar nbertagnolli commented on June 11, 2024 2

@AnshulP10 please take a look at the PR we've been working on #156. @kgourgou pointed out the above script has some things that you need to modify for some models. This PR hopefully addresses those concerns. In the PR there is a function called export_onnx which should do what you want. Let me know if you still have trouble.

from setfit.

blakechi avatar blakechi commented on June 11, 2024 1

Hi @lewtun,
Thanks for your reply! onnx_export for SetFitModel sounds good to me.

Sure, happy to lend a hand on adding onnx in this work after #8. 😀

from setfit.

blakechi avatar blakechi commented on June 11, 2024 1

Sorry for the waiting time and good new is Issue #8 has been resolved!

Just a suggestion. Maybe @nbertagnolli can work on converting the sklearn head and I can work on the PyTorch version? So we can work on the part we are familiar with. :)

How do you think? @lewtun @nbertagnolli

from setfit.

nbertagnolli avatar nbertagnolli commented on June 11, 2024 1

Sounds good @blakechi I'll open a PR soon with an initial script for onnx conversion and we can go from there! : )

from setfit.

blakechi avatar blakechi commented on June 11, 2024 1

okay, sounds good to me :)

Make sure you follow @lewtun 's suggestion as below:

Hi @blakechi thanks for your interest in our work! Instead of compiling SetFitTrainer into ONNX, I think it would be better to have something like an onnx_export() function that lives inside a setfit.onnx module.

Ideally, this function would take a SetFitModel as input and output model.onnx, similar to what was done for Stable Diffusion here: https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py#L32

Since the current implementation of SetFitModel uses scikit-learn estimators for the classification head, it might be best to wait until we implement a pure PyTorch version in #8

from setfit.

blakechi avatar blakechi commented on June 11, 2024 1

Do we also need a Python API for ONNX models? I think it would be much easier for users since we can handle the tokenizer and ONNX runtime for them.
Found this pretty helpful from Diffusion. We could possibly have it in a setfit.onnx module as well

from setfit.

tomaarsen avatar tomaarsen commented on June 11, 2024 1

With #156 merged, this feature request has been implemented :)

from setfit.

lewtun avatar lewtun commented on June 11, 2024

Hi @blakechi thanks for your interest in our work! Instead of compiling SetFitTrainer into ONNX, I think it would be better to have something like an onnx_export() function that lives inside a setfit.onnx module.

Ideally, this function would take a SetFitModel as input and output model.onnx, similar to what was done for Stable Diffusion here: https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py#L32

Since the current implementation of SetFitModel uses scikit-learn estimators for the classification head, it might be best to wait until we implement a pure PyTorch version in #8

from setfit.

nbertagnolli avatar nbertagnolli commented on June 11, 2024

Happy to help @blakechi do you have a branch you're currently working on? Want me to create one and put an initial ONNX script together? Nice work on #8!

from setfit.

blakechi avatar blakechi commented on June 11, 2024

Thanks, @nbertagnolli!

Or maybe we can open 2 PR separately? Seems like you are almost ready to open one, so I don't want to hold your back. 😅 But I'm fine on either way :)

Maybe you can give us some suggestion? Which way is more convenient for you to manage? @lewtun

from setfit.

AnshulP10 avatar AnshulP10 commented on June 11, 2024

@nbertagnolli I tried this approach but I'm getting an issue regarding the inputs. I trained a multi-label model(one-vs-rest classifier) using text inputs

image

from setfit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.