Diff of /metrics.py [000000] .. [70e190]

Switch to unified view

a b/metrics.py
1
import numpy as np
2
import math
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from sklearn.utils.extmath import cartesian
7
from hausdorff import hausdorff_distance
8
9
__all__ = ['Dice loss', 'Cross entropy', 'Focal loss', 'Dice Iou Cross entropy', 'Binary dice loss']
10
11
12
class IOU(nn.Module):
13
    '''
14
    Calculate Intersection over Union (IoU) for semantic segmentation.
15
    
16
    Args:
17
        logits (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
18
        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
19
        num_classes (int): Number of classes
20
21
    Returns:
22
        tensor: Mean Intersection over Union (IoU) for the batch.
23
        list: List of IOU score for each class
24
    '''
25
    def __init__(self, num_classes, ignore_index=[0]):
26
        super(IOU, self).__init__()
27
        self.num_classes = num_classes
28
        self.ignore_index = ignore_index
29
        
30
    def forward(self, logits, target):
31
        pred = logits.argmax(dim=1)        
32
        target = target.argmax(dim=1)       
33
        ious = []
34
        for cls in range(self.num_classes):
35
            if cls in self.ignore_index: continue
36
            pred_mask = (pred == cls)
37
            target_mask = (target == cls)
38
                            
39
            intersection = (pred_mask & target_mask).sum().float()
40
            union = (pred_mask | target_mask).sum().float()
41
            
42
            if union == 0: iou = 1.0 
43
            else: iou = (intersection / union).item()
44
            ious.append(iou)
45
        
46
        mean_iou = sum(ious) / (self.num_classes - len(self.ignore_index))
47
        return torch.tensor(mean_iou), ious
48
49
    
50
class BinaryDice(nn.Module):
51
    '''
52
    Calculate Binary Dice score and Dice loss for binary segmentation or each class in Multiclass segmentation
53
    
54
    Args:
55
        logits (torch.Tensor): Predicted tensor of shape (batch_size, height, width, (depth))
56
        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width. (depth))
57
        
58
    Returns:
59
        tensor: Dice score
60
        tensor: Dice loss
61
    '''
62
    def __init__(self, smooth=1e-5, p=2):
63
        super(BinaryDice, self).__init__()
64
        self.smooth = smooth
65
        self.p = p
66
67
    def forward(self, logits, target):
68
        assert logits.shape[0] == target.shape[0], "logits & Target batch size don't match"
69
        smooth = 1e-5
70
        intersect = torch.sum(logits * target)        
71
        y_sum = torch.sum(target * target)
72
        z_sum = torch.sum(logits * logits)
73
        dice = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
74
        loss = 1 - dice
75
        return dice, loss
76
        
77
78
class Dice(nn.Module):
79
    '''
80
    Calculate Dice score and Dice loss for multiclass semantic segmentation
81
    
82
    Args:
83
        output (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
84
        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
85
        num_classes (int): Number of classes 
86
        
87
    Returns:
88
        tensor: Mean dice score over classes
89
        tensor: Mean dice loss over classes
90
        list: dice score for each classes
91
        listL dice loss for each classes
92
    '''
93
    def __init__(self, num_classes, weight=None, softmax=True, ignore_index=[0]):
94
        super(Dice, self).__init__()
95
        self.num_classes = num_classes
96
        self.weight = weight
97
        self.softmax = softmax
98
        self.ignore_index = ignore_index
99
        self.binary_dice = BinaryDice()
100
101
    def forward(self, logits, target):
102
        assert logits.shape == target.shape, 'logits & Target shape do not match'
103
        if self.softmax: logits = F.softmax(logits, dim=1)
104
        
105
        DICE, LOSS = 0.0, 0.0
106
        CLS_DICE, CLS_LOSS = [], []
107
        for clx in range(target.shape[1]):
108
            if clx in self.ignore_index: continue
109
            dice, loss = self.binary_dice(logits[:, clx], target[:, clx])
110
            CLS_DICE.append(dice.item())
111
            CLS_LOSS.append(loss.item())
112
            if self.weight is not None: dice *= self.weights[clx]
113
            DICE += dice
114
            LOSS += loss
115
116
        num_valid_classes = self.num_classes - len(self.ignore_index)
117
        return DICE / num_valid_classes, LOSS / num_valid_classes, CLS_DICE, CLS_LOSS
118
    
119
    
120
class WeightedHausdorffDistance(nn.Module):
121
    def __init__(self, height, width, p=-9, return_2_terms=False, device=torch.device('cuda')):
122
        '''
123
        height         (int):  image height
124
        width          (int):  image width
125
        return_2_terms (bool): Whether to return the 2 terms
126
                               of the WHD instead of their sum.
127
        '''
128
        super().__init__()
129
        self.height, self.width = height, width
130
        self.size = torch.tensor([height, width], dtype=torch.get_default_dtype(), device=device)
131
        self.max_dist = math.sqrt(height**2 + width**2)
132
        self.n_pixels = height * width
133
        self.all_img_locations = torch.from_numpy(cartesian([np.arange(height), np.arange(width)]))
