|
a |
|
b/metrics/meandice.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def meandice(pred, gt, dices): |
|
|
6 |
""" |
|
|
7 |
:return save img' dice value in IoUs |
|
|
8 |
""" |
|
|
9 |
# dices = [] |
|
|
10 |
pred[pred < 0.5] = 0 |
|
|
11 |
pred[pred >= 0.5] = 1 |
|
|
12 |
gt[gt < 0.5] = 0 |
|
|
13 |
gt[gt >= 0.5] = 1 |
|
|
14 |
pred = pred.type(torch.LongTensor) |
|
|
15 |
pred_np = pred.data.cpu().numpy() |
|
|
16 |
gt = gt.data.cpu().numpy() |
|
|
17 |
for x in range(pred.size()[0]): |
|
|
18 |
dice = np.sum(pred_np[x][gt[x] == 1]) * 2 / float(np.sum(pred_np[x]) + np.sum(gt[x])) |
|
|
19 |
dices.append(dice) |
|
|
20 |
return dices |