|
a |
|
b/tools/test_utils.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import h5py |
|
|
4 |
import numpy as np |
|
|
5 |
import math |
|
|
6 |
|
|
|
7 |
from utils.image_list import to_image_list |
|
|
8 |
|
|
|
9 |
# acdc data |
|
|
10 |
def personTo4Ddata(personname, test_list): |
|
|
11 |
sliceofp = [] |
|
|
12 |
for tl in test_list: |
|
|
13 |
if '/'.join(personname.split('-')) in tl: |
|
|
14 |
sliceofp.append(tl) |
|
|
15 |
|
|
|
16 |
imgs = [[], []] |
|
|
17 |
gts = [[], []] |
|
|
18 |
for ti, time_i in enumerate(["ES", "ED"]): |
|
|
19 |
time_path = [] |
|
|
20 |
|
|
|
21 |
for sp in sliceofp: |
|
|
22 |
if time_i in sp: |
|
|
23 |
time_path.append(sp) |
|
|
24 |
|
|
|
25 |
for tp in time_path: |
|
|
26 |
imgs[ti].append(h5py.File(tp, 'r')['image']) |
|
|
27 |
gts[ti].append(h5py.File(tp, 'r')['label']) |
|
|
28 |
|
|
|
29 |
imgs = np.array(imgs).transpose(1,2,3,0) |
|
|
30 |
gts = np.array(gts).transpose(1,2,3,0) |
|
|
31 |
return imgs, gts |
|
|
32 |
|
|
|
33 |
def test_it(model, data, device="cuda", used_df=False): |
|
|
34 |
model.eval() |
|
|
35 |
imgs = data |
|
|
36 |
|
|
|
37 |
imgs = imgs.to(device) |
|
|
38 |
# gts = gts.to(device) |
|
|
39 |
|
|
|
40 |
net_out = model(imgs) |
|
|
41 |
if used_df: |
|
|
42 |
preds_out = net_out[0] |
|
|
43 |
preds_df = net_out[1] |
|
|
44 |
else: |
|
|
45 |
preds_out = net_out[0] |
|
|
46 |
preds_df = None |
|
|
47 |
preds_out = nn.functional.softmax(preds_out, dim=1) |
|
|
48 |
_, preds = torch.max(preds_out, 1) |
|
|
49 |
preds = preds.unsqueeze(1) # (N, 1, *) |
|
|
50 |
|
|
|
51 |
return preds, preds_df |
|
|
52 |
|
|
|
53 |
def test_person(model, imgs, multi_batches=False, used_df=False): |
|
|
54 |
""" imgs: (times, slices, H, W) |
|
|
55 |
preds: (times, slices, H, W) |
|
|
56 |
""" |
|
|
57 |
preds = [] |
|
|
58 |
for i in range(len(imgs)): |
|
|
59 |
preds_timei = np.zeros([imgs[i].size(0), imgs[i].size(2), imgs[i].size(3)]) |
|
|
60 |
|
|
|
61 |
if multi_batches: |
|
|
62 |
batch_size = 32 |
|
|
63 |
for bs in range(math.ceil(len(imgs[i]) / batch_size)): |
|
|
64 |
st = batch_size * bs |
|
|
65 |
end = st + batch_size if (st+batch_size) <= len(imgs[i]) else len(imgs[i]) |
|
|
66 |
# data, origin_shape = to_image_list(imgs[i][st:st+batch_size], size_divisible=32, return_size=True) |
|
|
67 |
data = imgs[i][st:end] |
|
|
68 |
origin_shape = imgs[i].shape[-2:] |
|
|
69 |
|
|
|
70 |
pred, _ = test_it(model, data, used_df=used_df) |
|
|
71 |
preds_timei[st:end, ...] = pred.cpu().numpy()[:, 0, :origin_shape[0], :origin_shape[1]] |
|
|
72 |
# =========================== |
|
|
73 |
# data, origin_shape = to_image_list(imgs[i], size_divisible=32, return_size=True) |
|
|
74 |
# pred, _ = test_it(model, data, used_df=used_df) |
|
|
75 |
|
|
|
76 |
# for j in range(imgs[i].shape[0]): |
|
|
77 |
# preds_timei[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]] |
|
|
78 |
else: |
|
|
79 |
for j, pt in enumerate(imgs[i]): |
|
|
80 |
data = [pt] |
|
|
81 |
data, origin_shape = to_image_list(data, size_divisible=32, return_size=True) |
|
|
82 |
pred, _ = test_it(model, data) |
|
|
83 |
preds_timei[j, ...] = pred.cpu().numpy()[0, 0, :origin_shape[0][0], :origin_shape[0][1]] |
|
|
84 |
|
|
|
85 |
preds.append(preds_timei) |
|
|
86 |
|
|
|
87 |
return preds |