--- a
+++ b/utils/utils_loss.py
@@ -0,0 +1,50 @@
+import torch
+from torch import Tensor
+from typing import List, Set, Iterable
+
+def uniq(a: Tensor) -> Set:
+    return set(torch.unique(a.cpu()).numpy())
+
+def sset(a: Tensor, sub: Iterable) -> bool:
+    return uniq(a).issubset(sub)
+
+def simplex(t: Tensor, axis=1) -> bool:
+    _sum = t.sum(axis).type(torch.float32)
+    _ones = torch.ones_like(_sum, dtype=torch.float32)
+    return torch.allclose(_sum, _ones)
+
+def one_hot(t: Tensor, axis=1) -> bool:
+    return simplex(t, axis) and sset(t, [0, 1])
+
+# switch between representations
+def probs2class(probs: Tensor) -> Tensor:
+    b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
+    assert simplex(probs)
+
+    res = probs.argmax(dim=1)
+    assert res.shape == (b, w, h)
+
+    return res
+
+def class2one_hot(seg: Tensor, C: int) -> Tensor:
+    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
+        seg = seg.unsqueeze(dim=0)
+    assert sset(seg, list(range(C)))
+
+    b, w, h = seg.shape  # type: Tuple[int, int, int]
+
+    res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
+    assert res.shape == (b, C, w, h)
+    assert one_hot(res)
+
+    return res
+
+def probs2one_hot(probs: Tensor) -> Tensor:
+    _, C, _, _ = probs.shape
+    assert simplex(probs)
+
+    res = class2one_hot(probs2class(probs), C)
+    assert res.shape == probs.shape
+    assert one_hot(res)
+
+    return res
\ No newline at end of file