Code Monkey home page Code Monkey logo

Comments (21)

eafpres avatar eafpres commented on May 13, 2024 1

@izeigerman I will try to create a repro for you! Thank you for looking into it. FYI, to be clear it isn't just multiclass, the binary cases seem to produce the same issue (for me).

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024 1

Hi--version:

m2cgen 0.6.0

note: installed using pip on 2002-02-19

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024 1

I've got a reproduce. tree_method='hist' seems to be causing this:

expected=[0.7084146  0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.7084146  0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14489177 0.7067961  0.14831214], actual=[0.14489178040434617, 0.7067960730189182, 0.14831214657673555]
expected=[0.14527214 0.7086516  0.14607628], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14489177 0.7067961  0.14831214], actual=[0.14489178040434617, 0.7067960730189182, 0.14831214657673555]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.14497484 0.14953361 0.7054915 ], actual=[0.14497486252688704, 0.14953362429410053, 0.7054915131790125]
expected=[0.7084146  0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.7084146  0.14699726 0.14458819], actual=[0.7084145564550915, 0.14699725651673898, 0.14458818702816942]
expected=[0.14527214 0.7086516  0.14607628], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]
expected=[0.15254477 0.66059124 0.18686394], actual=[0.14527215259960496, 0.7086515652569474, 0.14607628214344756]

(Last row)

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024 1

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

Hey @eafpres ! Thank you so much for reporting this. This is indeed quite bad.

It'd be very helpful if you can provide some steps to reproduce the issue. Perhaps model's hyperparameters, or some sort of a synthetic dataset on which the problem reproduces.

UPD If you can confirm whether the same problem occurs in other languages (eg Java) that'd be just fantastic.

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

Something is clearly wrong there. Eg. row 13 - probabilities don't even add up to 1.0 in m2cgen's predictions, which doesn't make sense.

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

@eafpres Got it! May I please also ask you to share the m2cgen version that you're using (m2cgen -v)?

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

Excellent, thank you. Another thing you may try is to downgrade to version 0.5.0. We had some significant changes in XGBoost-related code in 0.6.0 which may or may not have caused this behavior.

from m2cgen.

StrikerRUS avatar StrikerRUS commented on May 13, 2024

Hello @eafpres !

Brief question: are you using DART in your XGBoost model?

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024

@StrikerRUS -- I am not using DART, here is my configuration:

        model = xgb.XGBClassifier(colsample_bytree = colsample_bytree,
                                  gamma = gamma,
                                  learning_rate = learning_rate,
                                  max_depth = max_depth,
                                  min_child_weight = min_child_weight,
                                  n_estimators = n_iters,
                                  reg_alpha = reg_alpha,
                                  reg_lambda = reg_lambda,
                                  subsample = subsample,
                                  tree_method = 'hist',
                                  objective = objective,
                                  booster = 'gbtree',
                                  eval_metric = eval_metric,
                                  num_class = n_class)

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

@eafpres for the sake of this investigation, can you please try without tree_method='hist' to see whether the problem persists?

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024

@izeigerman -- I will test it and report back. Thanks for the amazing support.

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024

Hi @izeigerman -- I changed my configuration to use tree_method = 'exact' and the results are now exactly identical from the two models. This was for the multi-class case. I'll also be testing the binary case and confirm that.

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

@eafpres Got it, thank you so much for helping to triage the problem. We'll then be looking into why it misbehaves when the tree_method is set to hist.

from m2cgen.

StrikerRUS avatar StrikerRUS commented on May 13, 2024

@eafpres @izeigerman
Can you guys please post a repro for this error here, because I cannot spot the issue by simply adding tree_method='hist' here

XGBOOST_PARAMS = dict(base_score=0.6, n_estimators=10,
random_state=RANDOM_SEED)

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

@StrikerRUS as far as I remember this is exactly how I reproduced it. Have you tried the LARGE version of this test? Or perhaps increase the sample size, because some rows still match with this tree method.

