Code Monkey home page Code Monkey logo

focalloss-for-lightgbm-xgboost's Introduction

focal loss in lightgbm(xgboost) for multi-class

This loss function contains focal loss[1],now only support lightgbm for multi-class(classes > 3,it will support xgboost and binary class task later)

focal loss

and alpha,gamma is the parameter of focal loss,which is:

image

alpha is used for imbalanced sample(It's no use while in multi-class task),and gamma is used for hard-to-learn sample,and in multi-class problem,it's seems that the alpha is no use.

in xgboost/lightgbm,we should provide the calculate formula of grad and hess.

while the first derivative of multi-class loss function is(From: https://zhuanlan.zhihu.com/p/149419189): image

and the second derivative is:

image

To speed up the calculate of loss function,we calculate this formula by numpy.

usage:

1.import loss function lib

import lightgbm as lgb
import lossfunction as lf
import numpy as np

2.init loss function

focal_loss_lgb = lf.ComplexLoss(gamma = 0.5)
param_dist= {'objective':focal_loss_lgb.focal_loss}
param_dist['num_class'] = '3'
clf_lgb = lgb.LGBMClassifier(**param_dist,random_state=2021)

3.train your dataset

clf_lgb.fit(X_train, y_train)

4.get probability result in your dataset

Here val_data is your validation dataset.

lgb_prob = clf_lgb.predict_proba(val_data)
lgb_prob = np.exp(lgb_prob)
results_lgb_prob = np.multiply(lgb_prob, 1/np.sum(lgb_prob, axis=1)[:, np.newaxis])

and results_lgb_prob is the probability of your validation dataset.

Ref:

[1] Lin T Y , Goyal P , Girshick R , et al. Focal Loss for Dense Object Detection[J]. IEEE Transactions on Pattern Analysis & Machine Intelligence, 2017, PP(99):2999-3007.

focalloss-for-lightgbm-xgboost's People

Contributors

y2019xcj avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.