Diff of /utils/metrics.py [000000] .. [98e649]

Switch to unified view

a b/utils/metrics.py
1
import torch
2
import numpy as np
3
from hausdorff import hausdorff_distance
4
from medpy.metric.binary import hd, dc
5
6
def dice(pred, target):
7
    pred = pred.contiguous()
8
    target = target.contiguous()
9
    smooth = 0.00001
10
11
    # intersection = (pred * target).sum(dim=2).sum(dim=2)
12
    pred_flat = pred.view(1, -1)
13
    target_flat = target.view(1, -1)
14
15
    intersection = (pred_flat * target_flat).sum().item()
16
17
    # loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
18
    dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth)
19
    return dice
20
21
def dice3D(img_gt, img_pred, voxel_size):
22
    """
23
    Function to compute the metrics between two segmentation maps given as input.
24
25
    Parameters
26
    ----------
27
    img_gt: np.array
28
    Array of the ground truth segmentation map.
29
30
    img_pred: np.array
31
    Array of the predicted segmentation map.
32
33
    voxel_size: list, tuple or np.array
34
    The size of a voxel of the images used to compute the volumes.
35
36
    Return
37
    ------
38
    A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
39
    Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
40
    """
41
42
    if img_gt.ndim != img_pred.ndim:
43
        raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
44
                         "same dimension, {} against {}".format(img_gt.ndim,
45
                                                                img_pred.ndim))
46
47
    res = []
48
    # Loop on each classes of the input images
49
    for c in [3, 1, 2]:
50
        # Copy the gt image to not alterate the input
51
        gt_c_i = np.copy(img_gt)
52
        gt_c_i[gt_c_i != c] = 0
53
54
        # Copy the pred image to not alterate the input
55
        pred_c_i = np.copy(img_pred)
56
        pred_c_i[pred_c_i != c] = 0
57
58
        # Clip the value to compute the volumes
59
        gt_c_i = np.clip(gt_c_i, 0, 1)
60
        pred_c_i = np.clip(pred_c_i, 0, 1)
61
62
        # Compute the Dice
63
        dice = dc(gt_c_i, pred_c_i)
64
65
        # Compute volume
66
        # volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
67
        # volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.
68
69
        # res += [dice, volpred, volpred-volgt]
70
        res += [dice]
71
72
    return res
73
74
def hd_3D(img_pred, img_gt, labels=[3, 1, 2]):
75
    res = []
76
    for c in labels:
77
        gt_c_i = np.copy(img_gt)
78
        gt_c_i[gt_c_i != c] = 0
79
80
        pred_c_i = np.copy(img_pred)
81
        pred_c_i[pred_c_i != c] = 0
82
83
        gt_c_i = np.clip(gt_c_i, 0, 1)
84
        pred_c_i = np.clip(pred_c_i, 0, 1)
85
86
        if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0:
87
            hausdorff = 0
88
        else:
89
            hausdorff = hd(pred_c_i, gt_c_i)
90
91
        res += [hausdorff]
92
93
    return res
94
95
def cal_hausdorff_distance(pred,target):
96
97
    pred = np.array(pred.contiguous())
98
    target = np.array(target.contiguous())
99
    result = hausdorff_distance(pred,target,distance="euclidean")
100
101
    return result
102
103
def make_one_hot(input, num_classes):
104
    """Convert class index tensor to one hot encoding tensor.
105
    Args:
106
         input: A tensor of shape [N, 1, *]
107
         num_classes: An int of number of class
108
    Returns:
109
        A tensor of shape [N, num_classes, *]
110
    """
111
    shape = np.array(input.shape)
112
    shape[1] = num_classes
113
    shape = tuple(shape)
114
    result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
115
    # result = result.scatter_(1, input.cpu(), 1)
116
117
    return result
118
119
def match_pred_gt(pred, gt):
120
    """ pred: (1, C, H, W)
121
        gt: (1, C, H, W)
122
    """
123
    gt_labels = torch.unique(gt, sorted=True)[1:]
124
    pred_labels = torch.unique(pred, sorted=True)[1:]
125
126
    if len(gt_labels) != 0 and len(pred_labels) != 0:
127
        dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels)))
128
        for i, pl in enumerate(pred_labels):
129
            pred_i = torch.tensor(pred==pl, dtype=torch.float)
130
            for j, gl in enumerate(gt_labels):
131
                dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0])
132
133
        # max_axis0 = np.max(dice_Matrix, axis=0)
134
        max_arg0 = np.argmax(dice_Matrix, axis=0)
135
    else:
136
        return torch.zeros_like(pred)
137
138
    pred_match = torch.zeros_like(pred)
139
    for i, arg in enumerate(max_arg0):
140
        pred_match[pred==pred_labels[arg]] = i + 1
141
    return pred_match
142
143
if __name__ == "__main__":
144
    npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy"
145
    pred_df, gt_df = np.load(npy_p)