a b/test.py
1
import numpy as np
2
import sys
3
import os
4
import glob
5
import nibabel as nib
6
import torch
7
#docker
8
#fpx = '/wmhseg_code/'
9
#inputDir = '/input'
10
#outputDir = '/output'
11
12
#local
13
main_folder_path = '../Data/MS2017b/'
14
fpx = './'
15
inputDir = 'input/'
16
outputDir = 'output/'
17
18
#PARAMS
19
useGPU = 0
20
gpu0 = 0
21
patch_size = 60
22
extra_patch = 5
23
model_paths = [fpx + 'analysis/models/EXP3D_1x1x1x1_0_0_dice_1_best.pth']
24
weights = [1]
25
#EXP3D_1x1x1x1_0_0_dice_1_best.pth
26
#EXP3D_2x2x2x2_0_0_dice_1_best.pth
27
#EXP3D_1x1x1x1_0_1_dice_1_best.pth
28
#EXP3D_1x1x1x1_1_1_dice_1_best.pth
29
##EXP3D_1x1x1x1_0_1_dice_1_best.pth (with ep = 16)
30
sys.path.append(fpx + 'utils/')
31
sys.path.append(fpx + 'architectures/deeplab_3D/')
32
sys.path.append(fpx + 'architectures/unet_3D/')
33
sys.path.append(fpx + 'architectures/hrnet_3D/')
34
sys.path.append(fpx + 'architectures/experiment_nets_3D/')
35
sys.path.append('utils/')
36
37
import deeplab_resnet_3D
38
import unet_3D
39
import highresnet_3D
40
import exp_net_3D
41
42
import augmentations as AUG
43
import normalizations as NORM
44
import resizeScans as RS
45
import evalF as EF
46
import evalFP as EFP
47
import PP
48
import torch
49
50
#step 1: read image from input folder
51
#step 2: resize image to 200x200x100 + apply normalizations
52
#step 3: make prediction by patches (with augmentations)
53
#step 4: save prediction to output folder
54
#step 5: resize prediction back to original size of image
55
56
57
img_path = os.path.join(inputDir, 'FLAIR.nii.gz')
58
img_path_rs = os.path.join(outputDir, 'FLAIR_rs.nii.gz')
59
60
wmh_path_rs = os.path.join(outputDir, 'wmh_rs.nii.gz')
61
wmh_path = os.path.join(outputDir, 'result.nii.gz')
62
63
old_size = PP.numpyFromScan(img_path).shape
64
65
new_size = [200,200,100]
66
num_labels = 2
67
68
#convert scan to 200x200x100
69
RS.convertSize2(img_path, img_path_rs, new_size)
70
#get the affine value
71
affine_rs = nib.load(img_path_rs).get_affine()
72
73
#normalize using histogram and variance normalization
74
RS.normalizeScan(img_path_rs, img_path_rs, main_folder_path=main_folder_path)
75
76
#read preprocessed img
77
img, affine = PP.numpyFromScan(img_path_rs, get_affine = True)
78
img = img.transpose((3,0,1,2))
79
img = img[np.newaxis, :]
80
81
print('Image ready')
82
print('Loading model')
83
84
out = None
85
for i, model_path in enumerate(model_paths): 
86
    f_name = model_path.split('/')[-1]
87
    isPriv = False
88
89
    #load model
90
    if 'EXP3D' in f_name:
91
        experiment = f_name.replace('EXP3D_', '').replace('.pth', '').split('_')
92
        experiment = '_'.join(experiment[0:3])
93
        dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment)
94
        model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = 209, withASPP = withASPP)
95
    elif 'HR3D' in f_name:
96
        model = highresnet_3D.getHRNet(num_labels)
97
    elif 'DL3D' in f_name:
98
        model = deeplab_resnet_3D.Res_Deeplab(num_labels)
99
    elif 'UNET3D' in  f_name:
100
        model = unet_3D.UNet3D(1, num_labels)
101
    else:
102
        print('No model available for this .pth')
103
        sys.exit()
104
105
    if useGPU:
106
        saved_state_dict = torch.load(model_path)
107
    else:
108
        saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
109
    model.load_state_dict(saved_state_dict)
110
    model.float()
111
    model.eval()
112
    print('Model ready')
113
    print('Predicting...')
114
    if not isinstance(out, np.ndarray):
115
        if isPriv:
116
            out = EFP.testPredict(img, model, num_labels, 209, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
117
        else:
118
            out = EF.testPredict(img, model, num_labels, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
119
    else:
120
        if isPriv:
121
            out += EFP.testPredict(img, model, num_labels, 209, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
122
        else:
123
            out += EF.testPredict(img, model, num_labels, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
124
125
out /= float(len(model_paths))
126
out = np.argmax(out, axis = 0)
127
#remove batch and label dimension
128
out = out.squeeze()
129
print('Prediction complete')
130
print('Saving...')
131
#save output
132
PP.saveScan(out.astype(np.float64), affine_rs, wmh_path_rs)
133
134
#resize output to original input size and save (this is our final result)
135
d = RS.convertSize2(wmh_path_rs, wmh_path, old_size, interpolation = 'nearest')
136
137
#read the image and save it with same affine and header as original FLAIR image
138
print('Saving final wmh file')
139
orig_flair = nib.load(img_path)
140
wmh_final = nib.load(wmh_path).get_data()
141
PP.saveScan(wmh_final, orig_flair.get_affine(), wmh_path, header =orig_flair.header)
142
print('Done')