a b/pathflowai/losses.py
1
"""
2
losses.py
3
=======================
4
Some additional loss functions that can be called using the pipeline, some of which still to be implemented.
5
"""
6
7
import torch, numpy as np
8
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union
9
from torch import Tensor, einsum
10
import torch.nn.functional as F
11
from scipy.ndimage import distance_transform_edt as distance
12
from torch import nn
13
14
def assert_(condition, message='', exception_type=AssertionError):
15
    """https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py
16
    Like assert, but with arbitrary exception types."""
17
    if not condition:
18
        raise exception_type(message)
19
20
class ShapeError(ValueError): # """https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py"""
21
    pass
22
23
def flatten_samples(input_):
24
    """
25
    https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/torch_utils.py
26
    Flattens a tensor or a variable such that the channel axis is first and the sample axis
27
    is second. The shapes are transformed as follows:
28
        (N, C, H, W) --> (C, N * H * W)
29
        (N, C, D, H, W) --> (C, N * D * H * W)
30
        (N, C) --> (C, N)
31
    The input must be atleast 2d.
32
    """
33
    assert_(input_.dim() >= 2,
34
            "Tensor or variable must be atleast 2D. Got one of dim {}."
35
            .format(input_.dim()),
36
            ShapeError)
37
    # Get number of channels
38
    num_channels = input_.size(1)
39
    # Permute the channel axis to first
40
    permute_axes = list(range(input_.dim()))
41
    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
42
    # For input shape (say) NCHW, this should have the shape CNHW
43
    permuted = input_.permute(*permute_axes).contiguous()
44
    # Now flatten out all but the first axis and return
45
    flattened = permuted.view(num_channels, -1)
46
    return flattened
47
48
class GeneralizedDiceLoss(nn.Module):
49
    """
50
    https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/extensions/criteria/set_similarity_measures.py
51
    Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237
52
53
    This version works for multiple classes and expects predictions for every class (e.g. softmax output) and
54
    one-hot targets for every class.
55
    """
56
    def __init__(self, weight=None, channelwise=False, eps=1e-6, add_softmax=False):
57
        super(GeneralizedDiceLoss, self).__init__()
58
        self.register_buffer('weight', weight)
59
        self.channelwise = channelwise
60
        self.eps = eps
61
        self.add_softmax = add_softmax
62
63
    def forward(self, input, target):
64
        """
65
        input: torch.FloatTensor or torch.cuda.FloatTensor
66
        target:     torch.FloatTensor or torch.cuda.FloatTensor
67
68
        Expected shape of the inputs:
69
            - if not channelwise: (batch_size, nb_classes, ...)
70
            - if channelwise:     (batch_size, nb_channels, nb_classes, ...)
71
        """
72
        assert input.size() == target.size()
73
        if self.add_softmax:
74
            input = F.softmax(input, dim=1)
75
        if not self.channelwise:
76
            # Flatten input and target to have the shape (nb_classes, N),
77
            # where N is the number of samples
78
            input = flatten_samples(input)
79
            target = flatten_samples(target).float()
80
81
            # Find classes weights:
82
            sum_targets = target.sum(-1)
83
            class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)
84
85
            # Compute generalized Dice loss:
86
            numer = ((input * target).sum(-1) * class_weigths).sum()
87
            denom = ((input + target).sum(-1) * class_weigths).sum()
88
89
            loss = 1. - 2. * numer / denom.clamp(min=self.eps)
90
        else:
91
            def flatten_and_preserve_channels(tensor):
92
                tensor_dim = tensor.dim()
93
                assert tensor_dim >= 3
94
                num_channels = tensor.size(1)
95
                num_classes = tensor.size(2)
96
                # Permute the channel axis to first
97
                permute_axes = list(range(tensor_dim))
98
                permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0]
99
                permuted = tensor.permute(*permute_axes).contiguous()
100
                flattened = permuted.view(num_channels, num_classes, -1)
101
                return flattened
102
103
            # Flatten input and target to have the shape (nb_channels, nb_classes, N)
104
            input = flatten_and_preserve_channels(input)
105
            target = flatten_and_preserve_channels(target)
106
107
            # Find classes weights:
108
            sum_targets = target.sum(-1)
109
            class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)
110
111
            # Compute generalized Dice loss:
112
            numer = ((input * target).sum(-1) * class_weigths).sum(-1)
113
            denom = ((input + target).sum(-1) * class_weigths).sum(-1)
114
115
            channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps)
116
117
            if self.weight is not None:
118
                if channelwise_loss.dim() == 2:
119
                    channelwise_loss = channelwise_loss.squeeze(1)
120
                assert self.weight.size() == channelwise_loss.size(),\
121
                    """`weight` should have shape (nb_channels, ),
122
                       `target` should have shape (batch_size, nb_channels, nb_classes, ...)"""
123
                # Apply channel weights:
124
                channelwise_loss = self.weight * channelwise_loss
125
126
            loss = channelwise_loss.sum()
127
128
        return loss
129
130
class FocalLoss(nn.Module): # add boundary loss
131
    """
132
    # https://raw.githubusercontent.com/Hsuxu/Loss_ToolBox-PyTorch/master/FocalLoss/FocalLoss.py
133
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
134
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
135
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
136
    :param num_class:
137
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
138
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
139
                    focus on hard misclassified example
140
    :param smooth: (float,double) smooth value when cross entropy
141
    :param balance_index: (int) balance class index, should be specific when alpha is float
142
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
143
    """
144
145
    def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
146
        super(FocalLoss, self).__init__()
