import torch
from torch import Tensor
from typing import List, Set, Iterable
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
# switch between representations
def probs2class(probs: Tensor) -> Tensor:
b, _, w, h = probs.shape # type: Tuple[int, int, int, int]
assert simplex(probs)
res = probs.argmax(dim=1)
assert res.shape == (b, w, h)
return res
def class2one_hot(seg: Tensor, C: int) -> Tensor:
if len(seg.shape) == 2: # Only w, h, used by the dataloader
seg = seg.unsqueeze(dim=0)
assert sset(seg, list(range(C)))
b, w, h = seg.shape # type: Tuple[int, int, int]
res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
assert res.shape == (b, C, w, h)
assert one_hot(res)
return res
def probs2one_hot(probs: Tensor) -> Tensor:
_, C, _, _ = probs.shape
assert simplex(probs)
res = class2one_hot(probs2class(probs), C)
assert res.shape == probs.shape
assert one_hot(res)
return res