mkocabas / focal-loss-keras Goto Github PK
View Code? Open in Web Editor NEWFocal Loss implementation in Keras
License: MIT License
Focal Loss implementation in Keras
License: MIT License
Hi, I want to know if this function can be directly applied in the multi-label task. I used to use the binary-crossentropy as my loss function in the multi-label task , now I want to use focal loss to replace it. Should I make some changes or not ? Anybody help?
is there a version that works with theano backend
I guess I would be safe to add epsilon to the log. Something like:
return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
model.compile(loss=focal_loss, optimizer=sgd, metrics=[my_iou_metric])
File "/.local/lib/python3.6/site-packages/keras/engine/training.py", line 830, in compile/.local/lib/python3.6/site-packages/keras/engine/training.py", line 442, in weighted
sample_weight, mask)
File "
ndim = K.ndim(score_array)
File "~/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 610, in ndim
dims = x.get_shape()._dims
AttributeError: 'function' object has no attribute 'get_shape'
for my given data problem(nlp task) if i use focal_loss(gamma=1.5,alpha = .2) then the result i get like this :
Train for 2000 steps, validate for 32 steps Epoch 1/4 1999/2000 [============================>.] - ETA: 0s - loss: 0.0436 - auc: 0.9015 ROC-AUC - epoch: 1 - score: 0.875631 2000/2000 [==============================] - 1737s 869ms/step - loss: 0.0436 - auc: 0.9015 - val_loss: 0.0453 - val_auc: 0.8754 Epoch 2/4 1999/2000 [============================>.] - ETA: 0s - loss: 0.0241 - auc: 0.9758 ROC-AUC - epoch: 2 - score: 0.903044 2000/2000 [==============================] - 1610s 805ms/step - loss: 0.0241 - auc: 0.9758 - val_loss: 0.0360 - val_auc: 0.9030 Epoch 3/4 1999/2000 [============================>.] - ETA: 0s - loss: 0.0215 - auc: 0.9812 ROC-AUC - epoch: 3 - score: 0.911900 2000/2000 [==============================] - 1612s 806ms/step - loss: 0.0215 - auc: 0.9812 - val_loss: 0.0352 - val_auc: 0.9118 Epoch 4/4 1999/2000 [============================>.] - ETA: 0s - loss: 0.0191 - auc: 0.9852 ROC-AUC - epoch: 4 - score: 0.909566 2000/2000 [==============================] - 1610s 805ms/step - loss: 0.0191 - auc: 0.9852 - val_loss: 0.0445 - val_auc: 0.9074 CPU times: user 4min 15s, sys: 23.7 s, total: 4min 39s Wall time: 1h 49min 31s
now when i tried focal_loss(gamma=2.0,alpha = .2) i get :
Train for 1896 steps, validate for 32 steps Epoch 1/3 1895/1896 [============================>.] - ETA: 0s - loss: 0.0107 - auc: 0.9842 ROC-AUC - epoch: 1 - score: 0.544101 1896/1896 [==============================] - 1644s 867ms/step - loss: 0.0107 - auc: 0.9842 - val_loss: 0.0860 - val_auc: 0.5423 Epoch 2/3 1895/1896 [============================>.] - ETA: 0s - loss: 0.0083 - auc: 0.9904 ROC-AUC - epoch: 2 - score: 0.573175 1896/1896 [==============================] - 1522s 803ms/step - loss: 0.0083 - auc: 0.9904 - val_loss: 0.0659 - val_auc: 0.5210 Epoch 3/3 1895/1896 [============================>.] - ETA: 0s - loss: 0.0070 - auc: 0.9946 ROC-AUC - epoch: 3 - score: 0.375477 1896/1896 [==============================] - 1522s 803ms/step - loss: 0.0070 - auc: 0.9945 - val_loss: 0.0396 - val_auc: 0.4966 CPU times: user 2min 54s, sys: 17.4 s, total: 3min 12s Wall time: 1h 18min 9s
terrible val_auc right?
please help me choose alpha and gamma for focal loss
In the original paper the author said:
We adopt this form in our experiments as it yields slightly
improved accuracy over the non-α-balanced form. Finally,
we note that the implementation of the loss layer combines
the sigmoid operation for computing p with the loss computation,
resulting in greater numerical stability.
using sigmoid as the last activation in the network for multi class problems means that has a possibility to get more than one high probability?
Why sigmoid got better stability than Softmax?
@mkocabas
Hi,thanks for your cool code.
but I have little confused about how to use it. I want implemented it in faster rcnn,should I just change the softmax into focal loss or others?can you tell me how to use it?
thanks so much.
@mkocabas
when i use this loss function,it show me
AttributeError: 'function' object has no attribute 'get_shape'
when i call
model.compile(loss=focal_loss,optimizer=sgd,metrics=['accuracy'])
but when i use this
model.compile(loss='binary_crossentropy',optimizer=sgd,metrics=['accuracy'])
it works well.
I think alpha value is between 0 and 1
In the https://github.com/mkocabas/focal-loss-keras/blob/master/focal_loss.py file, you forgot to import
tensorflow: import tensorflow as tf
.
Focal loss is huge. Why is that?
For example: loss: 38107.2309
I am trying to use your implementation with a U-net binary semantic segmentation. The net compiles alright and model saves to .h5 file.
But when I load the model and start training I get an error stating that there is no such loss function declared.
model = load_model(MODEL_PATH)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/models.py", line 274, in load_model
sample_weight_mode=sample_weight_mode)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/engine/training.py", line 634, in compile
loss_functions = [losses.get(l) for l in loss]
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/engine/training.py", line 634, in
loss_functions = [losses.get(l) for l in loss]
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/losses.py", line 122, in get
return deserialize(identifier)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/losses.py", line 114, in deserialize
printable_module_name='loss function')
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 164, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:focal_loss_fixed
The role of alpha is to act as a weight between the two classes (pos and neg class have weights alpha,1-alpha respectively; so alpha should be chosen between 0 and 1)
In the current implementation, alpha is the same for both classes; so effectively does nothing to the loss function
This loss function is for binary focal loss.
the loss turns to NAN after an epoch. i have added K.epsilon() inside the log that was fine if gama is 2,1 if gama less than 1 thena the loss turns to NAN again
Hello, thank you for your code.
But when i try to use this loss function, i meet the following error.Could you please help me?
btw: my keras version==2.1.4 tensorflow-gpu==1.4.1
File "/home/lx/Personal/MyScript/DeepLevelSet/model.py", line 83, in get_dcan
model.compile(optimizer=Adam(lr=1e-5),loss=[focal_loss,focal_loss], loss_weights=[1,0.4],metrics=[dice_coef])
File "/home/lx/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 830, in compile
sample_weight, mask)
File "/home/lx/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 442, in weighted
ndim = K.ndim(score_array)
File "/home/lx/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 606, in ndim
dims = x.get_shape()._dims
AttributeError: 'function' object has no attribute 'get_shape'
Can you provide a multi-label version about the focal-loss-keras? thanks
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.