134
        self.all_img_locations = self.all_img_locations.to(device=device, dtype=torch.get_default_dtype())
135
        self.return_2_terms = return_2_terms
136
        self.p = p
137
        
138
    def _assert_no_grad(self, variables):
139
        for var in variables:
140
            assert not var.requires_grad, \
141
                "nn criterions don't compute the gradient w.r.t. targets - please " \
142
                "mark these variables as volatile or not requiring gradients"
143
                
144
    def cdist(self, x, y):
145
        '''
146
        Compute distance between each pair of the two collections of inputs.
147
        x: Nxd Tensor
148
        y: Mxd Tensor
149
        return: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:]
150
                i.e. dist[i,j] = || x[i,:] - y[j,:] ||
151
        '''
152
        difs = x.unsqueeze(1) - y.unsqueeze(0)
153
        dists = torch.sum(difs**2, -1).sqrt()
154
        return dists
155
    
156
    def generalize_mean(self, tensor, dim, p=-9, keepdim=False):
157
        assert p < 0
158
        res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
159
        return res
160
        
161
    def forward(self, prob_map, gt, orig_sizes):
162
        '''
163
        prob_map: (B x H x W) Tensor of the probability map of the estimation.
164
                              B is batch size, H is height and W is width.
165
                              Values must be between 0 and 1.
166
                              
167
        gt: List of Tensors of the Ground Truth points.
168
            Must be of size B as in prob_map.
169
            Each element in the list must be a 2D Tensor,
170
            where each row is the (y, x), i.e, (row, col) of a GT point.
171
        
172
        orig_sizes: Bx2 Tensor containing the size
173
                    of the original images.
174
                    B is batch size.
175
                    The size must be in (height, width) format. 
176
                    
177
        return: Single-scalar Tensor with the Weighted Hausdorff Distance.
178
                If self.return_2_terms=True, then return a tuple containing
179
                the two terms of the Weighted Hausdorff Distance.
180
        '''
181
        
182
        self._assert_no_grad(gt)
183
        assert prob_map.dim() == 3, 'The probability map must be (B x H x W)'
184
        assert prob_map.size()[1:3] == (self.height, self.width), \
185
            'You must configure the WeightedHausdorffDistance with the height and width of the ' \
186
            'probability map that you are using, got a probability map of size %s'\
187
            % str(prob_map.size())
188
            
189
        batch_size = prob_map.shape[0]
190
        assert batch_size == len(gt)
191
        
192
        terms_1 = []
193
        terms_2 = []
194
        for b in range(batch_size):
195
            
196
            # One by one
197
            prob_map_b = prob_map[b, :, :]
198
            gt_b = gt[b]
199
            orig_size_b = orig_sizes[b, :]
200
            norm_factor = (orig_size_b / self.size).unsqueeze(0)
201
            n_gt_pts = gt_b.size()[0]
202
            
203
            # Corner case: no GT points
204
            if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0:
205
                terms_1.append(torch.tensor([0], 
206
                                            dtype=torch.get_default_dtype()))
207
                terms_2.append(torch.tensor([self.max_dist],
208
                                            dtype=torch.get_default_dtype())) 
209
                continue
210
            
211
            # Pairwise distances between all possible locations and the GTed locations
212
            n_gt_pts = gt_b.size()[0]
213
            normalized_x = norm_factor.repeat(self.n_pixels, 1) * self.all_img_locations
214
            normalized_y = norm_factor.repeat(len(gt_b), 1) * gt_b
215
            d_matrix = self.cdist(normalized_x, normalized_y)
216
            
217
            # Reshape probability map as a long column vector
218
            # and prepare it for mulitplication
219
            p = prob_map_b.view(prob_map_b.nelement())
220
            n_est_pts = p.sum()
221
            p_replicated = p.view(-1, 1).repeat(1, n_gt_pts)
222
            
223
            # Weighted Hausdorff Distance
224
            term_1 = (1 / (n_est_pts + 1e-6)) * torch.sum(p * torch.min(d_matrix, 1)[0])
225
            weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix
226
            minn = self.generalize_mean(weighted_d_matrix,
227
                                  p=self.p,
228
                                  dim=0, keepdim=False)
229
            term_2 = torch.mean(minn)
230
231
            terms_1.append(term_1)
232
            terms_2.append(term_2)
233
            
234
        terms_1 = torch.stack(terms_1)
235
        terms_2 = torch.stack(terms_2)
236
        
237
        if self.return_2_terms: res = terms_1.mean(), terms_2.means()
238
        else: res = terms_1.mean() + terms_2.mean()
239
        return res
240
    
241
242
class HD(nn.Module):
243
    def __init__(self):
244
        super().__init__()
245
        
246
    def forward(self, logits, target):
247
        _,logits = torch.max(logits, dim=1)
248
        _,target = torch.max(target, dim=1)
249
        
250
        logits = logits.detach().cpu().numpy()
251
        target = target.detach().cpu().numpy()
252
        
253
        hd = 0
254
        for index in range(logits.shape[0]):
255
            hd += hausdorff_distance(logits[index], target[index], distance='euclidean')
256
        
257
        return hd / logits.shape[0]