Code Monkey home page Code Monkey logo

Comments (11)

FelixGruen avatar FelixGruen commented on July 17, 2024 2

I think you maybe didn't understand my argument, because I made a mistake in my explanation. I should have read the code more closely before posting. But I'll try to explain the problem in more detail (correctly this time ;) and outline the implementation.

So lets imagine you have ten times more pixels with label a than you have with label b. To balance this out you want the gradients that come from pixels with label b to count ten times as much as gradients that come from pixels with label a, so that when you add it all up learning for both cases happens at the same speed.

The gradients are the derivative with respect to the loss. The math is a bit involved, but I found this explanation of how to compute the drivatives of the last output layer in the case of a subsequent softmax activation function with cross entropy loss.

The basic takeaway is that the gradient is ∂L / ∂oi = pi - yi where L is the loss, oi is the output of the last layer (the logits), yi is the label (vector) and pi is the result of the softmax activation function: pi = eoi / ( Σj eoj )

Now if you multiply the loss L by 10, the gradient will also be multiplied by 10. But if you multiply the logits oi by 10, it will only influence the gradient through the result of the softmax pi.

Specifically, the way it is implemented now (and here is where I was wrong above), the values oi (or ob, w, h, i to make it consistend with TensorFlow dimensions) will be multiplied with a certain weight wi across the whole feature map only depending on their class i and irrespective of the label of their pixel.

I sent you an (untested) PR for the implementation. You should multiply the label array with the weights and sum it up across the last dimension (the one that defines the classes), so that you have a weightmap which corresponds to the pixel labels. Then reshape it into a 1D vector and multiply it elementwise with the loss. That way you have larger gradients for pixels with a label with a larger weight, and smaller gradients for pixels with a label with a smaller weight.

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

I'm not sure if I understand your argumentation. The return value of softmax_cross_entropy_with_logits is a 1D-Tensor. How should one apply a class weight on this?
My intension with the implementation was, in case of class in-balance, to dampen the activations of a dominant class and amplify the others. However, could well be that I misunderstood the concept

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

Ok thats interessting. I need to think about this a little bit and thanks for the PR. I'll have a closer look

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

I had a closer look at your PR and the referenced explanation. I think I understand the concept but I'm struggling a bit with the implementation.
Anyway, I checked out your branch and let it run on a problem with a class imbalance. Something doesn't seem to be alright. After a few epochs the loss function started to drift away and exploded (from ~1 to 10^6).

from tf_unet.

FelixGruen avatar FelixGruen commented on July 17, 2024

Hm, that's indeed not good :)

As I said I just coded it down in an editor and didn't have time to test it. I'll see if I have time to look at it again. But if you find an error either in the concept or the implementation, I'm of course always grateful.

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

I just pushed a little extension to the toy problem such that the unet has to segment background (85%), circles (12%) and rectangles (2%). Maybe this is going to help to track down the issue

from tf_unet.

nicolov avatar nicolov commented on July 17, 2024

Some resources I found while looking at this issue:

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

HI nicolov, thanks for also looking into this.
The solution referenced in SO is essetially what is implemented in the master branch. The second solution in the post is what FelixGruen implemented in his new branch.

tf.nn.weighted_cross_entropy_with_logits sound interessting. If I understand correctly, the weights are supposed to do something slightly differently to what we try to achieve here. But maybe I could be used for our purpose. Any thoughts?

from tf_unet.

nicolov avatar nicolov commented on July 17, 2024

Yep, I agree with your analysis. I believe the tf.nn.weighted_cross_entropy_with_logits is for sigmoid activations, not softmax. Also interesting, this paper shows that a different loss function, based on the Dice coefficient, works better than re-weighting in the case of class inbalances.

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

I just pushed a new branch to make it easier to add new cost functions. Furhtermore I also included an implementation of the dice coeffient loss.

from tf_unet.

jakeret avatar jakeret commented on July 17, 2024

I merged the branch quite a while ago

from tf_unet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.