Comments (21)
@izeigerman I will try to create a repro for you! Thank you for looking into it. FYI, to be clear it isn't just multiclass, the binary cases seem to produce the same issue (for me).
from m2cgen.
Hi--version:
m2cgen 0.6.0
note: installed using pip on 2002-02-19
from m2cgen.
I've got a reproduce. tree_method='hist'
seems to be causing this:
expected=[0.7084146 0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.7084146 0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14489177 0.7067961 0.14831214], actual=[0.14489178040434617, 0.7067960730189182, 0.14831214657673555]
expected=[0.14527214 0.7086516 0.14607628], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14489177 0.7067961 0.14831214], actual=[0.14489178040434617, 0.7067960730189182, 0.14831214657673555]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.7084146 0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.7084146 0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14527214 0.7086516 0.14607628], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]
expected=[0.15254477 0.66059124 0.18686394], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]
(Last row)
from m2cgen.
from m2cgen.
Hey @eafpres ! Thank you so much for reporting this. This is indeed quite bad.
It'd be very helpful if you can provide some steps to reproduce the issue. Perhaps model's hyperparameters, or some sort of a synthetic dataset on which the problem reproduces.
UPD If you can confirm whether the same problem occurs in other languages (eg Java) that'd be just fantastic.
from m2cgen.
Something is clearly wrong there. Eg. row 13 - probabilities don't even add up to 1.0 in m2cgen's predictions, which doesn't make sense.
from m2cgen.
@eafpres Got it! May I please also ask you to share the m2cgen version that you're using (m2cgen -v
)?
from m2cgen.
Excellent, thank you. Another thing you may try is to downgrade to version 0.5.0. We had some significant changes in XGBoost-related code in 0.6.0 which may or may not have caused this behavior.
from m2cgen.
Hello @eafpres !
Brief question: are you using DART in your XGBoost model?
from m2cgen.
@StrikerRUS -- I am not using DART, here is my configuration:
model = xgb.XGBClassifier(colsample_bytree = colsample_bytree,
gamma = gamma,
learning_rate = learning_rate,
max_depth = max_depth,
min_child_weight = min_child_weight,
n_estimators = n_iters,
reg_alpha = reg_alpha,
reg_lambda = reg_lambda,
subsample = subsample,
tree_method = 'hist',
objective = objective,
booster = 'gbtree',
eval_metric = eval_metric,
num_class = n_class)
from m2cgen.
@eafpres for the sake of this investigation, can you please try without tree_method='hist'
to see whether the problem persists?
from m2cgen.
@izeigerman -- I will test it and report back. Thanks for the amazing support.
from m2cgen.
Hi @izeigerman -- I changed my configuration to use tree_method = 'exact' and the results are now exactly identical from the two models. This was for the multi-class case. I'll also be testing the binary case and confirm that.
from m2cgen.
@eafpres Got it, thank you so much for helping to triage the problem. We'll then be looking into why it misbehaves when the tree_method
is set to hist
.
from m2cgen.
@eafpres @izeigerman
Can you guys please post a repro for this error here, because I cannot spot the issue by simply adding tree_method='hist'
here
Lines 93 to 94 in d058b8f
from m2cgen.
@StrikerRUS as far as I remember this is exactly how I reproduced it. Have you tried the LARGE
version of this test? Or perhaps increase the sample size, because some rows still match with this tree method.
from m2cgen.
@izeigerman Hmm, interesting... Thank you for the hint! I'll continue investigation.
from m2cgen.
@eafpres the issue has been fixed in version 0.7.0
. Please feel free to reopen this issue if you notice it again.
from m2cgen.
@eafpres the issue has been fixed in version
0.7.0
. Please feel free to reopen this issue if you notice it again.
I saw the release yesterday! I have to update a model that I'm using this for this week, I'll test it both ways and advise if it works. A apologize for not getting the repro to you. Thanks for working on this so promptly!
from m2cgen.
Hi,
Thanks for the work on this package. I'm using m2cgen to convert a XGBoost model into VBA code. But when using the code produce by m2cgen I've got some predictions that are really different from the one get after training my model in Python. Here are some examples:
And I've looked in the XGBoost booster after training and compare to the output (in Python) form m2cgen. Here is what I have from m2cgen
import math
def sigmoid(x):
if x < 0.0:
z = math.exp(x)
return z / (1.0 + z)
return 1.0 / (1.0 + math.exp(-x))
def score(input):
if input[1] < 2.0:
if input[1] < 1.0:
var0 = -0.3193863
else:
var0 = -0.046659842
else:
if input[7] < 867.94:
var0 = 0.058621403
else:
var0 = 0.25975806
if input[4] < 5654.47:
if input[3] < 0.38662624:
var1 = -0.029487507
else:
var1 = 0.16083813
else:
if input[1] < 2.0:
var1 = -0.32378462
else:
var1 = -0.08247565
if input[0] < 1.0:
if input[2] < 0.8:
var2 = -0.15353489
else:
var2 = 0.081936955
else:
if input[4] < 2989.61:
var2 = 0.13463722
else:
var2 = -0.042515814
if input[5] < 0.11556604:
if input[12] < 0.11059804:
var3 = -0.1621976
else:
var3 = 0.30593434
else:
if input[11] < 661.39:
var3 = 0.0063493266
else:
var3 = 0.15387529
if input[9] < 0.12683104:
if input[19] < 197.56:
var4 = -0.25690553
else:
var4 = 0.06560632
else:
if input[8] < 0.11749347:
var4 = -0.018011741
else:
var4 = 0.10678521
if input[7] < 1790.11:
if input[8] < 0.11749347:
var5 = -0.091719724
else:
var5 = 0.048037946
else:
if input[1] < 3.0:
var5 = 0.058297392
else:
var5 = 0.18175843
if input[6] < 1351.78:
if input[10] < 3.0:
var6 = -0.0012290713
else:
var6 = 0.10081242
else:
if input[17] < 0.07381933:
var6 = -0.12741692
else:
var6 = 0.038392954
if input[1] < 3.0:
if input[15] < 0.12838633:
var7 = -0.081163615
else:
var7 = 0.019387348
else:
if input[20] < 0.29835963:
var7 = 0.1156334
else:
var7 = -0.17409053
if input[5] < 0.062735535:
if input[3] < 0.5642857:
var8 = -0.2049814
else:
var8 = 0.12192867
else:
if input[13] < 17.0:
var8 = -0.0035746796
else:
var8 = 0.10629323
if input[19] < 179.98:
if input[4] < 15379.7:
var9 = -0.010353668
else:
var9 = -0.19715081
else:
if input[21] < 1744.96:
var9 = 0.08414988
else:
var9 = -0.31387258
if input[9] < 0.12683104:
if input[19] < 90.45:
var10 = -0.15493616
else:
var10 = 0.05997152
else:
if input[11] < -1390.57:
var10 = -0.12933072
else:
var10 = 0.028274538
if input[14] < 3.0:
if input[7] < 652.72:
var11 = -0.061523404
else:
var11 = 0.018090146
else:
if input[20] < -0.015413969:
var11 = 0.122180216
else:
var11 = -0.07323579
if input[18] < 35.0:
if input[17] < 0.105689526:
var12 = -0.058067013
else:
var12 = 0.035271224
else:
if input[20] < 0.42494825:
var12 = 0.067990474
else:
var12 = -0.13910332
if input[8] < 0.11749347:
if input[22] < 0.06889495:
var13 = -0.109115146
else:
var13 = -0.011202088
else:
if input[16] < -161.82:
var13 = -0.01581455
else:
var13 = 0.10806873
if input[18] < 8.0:
if input[17] < 0.0007647209:
var14 = -0.10060249
else:
var14 = 0.04555326
else:
if input[15] < 0.15912667:
var14 = 0.0012086431
else:
var14 = 0.061486576
if input[11] < -1708.65:
if input[1] < 4.0:
var15 = -0.14637202
else:
var15 = 0.10264576
else:
if input[19] < 2421.29:
var15 = 0.008009123
else:
var15 = 0.17349313
if input[20] < 0.21551265:
if input[20] < -0.14049701:
var16 = -0.069627054
else:
var16 = 0.012490782
else:
if input[7] < 4508.38:
var16 = -0.13310793
else:
var16 = 0.2982378
if input[4] < 10364.37:
if input[18] < 46.0:
var17 = -0.00067418563
else:
var17 = 0.07025912
else:
if input[19] < 32.3:
var17 = -0.11449907
else:
var17 = 0.102952585
if input[12] < 0.11059804:
if input[9] < 0.06418919:
var18 = -0.12425961
else:
var18 = -0.0036558604
else:
if input[9] < 0.06418919:
var18 = 0.3158906
else:
var18 = 0.06434954
var19 = sigmoid(var0 + var1 + var2 + var3 + var4 + var5 + var6 + var7 + var8 + var9 + var10 + var11 + var12 + var13 + var14 + var15 + var16 + var17 + var18)
return [1.0 - var19, var19]
And this is what I have in the booster:
def score_booster(input):
if input[1]<2:
if input[1]<1:
var0=-0.319386303
else:
var0=-0.0466598421
else:
if input[7]<867.940002:
var0=0.0586214028
else:
var0=0.259758055
if input[4]<5654.47021:
if input[3]<0.386626244:
var1=-0.0294875074
else:
var1=0.160838127
else:
if input[1]<2:
var1=-0.32378462
else:
var1=-0.0824756473
if input[0]<1:
if input[2]<0.800000012:
var2=-0.153534889
else:
var2=0.0819369555
else:
if input[4]<2989.61011:
var2=0.134637222
else:
var2=-0.0425158143
if input[5]<0.115566038:
if input[12]<0.110598043:
var3=-0.162197605
else:
var3=0.30593434
else:
if input[11]<661.390015:
var3=0.00634932658
else:
var3=0.153875291
if input[9]<0.12683104:
if input[19]<197.559998:
var4=-0.256905526
else:
var4=0.0656063184
else:
if input[8]<0.117493473:
var4=-0.0180117413
else:
var4=0.106785208
if input[7]<1790.10999:
if input[8]<0.117493473:
var5=-0.0917197242
else:
var5=0.0480379462
else:
if input[1]<3:
var5=0.058297392
else:
var5=0.181758434
if input[6]<1351.78003:
if input[10]<3:
var6=-0.00122907129
else:
var6=0.10081242
else:
if input[17]<0.0738193318:
var6=-0.127416924
else:
var6=0.0383929536
if input[1]<3:
if input[15]<0.128386334:
var7=-0.081163615
else:
var7=0.0193873476
else:
if input[20]<0.298359632:
var7=0.115633398
else:
var7=-0.174090534
if input[5]<0.0627355352:
if input[3]<0.564285696:
var8=-0.204981402
else:
var8=0.12192867
else:
if input[13]<17:
var8=-0.00357467961
else:
var8=0.106293231
if input[19]<179.979996:
if input[4]<15379.7002:
var9=-0.0103536677
else:
var9=-0.197150812
else:
if input[21]<1744.95996:
var9=0.0841498822
else:
var9=-0.313872576
if input[9]<0.12683104:
if input[19]<90.4499969:
var10=-0.154936165
else:
var10=0.0599715188
else:
if input[11]<-1390.56995:
var10=-0.129330724
else:
var10=0.028274538
if input[14]<3:
if input[7]<652.719971:
var11=-0.061523404
else:
var11=0.0180901457
else:
if input[20]<-0.0154139688:
var11=0.122180216
else:
var11=-0.0732357875
if input[18]<35:
if input[17]<0.105689526:
var12=-0.0580670126
else:
var12=0.0352712236
else:
if input[20]<0.424948245:
var12=0.0679904744
else:
var12=-0.139103323
if input[8]<0.117493473:
if input[22]<0.0688949525:
var13=-0.109115146
else:
var13=-0.0112020876
else:
if input[16]<-161.820007:
var13=-0.0158145502
else:
var13=0.108068727
if input[18]<8:
if input[17]<0.000764720899:
var14=-0.100602493
else:
var14=0.0455532596
else:
if input[15]<0.159126669:
var14=0.00120864308
else:
var14=0.0614865758
if input[11]<-1708.65002:
if input[1]<4:
var15=-0.14637202
else:
var15=0.102645762
else:
if input[19]<2421.29004:
var15=0.00800912268
else:
var15=0.173493132
if input[20]<0.215512648:
if input[20]<-0.140497014:
var16=-0.069627054
else:
var16=0.012490782
else:
if input[7]<4508.37988:
var16=-0.13310793
else:
var16=0.298237801
if input[4]<10364.3701:
if input[18]<46:
var17=-0.000674185634
else:
var17=0.0702591166
else:
if input[19]<32.2999992:
var17=-0.11449907
else:
var17=0.102952585
if input[12]<0.110598043:
if input[9]<0.0641891882:
var18=-0.124259613
else:
var18=-0.00365586043
else:
if input[9]<0.0641891882:
var18=0.31589061
else:
var18=0.0643495396
return (var0, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, var11, var12, var13, var14, var15, var16, var17, var18)
So it seems that in the function returned by m2cgen the floats are 32 bits floats in the if/else conditions and not in the booster. So if in my data one sample has the value of the split m2cgen is not giving back the right value. Is there a trick to force floats to 64 bits?
Thanks in advance for your return,
Antoine
from m2cgen.
Hi, @izeigerman have you any thoughts on my previous comment or I need to open a new issue with my problem.
from m2cgen.
Related Issues (20)
- The converted C model compiles very slow,how can i fix it?
- Unable to export LGBM model that uses `categorical_feature` HOT 2
- Convert ML model to c# HOT 2
- Issue with xgboost export in python: not same values splits HOT 1
- Getting predict_prob when convert m2cgen to different languages
- Feature Request: support for multioutput regression HOT 1
- Process ended with exit code -1073741571 (0xC00000FD) HOT 1
- Results dont match
- Add Fortran
- LightGBM models
- Unable to export code from XGBoost 1.7.5 models HOT 6
- How to use model for jpg photo?
- Issue while Exporting a LightGBM Model to C#
- XGBoost exported to C generates wrong indices for input array
- Logitraw objective function ignored for xgb.Booster
- Cannot export XGBClassifier model: TypeError: unsupported operand type(s) for /: 'float' and 'NoneType' HOT 4
- Memory overflow in program run after compilation of C code transformed by the xgboost model HOT 1
- Cannot export XGBClassifier model: TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'
- XGBoost GCC Memory Allocation Error: cc1plus.exe: out of memory allocating 65536 bytes
- xgbmodel is larger than 500M, convert to c, always Segmentation fault
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from m2cgen.