|
a |
|
b/utils/direct_field/utils_df.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
from scipy.ndimage import distance_transform_edt as distance |
|
|
4 |
from utils.utils_loss import one_hot, simplex, class2one_hot |
|
|
5 |
|
|
|
6 |
def one_hot2dist(seg: np.ndarray) -> np.ndarray: |
|
|
7 |
assert one_hot(torch.Tensor(seg), axis=0) |
|
|
8 |
C: int = len(seg) |
|
|
9 |
|
|
|
10 |
res = np.zeros_like(seg) |
|
|
11 |
for c in range(C): |
|
|
12 |
posmask = seg[c].astype(np.bool) |
|
|
13 |
|
|
|
14 |
if posmask.any(): |
|
|
15 |
negmask = ~posmask |
|
|
16 |
res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask |
|
|
17 |
return res |
|
|
18 |
|
|
|
19 |
def class2dist(seg: np.ndarray, C=4) -> np.ndarray: |
|
|
20 |
""" res: (C, H, W) |
|
|
21 |
""" |
|
|
22 |
if seg.ndim == 2: |
|
|
23 |
seg_tensor = torch.Tensor(seg) |
|
|
24 |
elif seg.ndim == 3: |
|
|
25 |
seg_tensor = torch.Tensor(seg[0]) |
|
|
26 |
elif seg.ndim == 4: |
|
|
27 |
seg_tensor = torch.Tensor(seg[0, ..., 0]) |
|
|
28 |
|
|
|
29 |
seg_onehot = class2one_hot(seg_tensor, C).to(torch.float32) |
|
|
30 |
|
|
|
31 |
assert simplex(seg_onehot) |
|
|
32 |
res = one_hot2dist(seg_onehot[0].numpy()) |
|
|
33 |
return res |
|
|
34 |
|