Diff of /tools/test_utils.py [000000] .. [98e649]

Switch to unified view

a b/tools/test_utils.py
1
import torch
2
import torch.nn as nn
3
import h5py
4
import numpy as np
5
import math
6
7
from utils.image_list import to_image_list
8
9
# acdc data
10
def personTo4Ddata(personname, test_list):
11
    sliceofp = []
12
    for tl in test_list:
13
        if '/'.join(personname.split('-')) in tl:
14
            sliceofp.append(tl)
15
16
    imgs = [[], []]
17
    gts = [[], []]
18
    for ti, time_i in enumerate(["ES", "ED"]):
19
        time_path = []
20
        
21
        for sp in sliceofp:
22
            if time_i in sp:
23
                time_path.append(sp)
24
        
25
        for tp in time_path:
26
            imgs[ti].append(h5py.File(tp, 'r')['image'])
27
            gts[ti].append(h5py.File(tp, 'r')['label'])
28
29
    imgs = np.array(imgs).transpose(1,2,3,0)
30
    gts = np.array(gts).transpose(1,2,3,0)
31
    return imgs, gts
32
33
def test_it(model, data, device="cuda", used_df=False):
34
    model.eval()
35
    imgs = data
36
37
    imgs = imgs.to(device)
38
    # gts = gts.to(device)
39
40
    net_out = model(imgs)
41
    if used_df:
42
        preds_out = net_out[0]
43
        preds_df = net_out[1]
44
    else:
45
        preds_out = net_out[0]
46
        preds_df = None
47
    preds_out = nn.functional.softmax(preds_out, dim=1)
48
    _, preds = torch.max(preds_out, 1)
49
    preds = preds.unsqueeze(1)  # (N, 1, *)
50
51
    return preds, preds_df
52
53
def test_person(model, imgs, multi_batches=False, used_df=False):
54
    """ imgs: (times, slices, H, W)
55
        preds: (times, slices, H, W)
56
    """
57
    preds = []
58
    for i in range(len(imgs)):
59
        preds_timei = np.zeros([imgs[i].size(0), imgs[i].size(2), imgs[i].size(3)])
60
61
        if multi_batches:
62
            batch_size = 32
63
            for bs in range(math.ceil(len(imgs[i]) / batch_size)):
64
                st = batch_size * bs
65
                end = st + batch_size if (st+batch_size) <= len(imgs[i]) else len(imgs[i])
66
                # data, origin_shape = to_image_list(imgs[i][st:st+batch_size], size_divisible=32, return_size=True)
67
                data = imgs[i][st:end]
68
                origin_shape = imgs[i].shape[-2:]
69
70
                pred, _ = test_it(model, data, used_df=used_df)
71
                preds_timei[st:end, ...] = pred.cpu().numpy()[:, 0, :origin_shape[0], :origin_shape[1]]
72
            # ===========================
73
            # data, origin_shape = to_image_list(imgs[i], size_divisible=32, return_size=True)
74
            # pred, _ = test_it(model, data, used_df=used_df)
75
            
76
            # for j in range(imgs[i].shape[0]):
77
            #     preds_timei[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]]
78
        else:
79
            for j, pt in enumerate(imgs[i]):
80
                data = [pt]
81
                data, origin_shape = to_image_list(data, size_divisible=32, return_size=True)
82
                pred, _ = test_it(model, data)
83
                preds_timei[j, ...] = pred.cpu().numpy()[0, 0, :origin_shape[0][0], :origin_shape[0][1]]
84
        
85
        preds.append(preds_timei)
86
87
    return preds