Code Monkey home page Code Monkey logo

Comments (13)

prabhat00155 avatar prabhat00155 commented on May 27, 2024

Could you share your code and dataset?

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

This is part of PR:

#144

You need to run the unit tests in class TestNaiveBayesConverter to generate model, inputs, expected outputs, all are stored on disk. The runtime is then run in class TestBackendWithOnnxRuntime. The code is in branch https://github.com/xadupre/onnxmltools/tree/testrt/onnxmltools.

from onnxmltools.

prabhat00155 avatar prabhat00155 commented on May 27, 2024

I see you have made changes to NaiveBayes.py file which has the NB converters. I had tested test_NaiveBayesConverter.py file and they were running fine. Are you saying that the original file was causing test failures or did you see the mismatch after your changes?

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

When you say running, you mean just the conversion or the conversion + the execution of the converted model with one onnx backend. That's what I did today. The converter was woring fine but the execution of the converted onnx with onnxruntime was either failing either producing different results. The mismatch was before my changes.

from onnxmltools.

prabhat00155 avatar prabhat00155 commented on May 27, 2024

I mean conversion + execution. Here is what I did:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from onnxmltools import convert_sklearn
from onnxmltools.convert.common.data_types import FloatTensorType

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)

model_MNB = MultinomialNB().fit(X_train, y_train)
model_BNB = BernoulliNB().fit(X_train, y_train)
onnx_MNB = convert_sklearn(model_MNB, "model MNB", [('input', FloatTensorType([1, 4]))])
onnx_BNB = convert_sklearn(model_BNB, "model BNB", [('input', FloatTensorType([1, 4]))])
onnxmltools.utils.save_model(onnx_MNB, "onnx_MNB.onnx")
onnxmltools.utils.save_model(onnx_BNB, "onnx_BNB.onnx")

scikit_result_MNB = np.hstack([model_MNB.predict_proba(X_test), model_MNB.predict(X_test).reshape((-1, 1))])
scikit_result_BNB = np.hstack([model_BNB.predict_proba(X_test), model_BNB.predict(X_test).reshape((-1, 1))])

np.mean(onnx_res_mnb[:, 3] == scikit_result_MNB[:, 3]) # Gives 1 as output
np.mean(onnx_res_bnb[:, 3] == scikit_result_BNB[:, 3]) # Gives 1 as output

This means all the predictions match between scikit and onnx models.

np.mean(np.isclose(onnx_res_mnb, scikit_result_MNB)) # Gives 1 as output
np.mean(np.isclose(onnx_res_bnb, scikit_result_BNB)) # Gives 0.25 as output

MNB outputs (probabilities + labels) seem to match, whereas BNB probability values seem to vary a little although their labels match in this example.

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

How did you get onnx_res_mnb?

from onnxmltools.

prabhat00155 avatar prabhat00155 commented on May 27, 2024

$ ./onnxruntime_exec.exe -m onnx_MNB.onnx -t iris_test.csv > result_MNB.csv
'onnx_MNB.onnx' loaded successfully.
Done loading model: onnx_MNB.onnx
Execution Status: OK

prroy@B115FFDGPUN03 /cygdrive/c/Users/prroy/LotusRT/Lotus_Oct9/Lotus/onnxru ntime/cmake_build/Debug
$ ./onnxruntime_exec.exe -m onnx_BNB.onnx -t iris_test.csv > result_BNB.csv
'onnx_BNB.onnx' loaded successfully.
Done loading model: onnx_BNB.onnx
Execution Status: OK
In Python notebook:
onnx_res_mnb = np.loadtxt(fname='result_MNB.csv', delimiter=',')
onnx_res_bnb = np.loadtxt(fname='result_BNB.csv', delimiter=',')

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

What is the version of onnxruntime you are using? I observed differences between versions. I'm currently testing against the version released on pypi (1.3.0). The tests I put in place test all outputs, prediction and probabilities.

from onnxmltools.

prabhat00155 avatar prabhat00155 commented on May 27, 2024

How do I check the version of onnxruntime? I cloned Lotus repo on Oct 9 and built onnxruntime project.

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

This should work then except I think onnxruntime only produces the predicted labels and not the scores. If you build it yourself, you can add an option to build the python package too and check with this one the converted model. Documentation for onnxruntime is here: https://docs.microsoft.com/en-us/python/api/overview/azure/onnx/examples-md?view=azure-onnx-py. I'll check with the runtime tomorrow on my side.

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

