a b/utils/direct_field/utils_df.py
1
import torch
2
import numpy as np
3
from scipy.ndimage import distance_transform_edt as distance
4
from utils.utils_loss import one_hot, simplex, class2one_hot
5
6
def one_hot2dist(seg: np.ndarray) -> np.ndarray:
7
    assert one_hot(torch.Tensor(seg), axis=0)
8
    C: int = len(seg)
9
10
    res = np.zeros_like(seg)
11
    for c in range(C):
12
        posmask = seg[c].astype(np.bool)
13
14
        if posmask.any():
15
            negmask = ~posmask
16
            res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
17
    return res
18
19
def class2dist(seg: np.ndarray, C=4) -> np.ndarray:
20
    """ res: (C, H, W)
21
    """
22
    if seg.ndim == 2:
23
        seg_tensor = torch.Tensor(seg)
24
    elif seg.ndim == 3:
25
        seg_tensor = torch.Tensor(seg[0])
26
    elif seg.ndim == 4:
27
        seg_tensor = torch.Tensor(seg[0, ..., 0])
28
29
    seg_onehot = class2one_hot(seg_tensor, C).to(torch.float32)
30
31
    assert simplex(seg_onehot)
32
    res = one_hot2dist(seg_onehot[0].numpy())
33
    return res
34