Comments (23)
Could not find the numpy implementation, made a new one instead. This one will correctly work with batches (which the previous one did not) and is compatible with nd tensors (2d, 3d, etc segmentations):
def soft_dice_numpy(y_pred, y_true, eps=1e-7):
'''
c is number of classes
:param y_pred: b x c x X x Y( x Z...) network output, must sum to 1 over c channel (such as after softmax)
:param y_true: b x c x X x Y( x Z...) one hot encoding of ground truth
:param eps:
:return:
'''
axes = tuple(range(2, len(y_pred.shape)))
intersect = np.sum(y_pred * y_true, axes)
denom = np.sum(y_pred + y_true, axes)
return - (2. *intersect / (denom + eps)).mean()
from recipes.
Hi Jeremy,
whether you square any of these values is a design choice. Basically it depends on how you approximate the intersection / cardinality of the segmentations. Since the dice loss only approximates the dice, you can choose whatever approximation you want. Both are correct.
The first paper that used the dice loss (that I am aware of) is this one: https://arxiv.org/pdf/1608.04117.pdf
They use the same approximation as I do. I very much prefer this formulation and I am using it for all my research.
Regards,
Fabian
from recipes.
from recipes.
Hi Alex,
sure. It's not like it's a secret or anything ;-)
def soft_dice(y_pred, y_true):
# y_pred is softmax output of shape (num_samples, num_classes)
# y_true is one hot encoding of target (shape= (num_samples, num_classes))
intersect = T.sum(y_pred * y_true, 0)
denominator = T.sum(y_pred, 0) + T.sum(y_true, 0)
dice_scores = T.constant(2) * intersect / (denominator + T.constant(1e-6))
return dice_scores
Make sure that you input shapes are correct and that you use one hot encoding. My implementation will return a soft dice score for each class (output shape is (num_classes, )). I got some decent results with it (same as state of the art on BraTS 2015 train data). Note that this implementation ignores that there may be more than one sample in the batch. You should be able to modify my implementation so that it can deal with batches > 1 as well.
If you have any suggestions for improvements please let me know.
Cheers,
Fabian
from recipes.
Hi Fabian,
Your answer makes sense. My bad, argmax is not differentiable. (How does Theano even run though?Shouldn't it break? Or is it trying to approximate it?)
I will try with another solution to convert the target images into one-hot ones, so that no argmax is required. The reason I am insisting with dice coefficienct is that I think that it could be better than cross entropy for segmentation problems.
-Alex
from recipes.
from recipes.
@FabianIsensee I implemented the Dice score based on our discussion here. You can see the gist here: https://gist.github.com/mongoose54/71e174587fbec8c2fe970e8a1c14eff4 Although it is not complaining when using as a metric I am getting some weird numbers some times. Have you been able to implement it? Would it be possible to share?
-Alex
from recipes.
IIUC, Should we do a negative of weighted sum of the dice scores for computing the loss to back propagate for a multi class problem? @FabianIsensee
from recipes.
Hi FabianIsensee,
Can you please explain the dice loss calculation for multi-class problem.
Best
from recipes.
Can you please explain the dice loss calculation for multi-class problem.
The way Fabian implemented it each class will be treated as a binary problem, and you will get a score per class in the end. For training, you can try using the negative sum of the scores as the total loss to minimize.
from recipes.
@f0k,
Thanks for response.
I have implemented it in caffe and code is listed below: However when i generate predictions that are binary. An output is attached below:
`
class DiceLossLayer(caffe.Layer):
def forward(self, bottom, top):
self.diff[...] = bottom[1].data
top[0].data[...] = 1 - self.dice_coef_multi_class(bottom[0], bottom[1])
def backward(self, top, propagate_down, bottom):
if propagate_down[1]:
raise Exception("label not diff")
elif propagate_down[0]:
a=(-2. * self.diff + self.dice) / self.sum
bottom[0].diff[...] = a
else:
raise Exception("no diff")
# =============================
def dice_coef_multi_class(self, y_pred, y_true):
n_classes = 5
smooth=np.float32(1e-7)
y_true=y_true.data
y_pred=y_pred.data
y_pred = np.argmax(y_pred, 1)
y_pred = np.expand_dims(y_pred,1)
y_pred=np.ndarray.flatten(y_pred)
y_true = np.ndarray.flatten(y_true)
dice = np.zeros(n_classes)
self.sum = np.zeros([n_classes])
for i in range(n_classes):
y_true_i = np.equal(y_true, i)
y_pred_i = np.equal(y_pred, i)
self.sum[i] = np.sum(y_true_i) + np.sum(y_pred_i) + smooth
dice[i] = (2. * np.sum(y_true_i * y_pred_i) + smooth) / self.sum[i]
self.sum=np.sum(self.sum)
self.dice=np.sum(dice)
return self.dice
from recipes.
Your implementation looks complicated; can't you directly translate Fabian's version to numpy?
from recipes.
@f0k
Please look at only three methods which are def forward(...)
, def backward(...)
. these methods are required for caffe.
from recipes.
Hi, sorry for being late to the party. I would strongly suggest to directly translate my implementation to numpy because that is much easier to read. I have it somewhere as well (will try to find it).
One thing that I absolutely don't like about caffe is that you manually need to implement the gradient which is where it is very easy to make mistakes (I did not check your backward implementation).
The implementation of dice_coef_multi_class looks fine to me.
For how long did you train the network? It very often happens that the network will start out like in your figure and get better over time.
edit: I saw a mistake in dice_coef_multi_class (and thus probably the gradient as well). You want a loss function, so something that gets lower the better the network is. Therefore, like Jan mentioned previously, return -self.dice
!
from recipes.
@mongoose54 Should the output of the dice_coef_loss be negated ?
def dice_coef_loss(y_pred, y_true):
return - dice_coef(y_pred, y_true)
from recipes.
Yes. It's a loss function and lasagne minimizes the loss. In order to maximize the dice, you need to minimize the negative dice loss
from recipes.
According to this dissertation (page 72) in which the author discusses their paper, V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation, the values in the denominator should be squared.
I've adapted @FabianIsensee 's solution in this Gist and welcome critique/discussion of the changes.
from recipes.
Oh I see, thanks! :)
from recipes.
Hello,@FabianIsensee
When I use dice as loss function, the predicted image was all zeros!
Could you help me to analyze it?
from recipes.
Hi,
you are not exactly giving a lot of detail. My guess is that your learning rate etc is not optimal OR you should consider optimizing the background as well (also compute the dice loss for the background task)
from recipes.
Hi,
you are not exactly giving a lot of detail. My guess is that your learning rate etc is not optimal OR you should consider optimizing the background as well (also compute the dice loss for the background task)
thank you for reply!
This is my dice loss function. Under implemention of U-Net.
def dice_coef(y_true, y_pred):
smooth = 1
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection +smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +smooth)
def dice_coef_loss(y_true, y_pred):
print("dice loss")
return 1-dice_coef(y_true, y_pred)
...
model.compile(optimizer = Adam(lr = 1e-4), loss = dice_loss, metrics = ['accuracy'])
So ,I was so struggling where is going wrong.
from recipes.
whether you square any of these values is a design choice. Basically it depends on how you approximate the intersection / cardinality of the segmentations. Since the dice loss only approximates the dice, you can choose whatever approximation you want. Both are correct.
It is really unfair to call it a design choice. The squared formulation is better because it has an obvious mathematical geometric meaning. Consider the cosine law from high school:
https://www.mathsisfun.com/algebra/trig-cosine-law.html
Immediate from the cosine law, the squared formulation is the cosine of the prediction and the target, viewed as vectors. The non-squared version is merely "not wrong" per-se, but it takes some mental gymnastics to make it into something meaningful mathematically.
from recipes.
Hi @liuyipei , I must admit I am probably not as good of a mathematician as I would like to be - there may certainly be theoretical advantages of squaring vs not squaring. I am a man of results though and at least in my experiments squaring does not perform as well. That could be to a number of reasons (hyperparameters tuned for non-squaring), so I am not saying that this observation is going to be true for everyone. But since my results are performing really well on several segmentation leaderboards I am confident that not squaring is a non-issue in practice :-)
Best,
Fabian
from recipes.
Related Issues (20)
- 3D UNet implementation HOT 7
- reason behind low value of parameters in VGG19 HOT 1
- error when set values for vgg-19 HOT 2
- modelzoo resnet50.py incompatible to Python 3 HOT 3
- Implementation of Convolutional Spatial Transformer and Siamese network HOT 3
- no sandbox.cuda
- Bad argument to Theano: No. of dimensions changes in the error after reshaping
- Question not an Issue: fliping the arrays HOT 1
- cifar100 with resnet HOT 1
- pretrained network for small images HOT 1
- https://s3.amazonaws.com/lasagne/recipes/pretrained/imagenet/vgg16.pkl HOT 6
- vgg16.pkl without aws cli
- Need help with S3 Browser based downloads HOT 1
- Dice coeff is not changing since the first epoch and binary accuracy changes and is increased to 1?
- Problem with op.grad in OpFromGraph - Guided Backpropagation
- Wrong order of stride and pad arguments in build_simple_block HOT 3
- Broken links in Video features with C3D.ipynb example HOT 5
- Training C3D
- Wrong pretrained weights for UNet example HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from recipes.