Code Monkey home page Code Monkey logo

focal_loss_pytorch's Introduction

Focal Loss for Dense Object Detection in PyTorch

focal loss Focal Loss for Dense Object Detection

Result

Method training set val set mAP
Cross Entropy Loss VOC2007 VOC2007 63.36
Focal Loss VOC2007 VOC2007 65.26

focal_loss_pytorch's People

Contributors

clcarwin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

focal_loss_pytorch's Issues

Normalisation & mean over classes

Should the loss, do a mean over all the classes of an anchor. For example if data comes in like this:
N * (Anum_classes) * H * W, where A is the number of anchors per cell. H, = height of feature map and W = width of feature map. Our targets are NAHW.

Now, we need to do a mean over each the num_classes of each A, of our loss computed over our N * (Anum_classes) * H * W input, to get it to NAHW.

This NAH*W should be summed over and then normalised by the number of anchors assigned to a ground truth box only, as the paper states the rest of the value is effectively zero.

NameError: name 'long' is not defined

When I run this program , an error occurs:

"NameError: name 'long' is not defined"

it points to line 11 of file focalloss.py:

if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])

What are the best parameters?

gamma, alpha and size_average.
What are your recommendation when you init the loss function?
How gamma and alpha affect the result?

Loss occurs nan!

During the adjustment process of gamma, when the value is between 0-1, the loss function has a gradient explosion. If it exceeds 1, it will not. May I ask what caused it

the problem of reshape

  • your code:
    input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
    input = input.transpose(1,2) # N,C,H*W => N,H*W,C
    input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C

hi, thank you for your code. and i have a question, why you can`t reshape the tensor from N,C,H,W => N* H* W,C directly

Why softmax instead of sigmoid is used?

Hi ๐Ÿ‘‹

I've accidentally stumbled upon your code. Looks nice, but I have a question though.

Why do you use Softmax instead of Sigmoid function?

logpt = F.log_softmax(input)

In the paper authors clearly stated the usage of sigmoid function, which makes sense. On the contrary, softmax doesn't (at least, it seems to me)

As a reference, implementation in fvcore has sigmoid function

I would be glad if someone could explain to me, whether it's a bug or there's some intuition behind this.

Thanks in advance!

Negative examples

Hello! Thank for your focal loss implementation. But I have question. I guessed we also have to consider negative cases of confidence in loss, I mean cases when pt = 1 - p. I see there are only pt = p(positive) cases, but no pt = 1 - p cases. Could you comment it?

Multi-class focal loss

I notice that it seems focal_loss.py only implements focal loss on the binary classification setting. Do there exist any implementations about focal loss for a multi-class classification setting?

if I set alpha the program crash

~/FullyConnected/focalloss.py in forward(self, input, target)
     32 
     33         loss = -1 * (1-pt)**self.gamma * logpt
---> 34         if self.size_average: return loss.mean()
     35         else: return loss.sum()

RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1524586445097/work/aten/src/THC/generated/../THCReduceAll.cuh:339

FocalLoss vs CrossEntropyLoss

In my experiments, the the loss of FocalLoss with gamma=0 is much lower than the loss of CrossEntropyLoss. What makes it?

Shouldn't it be log_sigmoid instead of softmax ?

The paper mentions that the loss layer is combined with the sigmoid computation and not softmax. More speciafically this line

Finally,
we note that the implementation of the loss layer combines
the sigmoid operation for computing p with the loss computation, resulting in greater numerical stability.

So isn't the author saying that we should use sigmoid activation over the last layer. The softmax usage maybe could lead to a lower accuracy.

alpha value is not consistent with paper?

It seems the loss=-alpha*(1-y')^{gamma}log(y') when y=1 in the paper.
But in the code:
loss=-(1-alpha)*(1-y')^{gamma}log(y') when y=1
So the alpha should be set in an oppisite way?...
In the code:
{
(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]),
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
}
Is this right?

shall the pt bp?

I find the calculation of pt using the data, which means your implementation not allowing pt Backpropagatio. So I argue should this allow bp?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.