from m2cgen.

StrikerRUS avatar StrikerRUS commented on May 13, 2024

@izeigerman Hmm, interesting... Thank you for the hint! I'll continue investigation.

from m2cgen.

izeigerman avatar izeigerman commented on May 13, 2024

@eafpres the issue has been fixed in version 0.7.0. Please feel free to reopen this issue if you notice it again.

from m2cgen.

eafpres avatar eafpres commented on May 13, 2024

@eafpres the issue has been fixed in version 0.7.0. Please feel free to reopen this issue if you notice it again.

I saw the release yesterday! I have to update a model that I'm using this for this week, I'll test it both ways and advise if it works. A apologize for not getting the repro to you. Thanks for working on this so promptly!

from m2cgen.

antoinemertz avatar antoinemertz commented on May 13, 2024

Hi,

Thanks for the work on this package. I'm using m2cgen to convert a XGBoost model into VBA code. But when using the code produce by m2cgen I've got some predictions that are really different from the one get after training my model in Python. Here are some examples:
Capture

And I've looked in the XGBoost booster after training and compare to the output (in Python) form m2cgen. Here is what I have from m2cgen

import math
def sigmoid(x):
    if x < 0.0:
        z = math.exp(x)
        return z / (1.0 + z)
    return 1.0 / (1.0 + math.exp(-x))