Here is what I get with the latest version of onnxruntime and the current onnxmltools. onnxruntime and the executable gives the same results. On this example, MNB works, BNB does not.

import os
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from onnxmltools import convert_sklearn
from onnxmltools.convert.common.data_types import FloatTensorType
import onnxmltools
import numpy as np

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)
model_MNB = MultinomialNB().fit(X_train, y_train)
model_BNB = BernoulliNB().fit(X_train, y_train)
onnx_MNB = convert_sklearn(model_MNB, "model MNB", [('input', FloatTensorType([1, 4]))])
onnx_BNB = convert_sklearn(model_BNB, "model BNB", [('input', FloatTensorType([1, 4]))])

xt = np.zeros((X_test.shape[0], X_test.shape[1]+1))
xt[:, 1:] = X_test
xt[:, 0] = y_test
np.savetxt("iris_test.csv", xt, delimiter=',', fmt='%f')

onnxmltools.utils.save_model(onnx_MNB, "onnx_MNB.onnx")
onnxmltools.utils.save_model(onnx_BNB, "onnx_BNB.onnx")

fold = r"path_to_exec"
if True or not os.path.exists("result_BNB.csv"):
cmd = fold + "\onnxruntime_exec.exe -m onnx_BNB.onnx -t iris_test.csv > result_BNB.csv"
os.system(cmd)
if True or not os.path.exists("result_MNB.csv"):
cmd = fold + "\onnxruntime_exec.exe -m onnx_MNB.onnx -t iris_test.csv > result_MNB.csv"
os.system(cmd)

onnx_res_mnb = np.loadtxt("result_MNB.csv", delimiter=',')
onnx_res_bnb = np.loadtxt("result_BNB.csv", delimiter=',')

scikit_result_MNB = np.hstack([model_MNB.predict_proba(X_test), model_MNB.predict(X_test).reshape((-1, 1))])
scikit_result_BNB = np.hstack([model_BNB.predict_proba(X_test), model_BNB.predict(X_test).reshape((-1, 1))])

print(np.mean(onnx_res_mnb[:, 3] == scikit_result_MNB[:, 3])) # Gives 1 as output
print(np.mean(onnx_res_bnb[:, 3] == scikit_result_BNB[:, 3])) # Gives 1 as output

#This means all the predictions match between scikit and onnx models.
print(np.mean(np.isclose(onnx_res_mnb, scikit_result_MNB))) # Gives 1 as output
print(np.mean(np.isclose(onnx_res_bnb, scikit_result_BNB))) # Gives 0.25 as output

import onnxruntime

mnb = onnxruntime.InferenceSession('onnx_MNB.onnx')
mnb_prd = mnb.run(None, {'input': X_test[:1].astype(np.float32)})
print("MNB")
print("ONNX-PY ", mnb_prd[1])
print("ONNX-EXE", onnx_res_mnb[:1, 1:])
print("SKL ", scikit_result_MNB[:1, :3])

bnb = onnxruntime.InferenceSession('onnx_BNB.onnx')
bnb_prd = bnb.run(None, {'input': X_test[:1].astype(np.float32)})
print("BNB")
print("ONNX-PY ", bnb_prd[1])
print("ONNX-EXE", onnx_res_bnb[:1, 1:])
print("SKL ", scikit_result_BNB[:1, :3])

Outputs:

0.0
0.0
0.0
0.0
MNB
ONNX-PY [{0: 0.04780237749218941, 1: 0.5113309621810913, 2: 0.4408663213253021}]
ONNX-EXE [[0.04780238 0.51133096 0.44086632]]
SKL [[0.04780234 0.51133139 0.44086628]]
BNB
ONNX-PY [{0: 0.3253963887691498, 1: 0.4300372302532196, 2: 0.24456651508808136}]
ONNX-EXE [[0.32539639 0.43003723 0.24456652]]
SKL [[0.33333227 0.34244143 0.3242263 ]]

from onnxmltools.

xadupre avatar xadupre commented on May 27, 2024

BernouillNB - binarisation of features is not part of the converter

https://github.com/scikit-learn/scikit-learn/blob/bac89c2/sklearn/naive_bayes.py#L938

from onnxmltools.

prabhat00155 avatar prabhat00155 commented on May 27, 2024

Yup, here is the PR: #162

from onnxmltools.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.