Code Monkey home page Code Monkey logo

Comments (10)

JunMa11 avatar JunMa11 commented on July 18, 2024 2

from scipy.ndimage import distance_transform_edt as distance
from skimage import segmentation as skimage_seg

def compute_sdf(img_gt, out_shape):
    """
    compute the signed distance map of binary mask
    input: segmentation, shape = (batch_size, x, y, z)
    output: the Signed Distance Map (SDM) 
    sdf(x) = 0; x in segmentation boundary
             -inf|x-y|; x in segmentation
             +inf|x-y|; x out of segmentation
    """

    img_gt = img_gt.astype(np.uint8)

    gt_sdf = np.zeros(out_shape)

    for b in range(out_shape[0]): # batch size
        for c in range(1, out_shape[1]):
            posmask = img_gt[b]
            negmask = 1-posmask
            posdis = distance(posmask)
            negdis = distance(negmask)
            boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
            sdf = negdis - posdis
            sdf[boundary==1] = 0
            gt_sdf[b][c] = sdf

    return gt_sdf
def boundary_loss(outputs_soft, gt_sdf):
    """
    compute boundary loss for binary segmentation
    input: outputs_soft: softmax results,  shape=(b,2,x,y)
           gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y)
    output: boundary_loss; sclar
    """
    pc = outputs_soft[:,1,...]
    dc = gt_sdf[:,1,...]
    multipled = torch.einsum('bxy, bxy->bxy', pc, dc)
    bd_loss = multipled.mean()

    return bd_loss
        with torch.no_grad():
            gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(), outputs_soft.shape)
            gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(outputs_soft.device.index)

        loss_boundary = boundary_loss(outputs_soft, gt_sdf)
        loss = alpha*(loss_dice) + (1 - alpha) * loss_boundary # alpha = 1.0
        alpha -= 0.01
       if alpha <= 0.01:
           alpha = 0.01

from seglossodyssey.

MohamedAliRashad avatar MohamedAliRashad commented on July 18, 2024 1

@JunMa11 yeah and also make it multi-class ๐Ÿ˜…

from seglossodyssey.

JunMa11 avatar JunMa11 commented on July 18, 2024

Hi @MohamedAliRashad ,

We have a demo reimplementation at

https://github.com/JunMa11/SegWithDistMap/blob/8c3c7656f8259bff2d31509a881553d5cf5bb691/code/train_LA_BD.py#L121-L133

and

https://github.com/JunMa11/SegWithDistMap/blob/8c3c7656f8259bff2d31509a881553d5cf5bb691/code/train_LA_BD.py#L192-L203

Hope it could help you.

Best,
Jun

from seglossodyssey.

MohamedAliRashad avatar MohamedAliRashad commented on July 18, 2024

Thanks @JunMa11, but i still notice that the input and output are in this shape (b,2,x,y,z). Is there a way to make the DistLoss take the same inputs like the CrossEntropyLoss from pytorch.
network_output=(Batch, Num_classes, H, W)
target_shape=(Batch, H, W)

from seglossodyssey.

JunMa11 avatar JunMa11 commented on July 18, 2024

Hi @MohamedAliRashad ,

You may need to use one-hot encode to labels or manually specify one dimension like this

https://github.com/JunMa11/SegWithDistMap/blob/8c3c7656f8259bff2d31509a881553d5cf5bb691/code/train_LA.py#L108-L110

from seglossodyssey.

MohamedAliRashad avatar MohamedAliRashad commented on July 18, 2024

@JunMa11 How this will enable me to use it with the Boundary Loss ?
And while we are on the subject, what compound loss would you suggest for an imbalanced, hight IoU requirement ?

from seglossodyssey.

JunMa11 avatar JunMa11 commented on July 18, 2024

Hi @MohamedAliRashad ,

To use boundary loss, please refer to this out-of-the-box code.

what compound loss would you suggest for an imbalanced
Based on my experience, using BD loss+ Dice loss is better than only Dice loss.
However, it should be noted that this is still an unsolved problem.

from seglossodyssey.

MohamedAliRashad avatar MohamedAliRashad commented on July 18, 2024

@JunMa11 I have some questions if you can help me.

  1. what is the precomputed distance map is used for ?
  2. Does the implementations of DC_and_BD_loss support normal images like it supports volumetric ones ?
  3. if i want to combine the two losses is it better to just sum them or should i add weights ? and if i should add weights, how to determine them ?
  4. What is the apply_nonlin parameter in DICE ?

from seglossodyssey.

JunMa11 avatar JunMa11 commented on July 18, 2024

Hi @MohamedAliRashad ,
Sorry for my late reply.

  1. boundary loss relies on the distance map.
  2. No.
  3. According to the original paper, in the beginning, larger weights should be given to dice loss, and then gradually increase the weights of BD loss.
  4. apply_nonlin means apply softmax to the network logits output.

Again, please refer to

https://github.com/JunMa11/SegWithDistMap/blob/8c3c7656f8259bff2d31509a881553d5cf5bb691/code/train_LA_BD.py#L192-L203

if you want to use boundary loss for your own tasks.

Best,
Jun

from seglossodyssey.

MohamedAliRashad avatar MohamedAliRashad commented on July 18, 2024

@JunMa11 I am sorry for the disturbance, I am just trying to figure out how to adjust the DC_and_BD_loss function for a 2D input.

from seglossodyssey.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.