Diff of /metrics/meandice.py [000000] .. [f77492]

Switch to unified view

a b/metrics/meandice.py
1
import numpy as np
2
import torch
3
4
5
def meandice(pred, gt, dices):
6
    """
7
    :return save img' dice value in IoUs
8
    """
9
    # dices = []
10
    pred[pred < 0.5] = 0
11
    pred[pred >= 0.5] = 1
12
    gt[gt < 0.5] = 0
13
    gt[gt >= 0.5] = 1
14
    pred = pred.type(torch.LongTensor)
15
    pred_np = pred.data.cpu().numpy()
16
    gt = gt.data.cpu().numpy()
17
    for x in range(pred.size()[0]):
18
        dice = np.sum(pred_np[x][gt[x] == 1]) * 2 / float(np.sum(pred_np[x]) + np.sum(gt[x]))
19
        dices.append(dice)
20
    return dices