|
a |
|
b/pathflowai/losses.py |
|
|
1 |
""" |
|
|
2 |
losses.py |
|
|
3 |
======================= |
|
|
4 |
Some additional loss functions that can be called using the pipeline, some of which still to be implemented. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import torch, numpy as np |
|
|
8 |
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union |
|
|
9 |
from torch import Tensor, einsum |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
from scipy.ndimage import distance_transform_edt as distance |
|
|
12 |
from torch import nn |
|
|
13 |
|
|
|
14 |
def assert_(condition, message='', exception_type=AssertionError): |
|
|
15 |
"""https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py |
|
|
16 |
Like assert, but with arbitrary exception types.""" |
|
|
17 |
if not condition: |
|
|
18 |
raise exception_type(message) |
|
|
19 |
|
|
|
20 |
class ShapeError(ValueError): # """https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py""" |
|
|
21 |
pass |
|
|
22 |
|
|
|
23 |
def flatten_samples(input_): |
|
|
24 |
""" |
|
|
25 |
https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/torch_utils.py |
|
|
26 |
Flattens a tensor or a variable such that the channel axis is first and the sample axis |
|
|
27 |
is second. The shapes are transformed as follows: |
|
|
28 |
(N, C, H, W) --> (C, N * H * W) |
|
|
29 |
(N, C, D, H, W) --> (C, N * D * H * W) |
|
|
30 |
(N, C) --> (C, N) |
|
|
31 |
The input must be atleast 2d. |
|
|
32 |
""" |
|
|
33 |
assert_(input_.dim() >= 2, |
|
|
34 |
"Tensor or variable must be atleast 2D. Got one of dim {}." |
|
|
35 |
.format(input_.dim()), |
|
|
36 |
ShapeError) |
|
|
37 |
# Get number of channels |
|
|
38 |
num_channels = input_.size(1) |
|
|
39 |
# Permute the channel axis to first |
|
|
40 |
permute_axes = list(range(input_.dim())) |
|
|
41 |
permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] |
|
|
42 |
# For input shape (say) NCHW, this should have the shape CNHW |
|
|
43 |
permuted = input_.permute(*permute_axes).contiguous() |
|
|
44 |
# Now flatten out all but the first axis and return |
|
|
45 |
flattened = permuted.view(num_channels, -1) |
|
|
46 |
return flattened |
|
|
47 |
|
|
|
48 |
class GeneralizedDiceLoss(nn.Module): |
|
|
49 |
""" |
|
|
50 |
https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/extensions/criteria/set_similarity_measures.py |
|
|
51 |
Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237 |
|
|
52 |
|
|
|
53 |
This version works for multiple classes and expects predictions for every class (e.g. softmax output) and |
|
|
54 |
one-hot targets for every class. |
|
|
55 |
""" |
|
|
56 |
def __init__(self, weight=None, channelwise=False, eps=1e-6, add_softmax=False): |
|
|
57 |
super(GeneralizedDiceLoss, self).__init__() |
|
|
58 |
self.register_buffer('weight', weight) |
|
|
59 |
self.channelwise = channelwise |
|
|
60 |
self.eps = eps |
|
|
61 |
self.add_softmax = add_softmax |
|
|
62 |
|
|
|
63 |
def forward(self, input, target): |
|
|
64 |
""" |
|
|
65 |
input: torch.FloatTensor or torch.cuda.FloatTensor |
|
|
66 |
target: torch.FloatTensor or torch.cuda.FloatTensor |
|
|
67 |
|
|
|
68 |
Expected shape of the inputs: |
|
|
69 |
- if not channelwise: (batch_size, nb_classes, ...) |
|
|
70 |
- if channelwise: (batch_size, nb_channels, nb_classes, ...) |
|
|
71 |
""" |
|
|
72 |
assert input.size() == target.size() |
|
|
73 |
if self.add_softmax: |
|
|
74 |
input = F.softmax(input, dim=1) |
|
|
75 |
if not self.channelwise: |
|
|
76 |
# Flatten input and target to have the shape (nb_classes, N), |
|
|
77 |
# where N is the number of samples |
|
|
78 |
input = flatten_samples(input) |
|
|
79 |
target = flatten_samples(target).float() |
|
|
80 |
|
|
|
81 |
# Find classes weights: |
|
|
82 |
sum_targets = target.sum(-1) |
|
|
83 |
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) |
|
|
84 |
|
|
|
85 |
# Compute generalized Dice loss: |
|
|
86 |
numer = ((input * target).sum(-1) * class_weigths).sum() |
|
|
87 |
denom = ((input + target).sum(-1) * class_weigths).sum() |
|
|
88 |
|
|
|
89 |
loss = 1. - 2. * numer / denom.clamp(min=self.eps) |
|
|
90 |
else: |
|
|
91 |
def flatten_and_preserve_channels(tensor): |
|
|
92 |
tensor_dim = tensor.dim() |
|
|
93 |
assert tensor_dim >= 3 |
|
|
94 |
num_channels = tensor.size(1) |
|
|
95 |
num_classes = tensor.size(2) |
|
|
96 |
# Permute the channel axis to first |
|
|
97 |
permute_axes = list(range(tensor_dim)) |
|
|
98 |
permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0] |
|
|
99 |
permuted = tensor.permute(*permute_axes).contiguous() |
|
|
100 |
flattened = permuted.view(num_channels, num_classes, -1) |
|
|
101 |
return flattened |
|
|
102 |
|
|
|
103 |
# Flatten input and target to have the shape (nb_channels, nb_classes, N) |
|
|
104 |
input = flatten_and_preserve_channels(input) |
|
|
105 |
target = flatten_and_preserve_channels(target) |
|
|
106 |
|
|
|
107 |
# Find classes weights: |
|
|
108 |
sum_targets = target.sum(-1) |
|
|
109 |
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) |
|
|
110 |
|
|
|
111 |
# Compute generalized Dice loss: |
|
|
112 |
numer = ((input * target).sum(-1) * class_weigths).sum(-1) |
|
|
113 |
denom = ((input + target).sum(-1) * class_weigths).sum(-1) |
|
|
114 |
|
|
|
115 |
channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps) |
|
|
116 |
|
|
|
117 |
if self.weight is not None: |
|
|
118 |
if channelwise_loss.dim() == 2: |
|
|
119 |
channelwise_loss = channelwise_loss.squeeze(1) |
|
|
120 |
assert self.weight.size() == channelwise_loss.size(),\ |
|
|
121 |
"""`weight` should have shape (nb_channels, ), |
|
|
122 |
`target` should have shape (batch_size, nb_channels, nb_classes, ...)""" |
|
|
123 |
# Apply channel weights: |
|
|
124 |
channelwise_loss = self.weight * channelwise_loss |
|
|
125 |
|
|
|
126 |
loss = channelwise_loss.sum() |
|
|
127 |
|
|
|
128 |
return loss |
|
|
129 |
|
|
|
130 |
class FocalLoss(nn.Module): # add boundary loss |
|
|
131 |
""" |
|
|
132 |
# https://raw.githubusercontent.com/Hsuxu/Loss_ToolBox-PyTorch/master/FocalLoss/FocalLoss.py |
|
|
133 |
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in |
|
|
134 |
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' |
|
|
135 |
Focal_Loss= -1*alpha*(1-pt)*log(pt) |
|
|
136 |
:param num_class: |
|
|
137 |
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion |
|
|
138 |
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more |
|
|
139 |
focus on hard misclassified example |
|
|
140 |
:param smooth: (float,double) smooth value when cross entropy |
|
|
141 |
:param balance_index: (int) balance class index, should be specific when alpha is float |
|
|
142 |
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. |
|
|
143 |
""" |
|
|
144 |
|
|
|
145 |
def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True): |
|
|
146 |
super(FocalLoss, self).__init__() |
|
|
147 |
self.num_class = num_class |
|
|
148 |
self.alpha = alpha |
|
|
149 |
self.gamma = gamma |
|
|
150 |
self.smooth = smooth |
|
|
151 |
self.size_average = size_average |
|
|
152 |
|
|
|
153 |
if self.alpha is None: |
|
|
154 |
self.alpha = torch.ones(self.num_class, 1) |
|
|
155 |
elif isinstance(self.alpha, (list, np.ndarray)): |
|
|
156 |
assert len(self.alpha) == self.num_class |
|
|
157 |
self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1) |
|
|
158 |
self.alpha = self.alpha / self.alpha.sum() |
|
|
159 |
elif isinstance(self.alpha, float): |
|
|
160 |
alpha = torch.ones(self.num_class, 1) |
|
|
161 |
alpha = alpha * (1 - self.alpha) |
|
|
162 |
alpha[balance_index] = self.alpha |
|
|
163 |
self.alpha = alpha |
|
|
164 |
else: |
|
|
165 |
raise TypeError('Not support alpha type') |
|
|
166 |
|
|
|
167 |
if self.smooth is not None: |
|
|
168 |
if self.smooth < 0 or self.smooth > 1.0: |
|
|
169 |
raise ValueError('smooth value should be in [0,1]') |
|
|
170 |
|
|
|
171 |
def forward(self, logit, target): |
|
|
172 |
|
|
|
173 |
# logit = F.softmax(input, dim=1) |
|
|
174 |
|
|
|
175 |
if logit.dim() > 2: |
|
|
176 |
# N,C,d1,d2 -> N,C,m (m=d1*d2*...) |
|
|
177 |
logit = logit.view(logit.size(0), logit.size(1), -1) |
|
|
178 |
logit = logit.permute(0, 2, 1).contiguous() |
|
|
179 |
logit = logit.view(-1, logit.size(-1)) |
|
|
180 |
target = target.view(-1, 1) |
|
|
181 |
|
|
|
182 |
# N = input.size(0) |
|
|
183 |
# alpha = torch.ones(N, self.num_class) |
|
|
184 |
# alpha = alpha * (1 - self.alpha) |
|
|
185 |
# alpha = alpha.scatter_(1, target.long(), self.alpha) |
|
|
186 |
epsilon = 1e-10 |
|
|
187 |
alpha = self.alpha |
|
|
188 |
if alpha.device != input.device: |
|
|
189 |
alpha = alpha.to(input.device) |
|
|
190 |
|
|
|
191 |
idx = target.cpu().long() |
|
|
192 |
|
|
|
193 |
one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() |
|
|
194 |
one_hot_key = one_hot_key.scatter_(1, idx, 1) |
|
|
195 |
if one_hot_key.device != logit.device: |
|
|
196 |
one_hot_key = one_hot_key.to(logit.device) |
|
|
197 |
|
|
|
198 |
if self.smooth: |
|
|
199 |
one_hot_key = torch.clamp( |
|
|
200 |
one_hot_key, self.smooth/(self.num_class-1), 1.0 - self.smooth) |
|
|
201 |
pt = (one_hot_key * logit).sum(1) + epsilon |
|
|
202 |
logpt = pt.log() |
|
|
203 |
|
|
|
204 |
gamma = self.gamma |
|
|
205 |
|
|
|
206 |
alpha = alpha[idx] |
|
|
207 |
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt |
|
|
208 |
|
|
|
209 |
if self.size_average: |
|
|
210 |
loss = loss.mean() |
|
|
211 |
else: |
|
|
212 |
loss = loss.sum() |
|
|
213 |
return loss |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
def uniq(a: Tensor) -> Set: |
|
|
217 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
218 |
return set(torch.unique(a.cpu()).numpy()) |
|
|
219 |
|
|
|
220 |
def sset(a: Tensor, sub: Iterable) -> bool: |
|
|
221 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
222 |
return uniq(a).issubset(sub) |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
def eq(a: Tensor, b) -> bool: |
|
|
226 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
227 |
return torch.eq(a, b).all() |
|
|
228 |
|
|
|
229 |
|
|
|
230 |
def simplex(t: Tensor, axis=1) -> bool: |
|
|
231 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
232 |
_sum = t.sum(axis).type(torch.float32) |
|
|
233 |
_ones = torch.ones_like(_sum, dtype=torch.float32) |
|
|
234 |
return torch.allclose(_sum, _ones) |
|
|
235 |
|
|
|
236 |
def one_hot(t: Tensor, axis=1) -> bool: |
|
|
237 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
238 |
return simplex(t, axis) and sset(t, [0, 1]) |
|
|
239 |
|
|
|
240 |
def class2one_hot(seg: Tensor, C: int) -> Tensor: |
|
|
241 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
242 |
if len(seg.shape) == 2: # Only w, h, used by the dataloader |
|
|
243 |
seg = seg.unsqueeze(dim=0) |
|
|
244 |
assert sset(seg, list(range(C))) |
|
|
245 |
|
|
|
246 |
b, w, h = seg.shape # type: Tuple[int, int, int] |
|
|
247 |
|
|
|
248 |
res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) |
|
|
249 |
assert res.shape == (b, C, w, h) |
|
|
250 |
assert one_hot(res) |
|
|
251 |
|
|
|
252 |
return res |
|
|
253 |
|
|
|
254 |
def one_hot2dist(seg: np.ndarray) -> np.ndarray: |
|
|
255 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" |
|
|
256 |
assert one_hot(torch.Tensor(seg), axis=0) |
|
|
257 |
C: int = len(seg) |
|
|
258 |
|
|
|
259 |
res = np.zeros_like(seg) |
|
|
260 |
for c in range(C): |
|
|
261 |
posmask = seg[c].astype(np.bool) |
|
|
262 |
|
|
|
263 |
if posmask.any(): |
|
|
264 |
negmask = ~posmask |
|
|
265 |
res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask |
|
|
266 |
return res |
|
|
267 |
|
|
|
268 |
class SurfaceLoss(): |
|
|
269 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py""" |
|
|
270 |
def __init__(self, **kwargs): |
|
|
271 |
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing |
|
|
272 |
self.idc: List[int] = kwargs["idc"] |
|
|
273 |
print(f"Initialized {self.__class__.__name__} with {kwargs}") |
|
|
274 |
|
|
|
275 |
def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor: |
|
|
276 |
assert simplex(probs) |
|
|
277 |
assert not one_hot(dist_maps) |
|
|
278 |
|
|
|
279 |
pc = probs[:, self.idc, ...].type(torch.float32) |
|
|
280 |
dc = dist_maps[:, self.idc, ...].type(torch.float32) |
|
|
281 |
|
|
|
282 |
multipled = einsum("bcwh,bcwh->bcwh", pc, dc) |
|
|
283 |
|
|
|
284 |
loss = multipled.mean() |
|
|
285 |
|
|
|
286 |
return loss |
|
|
287 |
|
|
|
288 |
class GeneralizedDice(): |
|
|
289 |
"""https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py""" |
|
|
290 |
def __init__(self, **kwargs): |
|
|
291 |
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing |
|
|
292 |
self.idc: List[int] = kwargs["idc"] |
|
|
293 |
print(f"Initialized {self.__class__.__name__} with {kwargs}") |
|
|
294 |
|
|
|
295 |
def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor: |
|
|
296 |
assert simplex(probs) and simplex(target) |
|
|
297 |
|
|
|
298 |
pc = probs[:, self.idc, ...].type(torch.float32) |
|
|
299 |
tc = target[:, self.idc, ...].type(torch.float32) |
|
|
300 |
|
|
|
301 |
w: Tensor = 1 / ((einsum("bcwh->bc", tc).type(torch.float32) + 1e-10) ** 2) |
|
|
302 |
intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc) |
|
|
303 |
union: Tensor = w * (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc)) |
|
|
304 |
|
|
|
305 |
divided: Tensor = 1 - 2 * (einsum("bc->b", intersection) + 1e-10) / (einsum("bc->b", union) + 1e-10) |
|
|
306 |
|
|
|
307 |
loss = divided.mean() |
|
|
308 |
|
|
|
309 |
return loss |