def score(input):
    if input[1] < 2.0:
        if input[1] < 1.0:
            var0 = -0.3193863
        else:
            var0 = -0.046659842
    else:
        if input[7] < 867.94:
            var0 = 0.058621403
        else:
            var0 = 0.25975806
    if input[4] < 5654.47:
        if input[3] < 0.38662624:
            var1 = -0.029487507
        else:
            var1 = 0.16083813
    else:
        if input[1] < 2.0:
            var1 = -0.32378462
        else:
            var1 = -0.08247565
    if input[0] < 1.0:
        if input[2] < 0.8:
            var2 = -0.15353489
        else:
            var2 = 0.081936955
    else:
        if input[4] < 2989.61:
            var2 = 0.13463722
        else:
            var2 = -0.042515814
    if input[5] < 0.11556604:
        if input[12] < 0.11059804:
            var3 = -0.1621976
        else:
            var3 = 0.30593434
    else:
        if input[11] < 661.39:
            var3 = 0.0063493266
        else:
            var3 = 0.15387529
    if input[9] < 0.12683104:
        if input[19] < 197.56:
            var4 = -0.25690553
        else:
            var4 = 0.06560632
    else:
        if input[8] < 0.11749347:
            var4 = -0.018011741
        else:
            var4 = 0.10678521
    if input[7] < 1790.11:
        if input[8] < 0.11749347:
            var5 = -0.091719724
        else:
            var5 = 0.048037946
    else:
        if input[1] < 3.0:
            var5 = 0.058297392
        else:
            var5 = 0.18175843
    if input[6] < 1351.78:
        if input[10] < 3.0:
            var6 = -0.0012290713
        else:
            var6 = 0.10081242
    else:
        if input[17] < 0.07381933:
            var6 = -0.12741692
        else:
            var6 = 0.038392954
    if input[1] < 3.0:
        if input[15] < 0.12838633:
            var7 = -0.081163615
        else:
            var7 = 0.019387348
    else:
        if input[20] < 0.29835963:
            var7 = 0.1156334
        else:
            var7 = -0.17409053
    if input[5] < 0.062735535:
        if input[3] < 0.5642857:
            var8 = -0.2049814
        else:
            var8 = 0.12192867
    else:
        if input[13] < 17.0:
            var8 = -0.0035746796
        else:
            var8 = 0.10629323
    if input[19] < 179.98:
        if input[4] < 15379.7:
            var9 = -0.010353668
        else:
            var9 = -0.19715081
    else:
        if input[21] < 1744.96:
            var9 = 0.08414988
        else:
            var9 = -0.31387258
    if input[9] < 0.12683104:
        if input[19] < 90.45:
            var10 = -0.15493616
        else:
            var10 = 0.05997152
    else:
        if input[11] < -1390.57:
            var10 = -0.12933072
        else:
            var10 = 0.028274538
    if input[14] < 3.0:
        if input[7] < 652.72:
            var11 = -0.061523404
        else:
            var11 = 0.018090146
    else:
        if input[20] < -0.015413969:
            var11 = 0.122180216
        else:
            var11 = -0.07323579
    if input[18] < 35.0:
        if input[17] < 0.105689526:
            var12 = -0.058067013
        else:
            var12 = 0.035271224
    else:
        if input[20] < 0.42494825:
            var12 = 0.067990474
        else:
            var12 = -0.13910332
    if input[8] < 0.11749347:
        if input[22] < 0.06889495:
            var13 = -0.109115146
        else:
            var13 = -0.011202088
    else:
        if input[16] < -161.82:
            var13 = -0.01581455
        else:
            var13 = 0.10806873
    if input[18] < 8.0:
        if input[17] < 0.0007647209:
            var14 = -0.10060249
        else:
            var14 = 0.04555326
    else:
        if input[15] < 0.15912667:
            var14 = 0.0012086431
        else:
            var14 = 0.061486576
    if input[11] < -1708.65:
        if input[1] < 4.0:
            var15 = -0.14637202
        else:
            var15 = 0.10264576
    else:
        if input[19] < 2421.29:
            var15 = 0.008009123
        else:
            var15 = 0.17349313
    if input[20] < 0.21551265:
        if input[20] < -0.14049701:
            var16 = -0.069627054
        else:
            var16 = 0.012490782
    else:
        if input[7] < 4508.38:
            var16 = -0.13310793
        else:
            var16 = 0.2982378
    if input[4] < 10364.37:
        if input[18] < 46.0:
            var17 = -0.00067418563
        else:
            var17 = 0.07025912
    else:
        if input[19] < 32.3:
            var17 = -0.11449907
        else:
            var17 = 0.102952585
    if input[12] < 0.11059804:
        if input[9] < 0.06418919:
            var18 = -0.12425961
        else:
            var18 = -0.0036558604
    else:
        if input[9] < 0.06418919:
            var18 = 0.3158906
        else:
            var18 = 0.06434954
    var19 = sigmoid(var0 + var1 + var2 + var3 + var4 + var5 + var6 + var7 + var8 + var9 + var10 + var11 + var12 + var13 + var14 + var15 + var16 + var17 + var18)
    return [1.0 - var19, var19]

And this is what I have in the booster:

def score_booster(input):
    if input[1]<2:
        if input[1]<1:
            var0=-0.319386303
        else:
            var0=-0.0466598421
    else:
        if input[7]<867.940002:
            var0=0.0586214028
        else:
            var0=0.259758055

    if input[4]<5654.47021:
        if input[3]<0.386626244:
            var1=-0.0294875074
        else:
            var1=0.160838127
    else:
        if input[1]<2:
            var1=-0.32378462
        else:
            var1=-0.0824756473

    if input[0]<1:
        if input[2]<0.800000012:
            var2=-0.153534889
        else:
            var2=0.0819369555
    else:
        if input[4]<2989.61011:
            var2=0.134637222
        else:
            var2=-0.0425158143

    if input[5]<0.115566038:
        if input[12]<0.110598043:
            var3=-0.162197605
        else:
            var3=0.30593434
    else:
        if input[11]<661.390015:
            var3=0.00634932658
        else:
            var3=0.153875291

    if input[9]<0.12683104:
        if input[19]<197.559998:
            var4=-0.256905526
        else:
            var4=0.0656063184
    else:
        if input[8]<0.117493473:
            var4=-0.0180117413
        else:
            var4=0.106785208

    if input[7]<1790.10999:
        if input[8]<0.117493473:
            var5=-0.0917197242
        else:
            var5=0.0480379462
    else:
        if input[1]<3:
            var5=0.058297392
        else:
            var5=0.181758434

    if input[6]<1351.78003:
        if input[10]<3:
            var6=-0.00122907129
        else:
            var6=0.10081242
    else:
        if input[17]<0.0738193318:
            var6=-0.127416924
        else:
            var6=0.0383929536

    if input[1]<3:
        if input[15]<0.128386334:
            var7=-0.081163615
        else:
            var7=0.0193873476
    else:
        if input[20]<0.298359632:
            var7=0.115633398
        else:
            var7=-0.174090534

    if input[5]<0.0627355352:
        if input[3]<0.564285696:
            var8=-0.204981402
        else:
            var8=0.12192867
    else:
        if input[13]<17:
            var8=-0.00357467961
        else:
            var8=0.106293231

    if input[19]<179.979996:
        if input[4]<15379.7002:
            var9=-0.0103536677
        else:
            var9=-0.197150812
    else:
        if input[21]<1744.95996:
            var9=0.0841498822
        else:
            var9=-0.313872576
    
    if input[9]<0.12683104:
        if input[19]<90.4499969:
            var10=-0.154936165
        else:
            var10=0.0599715188
    else:
        if input[11]<-1390.56995:
            var10=-0.129330724
        else:
            var10=0.028274538

    if input[14]<3:
        if input[7]<652.719971:
            var11=-0.061523404
        else:
            var11=0.0180901457
    else:
        if input[20]<-0.0154139688:
            var11=0.122180216
        else:
            var11=-0.0732357875

    if input[18]<35:
        if input[17]<0.105689526:
            var12=-0.0580670126
        else:
            var12=0.0352712236
    else:
        if input[20]<0.424948245:
            var12=0.0679904744
        else:
            var12=-0.139103323

    if input[8]<0.117493473:
        if input[22]<0.0688949525:
            var13=-0.109115146
        else:
            var13=-0.0112020876
    else:
        if input[16]<-161.820007:
            var13=-0.0158145502
        else:
            var13=0.108068727

    if input[18]<8:
        if input[17]<0.000764720899:
            var14=-0.100602493
        else:
            var14=0.0455532596
    else:
        if input[15]<0.159126669:
            var14=0.00120864308
        else:
            var14=0.0614865758
    
    if input[11]<-1708.65002:
        if input[1]<4:
            var15=-0.14637202
        else:
            var15=0.102645762
    else:
        if input[19]<2421.29004:
            var15=0.00800912268
        else:
            var15=0.173493132

    if input[20]<0.215512648:
        if input[20]<-0.140497014:
            var16=-0.069627054
        else:
            var16=0.012490782
    else:
        if input[7]<4508.37988:
            var16=-0.13310793
        else:
            var16=0.298237801
    
    if input[4]<10364.3701:
        if input[18]<46:
            var17=-0.000674185634
        else:
            var17=0.0702591166
    else:
        if input[19]<32.2999992:
            var17=-0.11449907
        else:
            var17=0.102952585

    if input[12]<0.110598043:
        if input[9]<0.0641891882:
            var18=-0.124259613
        else:
            var18=-0.00365586043
    else:
        if input[9]<0.0641891882:
            var18=0.31589061
        else:
            var18=0.0643495396
    
    return (var0, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, var11, var12, var13, var14, var15, var16, var17, var18)

So it seems that in the function returned by m2cgen the floats are 32 bits floats in the if/else conditions and not in the booster. So if in my data one sample has the value of the split m2cgen is not giving back the right value. Is there a trick to force floats to 64 bits?

Thanks in advance for your return,

Antoine

from m2cgen.

antoinemertz avatar antoinemertz commented on May 13, 2024

Hi, @izeigerman have you any thoughts on my previous comment or I need to open a new issue with my problem.

from m2cgen.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.