[96354c]: / src / losses / region_based_loss.py

Download this file

36 lines (23 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from src.losses.dice_loss import DiceLoss
from torch import nn
class RegionBasedDiceLoss3D(nn.Module):
def __init__(self, classes: int, sigmoid_normalization: bool=True):
super(RegionBasedDiceLoss3D, self).__init__()
self.dice_loss = DiceLoss(classes=classes, sigmoid_normalization=sigmoid_normalization,
eval_regions=False)
self.dice_loss_region_based = DiceLoss(classes=classes, sigmoid_normalization=sigmoid_normalization,
eval_regions=True)
def forward(self, input: torch.tensor, target: torch.tensor, weight_reg: int=1, weight_dice: int=1):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
Forward pass
:param input: torch.tensor (NxCxDxHxW) Network output
:param target: ground truth torch.tensor (NxDxHxW)
:param weight: torch.tensor (N) to provide class weights
:return: scalar
"""
dice_loss, dice_score, _ = self.dice_loss(input, target)
dice_loss_reg, _, subregions = self.dice_loss_region_based(input, target)
total_loss = weight_dice*dice_loss + weight_reg*dice_loss_reg
return total_loss, dice_loss, dice_score, dice_loss_reg, subregions