[98e649]: / libs / losses / surface_loss.py

Download this file

43 lines (30 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch import Tensor
from typing import List, Set, Iterable
from utils.utils_loss import class2one_hot
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
class SurfaceLoss():
def __init__(self, **kwargs):
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing
self.idc: List[int] = kwargs["idc"]
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def __call__(self, probs: Tensor, dist_maps: Tensor, gts: Tensor) -> Tensor:
assert simplex(probs)
assert not one_hot(dist_maps)
dist_maps = dist_maps.to(probs.device)
pc = probs[:, self.idc, ...]
dc = dist_maps[:, self.idc, ...]
multipled = torch.einsum("bcwh,bcwh->bcwh", pc, dc)
# gc = class2one_hot(gts)[:, self.idc, ...]
# multipled = torch.einsum("bcwh,bcwh->bcwh", pc - gc, dc)
loss = multipled.mean()
return loss