Switch to unified view

a b/tools/test_acdc_leadboard.py
1
import os
2
import sys
3
import numpy as np
4
import nibabel as nib
5
import time
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import _init_paths
12
from libs.network import U_Net, U_NetDF
13
from utils.image_list import to_image_list
14
import libs.datasets.augment as standard_augment
15
import libs.datasets.joint_augment as joint_augment
16
17
def progress_bar(curr_idx, max_idx, time_step, repeat_elem = "_"):
18
    max_equals = 55
19
    step_ms = int(time_step*1000)
20
    num_equals = int(curr_idx*max_equals/float(max_idx))
21
    len_reverse =len('Step:%d ms| %d/%d ['%(step_ms, curr_idx, max_idx)) + num_equals
22
    sys.stdout.write("Step:%d ms|%d/%d [%s]" %(step_ms, curr_idx, max_idx, " " * max_equals,))
23
    sys.stdout.flush()
24
    sys.stdout.write("/b" * (max_equals+1))
25
    sys.stdout.write(repeat_elem * num_equals)
26
    sys.stdout.write("/b"*len_reverse)
27
    sys.stdout.flush()
28
    if curr_idx == max_idx:
29
        print('/n')
30
31
def load_nii(img_path):
32
    """
33
    Function to load a 'nii' or 'nii.gz' file, The function returns
34
    everyting needed to save another 'nii' or 'nii.gz'
35
    in the same dimensional space, i.e. the affine matrix and the header
36
37
    Parameters
38
    ----------
39
40
    img_path: string
41
    String with the path of the 'nii' or 'nii.gz' image file name.
42
43
    Returns
44
    -------
45
    Three element, the first is a numpy array of the image values,
46
    the second is the affine transformation of the image, and the
47
    last one is the header of the image.
48
    """
49
    nimg = nib.load(img_path)
50
    return nimg.get_fdata(), nimg.affine, nimg.header
51
52
def save_nii(vol, affine, hdr, path, prefix, suffix):
53
    vol = nib.Nifti1Image(vol, affine, hdr)
54
    vol.set_data_dtype(np.uint8)
55
    nib.save(vol, os.path.join(path, prefix+'_'+suffix + ".nii.gz"))
56
57
58
def get_person_names(root_path=None):
59
    persons_name = os.listdir(root_path)
60
    persons_name = [pn for pn in persons_name if "patient" in pn]
61
    persons_name.sort()
62
    return persons_name
63
64
def get_patient_data(patient, root_path):
65
    patient_data = {}
66
    infocfg_p = os.path.join(root_path, patient, "Info.cfg")
67
68
    with open(infocfg_p) as f_in:
69
        for line in f_in:
70
            l = line.rstrip().split(": ")
71
            patient_data[l[0]] = l[1]
72
73
    ed_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ED'])))
74
    es_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ES'])))
75
    img_4d_path = os.path.join(root_path, patient, "{}_4d.nii.gz".format(patient))
76
    # ed_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ED'])))
77
    # es_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ES'])))
78
79
    ed, affine, hdr = load_nii(ed_path)
80
    patient_data['ED_VOL'] = np.swapaxes(ed, 0, -1)
81
    patient_data['3D_affine'] = affine
82
    patient_data['3D_hdr'] = hdr
83
84
    es, _, _ = load_nii(es_path)  # (w, h, slices)
85
    patient_data['ES_VOL'] = np.swapaxes(es, 0, -1)
86
87
    img_4d, affine_4d, hdr_4d = load_nii(img_4d_path)  # (w, h, slices, times)
88
    patient_data['4D'] = np.swapaxes(img_4d, 0, 1)
89
    patient_data['4D_affine'] = affine_4d
90
    patient_data['4D_hdr'] = hdr_4d
91
92
    patient_data['size'] = img_4d.shape[:2][::-1]
93
    patient_data['pid'] = patient
94
95
    # ed_gt = load_nii(ed_gt_path)
96
    # patient_data['ED_GT'] = np.swapaxes(ed_gt, 0, 1)
97
98
    # es_gt = load_nii(es_gt_path)
99
    # patient_data['ES_GT'] = np.swapaxes(es_gt, 0, 1)
100
    return patient_data
101
102
def test_it(model, data, device="cuda", used_df=True):
103
    model.eval()
104
    imgs = data
105
106
    imgs = imgs.to(device)
107
    # gts = gts.to(device)
108
109
    net_out = model(imgs)
110
    if used_df:
111
        preds_out = net_out[0]
112
        preds_df = net_out[1]
113
    else:
114
        preds_out = net_out[0]
115
        preds_df = None
116
    preds_out = nn.functional.softmax(preds_out, dim=1)
117
    _, preds = torch.max(preds_out, 1)
118
    preds = preds.unsqueeze(1)  # (N, 1, *)
119
120
    return preds, preds_df
121
122
def transform(imgs):
123
    mean = 63.19523533061758
124
    std = 70.74166957523165
125
    trans = standard_augment.Compose([standard_augment.To_PIL_Image(),
126
                                    # joint_augment.RandomAffine(0,translate=(0.125, 0.125)),
127
                                    # joint_augment.RandomRotate((-180,180)),
128
                                    # joint_augment.FixResize(224),
129
                                    standard_augment.to_Tensor(),
130
                                    standard_augment.normalize([mean], [std]),
131
                                      ])
132
    return trans(imgs)
133
134
def test_voxel(model, imgs, used_df, multi_batches=False, resize=None):
135
    """ imgs: (slices, H, W)
136
        preds: (slices, 1, H, W)
137
    """
138
    imgs = imgs[..., None].astype(np.float32)
