I have implemented your code in Pytorch and it worked properly but have the following concerns
Now, I followed your code and implemented focal loss as it is but My loss values are coming very less. Like random values is giving a score of 0.12 and quickly the loss is going 0.0012 and small
class FocalLoss_tensorflow(nn.Module):
def __init__(self, num_classes=20,
focusing_param = 2.0,
balance_param=0.25):
super(FocalLoss_2, self).__init__()
self.num_classes = num_classes
self.focusing_param = focusing_param
self.balance_param = balance_param
def focal_loss(self, x, y):
""" https://github.com/ailias/Focal-Loss-implement-on-Tensorflow/blob/master/focal_loss.py
everywhere people are just talking about num_classes. So lets remove the background class from focal loss calculation.
"""
x = x[:, 1:]
sigmoid_p = F.sigmoid(x)
anchors, classes = x.shape
t = torch.FloatTensor(anchors, classes+1)
t.zero_()
t.scatter_(1, y.data.cpu().view(-1, 1), 1)
t = Variable(t[:, 1:]).cuda()
zeros = Variable(torch.zeros(sigmoid_p.size())).cuda()
pos_p_sub = ((t >= sigmoid_p).float() * (t-sigmoid_p)) + ((t < sigmoid_p).float() * zeros)
neg_p_sub = ((t >= zeros).float() * zeros) + ((t <= zeros).float() * sigmoid_p)
per_entry_cross_ent = (-1) * self.balance_param * (pos_p_sub ** self.focusing_param) * torch.log(torch.clamp(sigmoid_p, 1e-8, 1.0)) -(1-self.balance_param) * (neg_p_sub ** self.focusing_param) * torch.log(torch.clamp(1.0-sigmoid_p, 1e-8, 1.0))
return per_entry_cross_ent.mean()
def forward(self, loc_preds, loc_targets, cls_preds, cls_targets):
batch_size, num_boxes = cls_targets.size()
pos = cls_targets > 0
num_pos = pos.data.long().sum()
mask = pos.unsqueeze(2).expand_as(loc_preds)
masked_loc_preds = loc_preds[mask].view(-1,4)
masked_loc_targets = loc_targets[mask].view(-1,4)
loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_targets, size_average=False)
loc_loss = loc_loss/num_pos
pos_neg = cls_targets > -1
mask = pos_neg.unsqueeze(2).expand_as(cls_preds)
masked_cls_preds = cls_preds[mask].view(-1, self.num_classes)
cls_loss = self.focal_loss(masked_cls_preds, cls_targets[pos_neg])
return loc_loss, cls_loss