147
        self.num_class = num_class
148
        self.alpha = alpha
149
        self.gamma = gamma
150
        self.smooth = smooth
151
        self.size_average = size_average
152
153
        if self.alpha is None:
154
            self.alpha = torch.ones(self.num_class, 1)
155
        elif isinstance(self.alpha, (list, np.ndarray)):
156
            assert len(self.alpha) == self.num_class
157
            self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
158
            self.alpha = self.alpha / self.alpha.sum()
159
        elif isinstance(self.alpha, float):
160
            alpha = torch.ones(self.num_class, 1)
161
            alpha = alpha * (1 - self.alpha)
162
            alpha[balance_index] = self.alpha
163
            self.alpha = alpha
164
        else:
165
            raise TypeError('Not support alpha type')
166
167
        if self.smooth is not None:
168
            if self.smooth < 0 or self.smooth > 1.0:
169
                raise ValueError('smooth value should be in [0,1]')
170
171
    def forward(self, logit, target):
172
173
        # logit = F.softmax(input, dim=1)
174
175
        if logit.dim() > 2:
176
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
177
            logit = logit.view(logit.size(0), logit.size(1), -1)
178
            logit = logit.permute(0, 2, 1).contiguous()
179
            logit = logit.view(-1, logit.size(-1))
180
        target = target.view(-1, 1)
181
182
        # N = input.size(0)
183
        # alpha = torch.ones(N, self.num_class)
184
        # alpha = alpha * (1 - self.alpha)
185
        # alpha = alpha.scatter_(1, target.long(), self.alpha)
186
        epsilon = 1e-10
187
        alpha = self.alpha
188
        if alpha.device != input.device:
189
            alpha = alpha.to(input.device)
190
191
        idx = target.cpu().long()
192
193
        one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
194
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
195
        if one_hot_key.device != logit.device:
196
            one_hot_key = one_hot_key.to(logit.device)
197
198
        if self.smooth:
199
            one_hot_key = torch.clamp(
200
                one_hot_key, self.smooth/(self.num_class-1), 1.0 - self.smooth)
201
        pt = (one_hot_key * logit).sum(1) + epsilon
202
        logpt = pt.log()
203
204
        gamma = self.gamma
205
206
        alpha = alpha[idx]
207
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
208
209
        if self.size_average:
210
            loss = loss.mean()
211
        else:
212
            loss = loss.sum()
213
        return loss
214
215
216
def uniq(a: Tensor) -> Set:
217
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
218
    return set(torch.unique(a.cpu()).numpy())
219
220
def sset(a: Tensor, sub: Iterable) -> bool:
221
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
222
    return uniq(a).issubset(sub)
223
224
225
def eq(a: Tensor, b) -> bool:
226
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
227
    return torch.eq(a, b).all()
228
229
230
def simplex(t: Tensor, axis=1) -> bool:
231
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
232
    _sum = t.sum(axis).type(torch.float32)
233
    _ones = torch.ones_like(_sum, dtype=torch.float32)
234
    return torch.allclose(_sum, _ones)
235
236
def one_hot(t: Tensor, axis=1) -> bool:
237
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
238
    return simplex(t, axis) and sset(t, [0, 1])
239
240
def class2one_hot(seg: Tensor, C: int) -> Tensor:
241
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
242
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
243
        seg = seg.unsqueeze(dim=0)
244
    assert sset(seg, list(range(C)))
245
246
    b, w, h = seg.shape  # type: Tuple[int, int, int]
247
248
    res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
249
    assert res.shape == (b, C, w, h)
250
    assert one_hot(res)
251
252
    return res
253
254
def one_hot2dist(seg: np.ndarray) -> np.ndarray:
255
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py"""
256
    assert one_hot(torch.Tensor(seg), axis=0)
257
    C: int = len(seg)
258
259
    res = np.zeros_like(seg)
260
    for c in range(C):
261
        posmask = seg[c].astype(np.bool)
262
263
        if posmask.any():
264
            negmask = ~posmask
265
            res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
266
    return res
267
268
class SurfaceLoss():
269
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py"""
270
    def __init__(self, **kwargs):
271
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
272
        self.idc: List[int] = kwargs["idc"]
273
        print(f"Initialized {self.__class__.__name__} with {kwargs}")
274
275
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
276
        assert simplex(probs)
277
        assert not one_hot(dist_maps)
278
279
        pc = probs[:, self.idc, ...].type(torch.float32)
280
        dc = dist_maps[:, self.idc, ...].type(torch.float32)
281
282
        multipled = einsum("bcwh,bcwh->bcwh", pc, dc)
283
284
        loss = multipled.mean()
285
286
        return loss
287
288
class GeneralizedDice():
289
    """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py"""
290
    def __init__(self, **kwargs):
291
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
292
        self.idc: List[int] = kwargs["idc"]
293
        print(f"Initialized {self.__class__.__name__} with {kwargs}")
294
295
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor:
296
        assert simplex(probs) and simplex(target)
297
298
        pc = probs[:, self.idc, ...].type(torch.float32)
299
        tc = target[:, self.idc, ...].type(torch.float32)
300
301
        w: Tensor = 1 / ((einsum("bcwh->bc", tc).type(torch.float32) + 1e-10) ** 2)
302
        intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc)
303
        union: Tensor = w * (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc))
304
305
        divided: Tensor = 1 - 2 * (einsum("bc->b", intersection) + 1e-10) / (einsum("bc->b", union) + 1e-10)
306
307
        loss = divided.mean()
308
309
        return loss