139
    B, _, _, C = imgs.shape
140
141
    if multi_batches:
142
        data, origin_shape = to_image_list(imgs, size_divisible=32, return_size=True)
143
        preds, _ = test_it(model, data)
144
        
145
        # for j in range(imgs.shape[0]):
146
        #     preds[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]]
147
    else:
148
        preds = torch.zeros(B, C, resize[0], resize[1])
149
        for j, pt in enumerate(imgs):
150
            data = [transform(pt)]
151
            data, origin_shape = to_image_list(data, size_divisible=32, return_size=True)
152
            pred, _ = test_it(model, data, used_df=used_df)
153
            preds[j, ...] = pred[0, 0, :origin_shape[0][0], :origin_shape[0][1]]
154
155
    if resize is not None:
156
        # preds = F.interpolate(preds, size=resize, mode='nearest')
157
        preds = preds[..., :resize[0], :resize[1]]
158
159
    return preds.cpu().numpy()[:, 0, ...]
160
161
def create_model(model_name, selfeat):
162
    if model_name == 'U_NetDF':
163
        model = U_NetDF(selfeat=selfeat, auxseg=True)
164
    elif model_name == 'U_Net':
165
        model = U_Net()
166
    # elif model_name == 'Resnet18_DfUnet':
167
    #     model = Resnet18_DfUnet()
168
    # elif model_name == 'DenseNet':
169
    #     model = DenseNet()
170
    # elif model_name == 'DenseNet_DF':
171
    #     model = DenseNet_DF(selfeat=selfeat)
172
173
    return model
174
175
def test(mgpus, model_name, model_path, used_df, selfeat, log_path):
176
177
    model = create_model(model_name, selfeat=selfeat)
178
    if mgpus is not None and len(mgpus) > 2:
179
        model = nn.DataParallel(model)
180
    model.cuda()
181
182
    checkpoint = torch.load(model_path, map_location='cpu')
183
    model.load_state_dict(checkpoint['model_state'])
184
185
    root_path = "MICCAIACDC2017/ACDC_DataSet/testing/testing/"
186
    root_path = "/root/ACDC_DataSet/testing/testing/"
187
    persons_name = get_person_names(root_path)
188
    for j, pn in enumerate(persons_name):
189
        s_time = time.time()
190
        patient_data = get_patient_data(pn, root_path)
191
192
        # (slices, h, w)
193
        es_pred = test_voxel(model, patient_data['ES_VOL'], used_df=used_df, resize=patient_data['size'])
194
        ed_pred = test_voxel(model, patient_data['ED_VOL'], used_df=used_df, resize=patient_data['size'])
195
        es_pred = np.transpose(es_pred, (2, 1, 0))
196
        ed_pred = np.transpose(ed_pred, (2, 1, 0))
197
198
        img_4D = patient_data['4D']
199
        h, w, s, t = img_4D.shape
200
        pred_4D = np.zeros((w, h, s, t))
201
        for i in range(img_4D.shape[-1]):
202
            pred = test_voxel(model, np.transpose(img_4D[...,i], (2, 0, 1)), used_df=used_df, resize=patient_data['size'])
203
            pred_4D[..., i] = np.transpose(pred, (2, 1, 0))
204
        
205
        save_path = os.path.join(log_path, "all_predictions")
206
        os.makedirs(save_path, exist_ok=True)
207
        CheckSizeAndSaveVolume(pred_4D, patient_data, save_path)
208
        progress_bar(j%(len(persons_name)+1), len(persons_name),time.time()-s_time)
209
210
211
def CheckSizeAndSaveVolume(seg_4D, patient_data, save_path):
212
    """
213
    TODO:
214
    """ 
215
    prefix = patient_data['pid']
216
    suffix = '4D'
217
218
    # save_nii(seg_4D, patient_data['4D_affine'], patient_data['4D_hdr'], save_path, prefix, suffix)
219
    suffix = 'ED'
220
    ED_phase_n = int(patient_data['ED'])
221
    ED_pred = seg_4D[:,:,:,ED_phase_n]
222
    save_nii(ED_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix)
223
224
    suffix = 'ES'
225
    ES_phase_n = int(patient_data['ES'])
226
    ES_pred = seg_4D[:,:,:,ES_phase_n]
227
    save_nii(ES_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix)
228
229
    # ED_GT = patient_data.get('ED_GT', None)
230
    results = []
231
    return results
232
233
234
235
if __name__ == "__main__":
236
    # get_person_names()
237
    import argparse
238
    parser = argparse.ArgumentParser(description="arg parser")
239
    parser.add_argument('--mgpus', type=str, default='0', required=False, help='whether to use multiple gpu')
240
    parser.add_argument('--used_df', type=str, default='True', help='whether to use df')
241
    parser.add_argument('--model', type=str, default='', help='whether to use df')
242
    parser.add_argument('--selfeat', type=bool, default=True, help='whether to use feature select')
243
    parser.add_argument('--model_path', type=str, default=None, help='whether to train with evaluation')
244
    parser.add_argument('--output_dir', type=str, default="logs/acdc_logs/logs_supcat_auxseg/predtions", required=False, help='specify an output directory if needed')
245
    parser.add_argument('--log_file', type=str, default="../log_predtion.txt", help="the file to write logging")
246
247
    args = parser.parse_args()
248
    if args.mgpus is not None:
249
        os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus
250
251
    model_path = "logs/acdc_logs/logs_supcat_auxseg/ckpt/checkpoint_epoch_70.pth"
252
253
    test(args.mgpus, "U_NetDF", model_path, args.used_df, args.selfeat, args.output_dir)