|
a |
|
b/metrics.py |
|
|
1 |
import numpy as np |
|
|
2 |
import math |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
from sklearn.utils.extmath import cartesian |
|
|
7 |
from hausdorff import hausdorff_distance |
|
|
8 |
|
|
|
9 |
__all__ = ['Dice loss', 'Cross entropy', 'Focal loss', 'Dice Iou Cross entropy', 'Binary dice loss'] |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class IOU(nn.Module): |
|
|
13 |
''' |
|
|
14 |
Calculate Intersection over Union (IoU) for semantic segmentation. |
|
|
15 |
|
|
|
16 |
Args: |
|
|
17 |
logits (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth)) |
|
|
18 |
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth)) |
|
|
19 |
num_classes (int): Number of classes |
|
|
20 |
|
|
|
21 |
Returns: |
|
|
22 |
tensor: Mean Intersection over Union (IoU) for the batch. |
|
|
23 |
list: List of IOU score for each class |
|
|
24 |
''' |
|
|
25 |
def __init__(self, num_classes, ignore_index=[0]): |
|
|
26 |
super(IOU, self).__init__() |
|
|
27 |
self.num_classes = num_classes |
|
|
28 |
self.ignore_index = ignore_index |
|
|
29 |
|
|
|
30 |
def forward(self, logits, target): |
|
|
31 |
pred = logits.argmax(dim=1) |
|
|
32 |
target = target.argmax(dim=1) |
|
|
33 |
ious = [] |
|
|
34 |
for cls in range(self.num_classes): |
|
|
35 |
if cls in self.ignore_index: continue |
|
|
36 |
pred_mask = (pred == cls) |
|
|
37 |
target_mask = (target == cls) |
|
|
38 |
|
|
|
39 |
intersection = (pred_mask & target_mask).sum().float() |
|
|
40 |
union = (pred_mask | target_mask).sum().float() |
|
|
41 |
|
|
|
42 |
if union == 0: iou = 1.0 |
|
|
43 |
else: iou = (intersection / union).item() |
|
|
44 |
ious.append(iou) |
|
|
45 |
|
|
|
46 |
mean_iou = sum(ious) / (self.num_classes - len(self.ignore_index)) |
|
|
47 |
return torch.tensor(mean_iou), ious |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
class BinaryDice(nn.Module): |
|
|
51 |
''' |
|
|
52 |
Calculate Binary Dice score and Dice loss for binary segmentation or each class in Multiclass segmentation |
|
|
53 |
|
|
|
54 |
Args: |
|
|
55 |
logits (torch.Tensor): Predicted tensor of shape (batch_size, height, width, (depth)) |
|
|
56 |
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width. (depth)) |
|
|
57 |
|
|
|
58 |
Returns: |
|
|
59 |
tensor: Dice score |
|
|
60 |
tensor: Dice loss |
|
|
61 |
''' |
|
|
62 |
def __init__(self, smooth=1e-5, p=2): |
|
|
63 |
super(BinaryDice, self).__init__() |
|
|
64 |
self.smooth = smooth |
|
|
65 |
self.p = p |
|
|
66 |
|
|
|
67 |
def forward(self, logits, target): |
|
|
68 |
assert logits.shape[0] == target.shape[0], "logits & Target batch size don't match" |
|
|
69 |
smooth = 1e-5 |
|
|
70 |
intersect = torch.sum(logits * target) |
|
|
71 |
y_sum = torch.sum(target * target) |
|
|
72 |
z_sum = torch.sum(logits * logits) |
|
|
73 |
dice = (2 * intersect + smooth) / (z_sum + y_sum + smooth) |
|
|
74 |
loss = 1 - dice |
|
|
75 |
return dice, loss |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
class Dice(nn.Module): |
|
|
79 |
''' |
|
|
80 |
Calculate Dice score and Dice loss for multiclass semantic segmentation |
|
|
81 |
|
|
|
82 |
Args: |
|
|
83 |
output (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth)) |
|
|
84 |
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth)) |
|
|
85 |
num_classes (int): Number of classes |
|
|
86 |
|
|
|
87 |
Returns: |
|
|
88 |
tensor: Mean dice score over classes |
|
|
89 |
tensor: Mean dice loss over classes |
|
|
90 |
list: dice score for each classes |
|
|
91 |
listL dice loss for each classes |
|
|
92 |
''' |
|
|
93 |
def __init__(self, num_classes, weight=None, softmax=True, ignore_index=[0]): |
|
|
94 |
super(Dice, self).__init__() |
|
|
95 |
self.num_classes = num_classes |
|
|
96 |
self.weight = weight |
|
|
97 |
self.softmax = softmax |
|
|
98 |
self.ignore_index = ignore_index |
|
|
99 |
self.binary_dice = BinaryDice() |
|
|
100 |
|
|
|
101 |
def forward(self, logits, target): |
|
|
102 |
assert logits.shape == target.shape, 'logits & Target shape do not match' |
|
|
103 |
if self.softmax: logits = F.softmax(logits, dim=1) |
|
|
104 |
|
|
|
105 |
DICE, LOSS = 0.0, 0.0 |
|
|
106 |
CLS_DICE, CLS_LOSS = [], [] |
|
|
107 |
for clx in range(target.shape[1]): |
|
|
108 |
if clx in self.ignore_index: continue |
|
|
109 |
dice, loss = self.binary_dice(logits[:, clx], target[:, clx]) |
|
|
110 |
CLS_DICE.append(dice.item()) |
|
|
111 |
CLS_LOSS.append(loss.item()) |
|
|
112 |
if self.weight is not None: dice *= self.weights[clx] |
|
|
113 |
DICE += dice |
|
|
114 |
LOSS += loss |
|
|
115 |
|
|
|
116 |
num_valid_classes = self.num_classes - len(self.ignore_index) |
|
|
117 |
return DICE / num_valid_classes, LOSS / num_valid_classes, CLS_DICE, CLS_LOSS |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
class WeightedHausdorffDistance(nn.Module): |
|
|
121 |
def __init__(self, height, width, p=-9, return_2_terms=False, device=torch.device('cuda')): |
|
|
122 |
''' |
|
|
123 |
height (int): image height |
|
|
124 |
width (int): image width |
|
|
125 |
return_2_terms (bool): Whether to return the 2 terms |
|
|
126 |
of the WHD instead of their sum. |
|
|
127 |
''' |
|
|
128 |
super().__init__() |
|
|
129 |
self.height, self.width = height, width |
|
|
130 |
self.size = torch.tensor([height, width], dtype=torch.get_default_dtype(), device=device) |
|
|
131 |
self.max_dist = math.sqrt(height**2 + width**2) |
|
|
132 |
self.n_pixels = height * width |
|
|
133 |
self.all_img_locations = torch.from_numpy(cartesian([np.arange(height), np.arange(width)])) |
|
|
134 |
self.all_img_locations = self.all_img_locations.to(device=device, dtype=torch.get_default_dtype()) |
|
|
135 |
self.return_2_terms = return_2_terms |
|
|
136 |
self.p = p |
|
|
137 |
|
|
|
138 |
def _assert_no_grad(self, variables): |
|
|
139 |
for var in variables: |
|
|
140 |
assert not var.requires_grad, \ |
|
|
141 |
"nn criterions don't compute the gradient w.r.t. targets - please " \ |
|
|
142 |
"mark these variables as volatile or not requiring gradients" |
|
|
143 |
|
|
|
144 |
def cdist(self, x, y): |
|
|
145 |
''' |
|
|
146 |
Compute distance between each pair of the two collections of inputs. |
|
|
147 |
x: Nxd Tensor |
|
|
148 |
y: Mxd Tensor |
|
|
149 |
return: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:] |
|
|
150 |
i.e. dist[i,j] = || x[i,:] - y[j,:] || |
|
|
151 |
''' |
|
|
152 |
difs = x.unsqueeze(1) - y.unsqueeze(0) |
|
|
153 |
dists = torch.sum(difs**2, -1).sqrt() |
|
|
154 |
return dists |
|
|
155 |
|
|
|
156 |
def generalize_mean(self, tensor, dim, p=-9, keepdim=False): |
|
|
157 |
assert p < 0 |
|
|
158 |
res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p) |
|
|
159 |
return res |
|
|
160 |
|
|
|
161 |
def forward(self, prob_map, gt, orig_sizes): |
|
|
162 |
''' |
|
|
163 |
prob_map: (B x H x W) Tensor of the probability map of the estimation. |
|
|
164 |
B is batch size, H is height and W is width. |
|
|
165 |
Values must be between 0 and 1. |
|
|
166 |
|
|
|
167 |
gt: List of Tensors of the Ground Truth points. |
|
|
168 |
Must be of size B as in prob_map. |
|
|
169 |
Each element in the list must be a 2D Tensor, |
|
|
170 |
where each row is the (y, x), i.e, (row, col) of a GT point. |
|
|
171 |
|
|
|
172 |
orig_sizes: Bx2 Tensor containing the size |
|
|
173 |
of the original images. |
|
|
174 |
B is batch size. |
|
|
175 |
The size must be in (height, width) format. |
|
|
176 |
|
|
|
177 |
return: Single-scalar Tensor with the Weighted Hausdorff Distance. |
|
|
178 |
If self.return_2_terms=True, then return a tuple containing |
|
|
179 |
the two terms of the Weighted Hausdorff Distance. |
|
|
180 |
''' |
|
|
181 |
|
|
|
182 |
self._assert_no_grad(gt) |
|
|
183 |
assert prob_map.dim() == 3, 'The probability map must be (B x H x W)' |
|
|
184 |
assert prob_map.size()[1:3] == (self.height, self.width), \ |
|
|
185 |
'You must configure the WeightedHausdorffDistance with the height and width of the ' \ |
|
|
186 |
'probability map that you are using, got a probability map of size %s'\ |
|
|
187 |
% str(prob_map.size()) |
|
|
188 |
|
|
|
189 |
batch_size = prob_map.shape[0] |
|
|
190 |
assert batch_size == len(gt) |
|
|
191 |
|
|
|
192 |
terms_1 = [] |
|
|
193 |
terms_2 = [] |
|
|
194 |
for b in range(batch_size): |
|
|
195 |
|
|
|
196 |
# One by one |
|
|
197 |
prob_map_b = prob_map[b, :, :] |
|
|
198 |
gt_b = gt[b] |
|
|
199 |
orig_size_b = orig_sizes[b, :] |
|
|
200 |
norm_factor = (orig_size_b / self.size).unsqueeze(0) |
|
|
201 |
n_gt_pts = gt_b.size()[0] |
|
|
202 |
|
|
|
203 |
# Corner case: no GT points |
|
|
204 |
if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0: |
|
|
205 |
terms_1.append(torch.tensor([0], |
|
|
206 |
dtype=torch.get_default_dtype())) |
|
|
207 |
terms_2.append(torch.tensor([self.max_dist], |
|
|
208 |
dtype=torch.get_default_dtype())) |
|
|
209 |
continue |
|
|
210 |
|
|
|
211 |
# Pairwise distances between all possible locations and the GTed locations |
|
|
212 |
n_gt_pts = gt_b.size()[0] |
|
|
213 |
normalized_x = norm_factor.repeat(self.n_pixels, 1) * self.all_img_locations |
|
|
214 |
normalized_y = norm_factor.repeat(len(gt_b), 1) * gt_b |
|
|
215 |
d_matrix = self.cdist(normalized_x, normalized_y) |
|
|
216 |
|
|
|
217 |
# Reshape probability map as a long column vector |
|
|
218 |
# and prepare it for mulitplication |
|
|
219 |
p = prob_map_b.view(prob_map_b.nelement()) |
|
|
220 |
n_est_pts = p.sum() |
|
|
221 |
p_replicated = p.view(-1, 1).repeat(1, n_gt_pts) |
|
|
222 |
|
|
|
223 |
# Weighted Hausdorff Distance |
|
|
224 |
term_1 = (1 / (n_est_pts + 1e-6)) * torch.sum(p * torch.min(d_matrix, 1)[0]) |
|
|
225 |
weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix |
|
|
226 |
minn = self.generalize_mean(weighted_d_matrix, |
|
|
227 |
p=self.p, |
|
|
228 |
dim=0, keepdim=False) |
|
|
229 |
term_2 = torch.mean(minn) |
|
|
230 |
|
|
|
231 |
terms_1.append(term_1) |
|
|
232 |
terms_2.append(term_2) |
|
|
233 |
|
|
|
234 |
terms_1 = torch.stack(terms_1) |
|
|
235 |
terms_2 = torch.stack(terms_2) |
|
|
236 |
|
|
|
237 |
if self.return_2_terms: res = terms_1.mean(), terms_2.means() |
|
|
238 |
else: res = terms_1.mean() + terms_2.mean() |
|
|
239 |
return res |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
class HD(nn.Module): |
|
|
243 |
def __init__(self): |
|
|
244 |
super().__init__() |
|
|
245 |
|
|
|
246 |
def forward(self, logits, target): |
|
|
247 |
_,logits = torch.max(logits, dim=1) |
|
|
248 |
_,target = torch.max(target, dim=1) |
|
|
249 |
|
|
|
250 |
logits = logits.detach().cpu().numpy() |
|
|
251 |
target = target.detach().cpu().numpy() |
|
|
252 |
|
|
|
253 |
hd = 0 |
|
|
254 |
for index in range(logits.shape[0]): |
|
|
255 |
hd += hausdorff_distance(logits[index], target[index], distance='euclidean') |
|
|
256 |
|
|
|
257 |
return hd / logits.shape[0] |