Diff of /test.py [000000] .. [408896]

Switch to unified view

a b/test.py
1
import os
2
import glob
3
import scipy
4
import numpy as np
5
import nibabel as nib
6
import tensorflow as tf
7
from tqdm import tqdm
8
from scipy.ndimage import zoom
9
10
from args import TestArgParser
11
from util import DiceCoefficient
12
from model import Model
13
14
15
class Interpolator(object):
16
    def __init__(self, modalities, order=3, mode='reflect'):
17
        self.modalities = modalities
18
        self.order = order
19
        self.mode = mode
20
21
    def __call__(self, path):
22
        """Extracts Numpy image and normalizes it to 1 mm^3."""
23
        # Extract raw images from each.
24
        image = []
25
        pixdim = []
26
        affine = []
27
        
28
        for name in self.modalities:
29
            file_card = glob.glob(os.path.join(path, '*' + name + '*' + '.nii' + '*'))[0]
30
            img = nib.load(file_card)
31
32
            image.append(np.array(img.dataobj).astype(np.float32))
33
            pixdim.append(img.header['pixdim'][:4])
34
            affine.append(np.stack([img.header['srow_x'],
35
                                    img.header['srow_y'],
36
                                    img.header['srow_z'],
37
                                    np.array([0., 0., 0., 1.])], axis=0))
38
39
        # Prepare image.
40
        image = np.stack(image, axis=-1)
41
        self.pixdim = np.mean(pixdim, axis=0, dtype=np.float32)
42
        self.affine = np.mean(affine, axis=0, dtype=np.float32)
43
44
        # Rescale and interpolate voxels spatially.
45
        if np.any(self.pixdim[:-1] != 1.0):
46
            image = zoom(image, self.pixdim[:-1] + [1.0], order=self.order, mode=self.mode)
47
        
48
        # Rescale and interpolate voxels depthwise (along time).
49
        if self.pixdim[-1] != 1.0:
50
            image = zoom(image, [1.0, 1.0, self.pixdim[-1], 1.0], order=self.order, mode=self.mode)
51
52
        # Mask out background voxels.
53
        mask = np.max(image, axis=-1, keepdims=True)
54
        mask = (mask > 0).astype(np.float32)
55
56
        return image, mask
57
58
    def reverse(self, output, path):
59
        """Reverses the interpolation performed in __call__."""
60
        # Scale back spatial voxel interpolation.
61
        if np.any(self.pixdim[:-1] != 1.0):
62
            output = zoom(output, 1.0 / self.pixdim[:-1], order=self.order, mode=self.mode)
63
64
        # Scale back depthwise voxel interpolation.
65
        if self.pixdim[-1] != 1.0:
66
            output = zoom(output, [1.0, 1.0, 1.0 / self.pixdim[-1]], order=self.order, mode=self.mode)
67
68
        # Save file.
69
        nib.save(nib.Nifti1Image(output, self.affine),
70
                    os.path.join(path, 'mask.nii'))
71
72
        return output
73
74
75
class TestTimeAugmentor(object):
76
    """Handles full inference on input with test-time augmentation."""
77
    def __init__(self,
78
                 mean,
79
                 std,
80
                 model,
81
                 model_data_format,
82
                 spatial_tta=True,
83
                 channel_tta=0,
84
                 threshold=0.5):
85
        self.mean = mean
86
        self.std = std
87
        self.model = model
88
        self.model_data_format = model_data_format
89
        self.channel_tta = channel_tta
90
        self.threshold = threshold
91
92
        self.channel_axis = -1 if self.model_data_format == 'channels_last' else 1
93
        self.spatial_axes = [1, 2, 3] if self.model_data_format == 'channels_last' else [2, 3, 4]
94
95
        if spatial_tta:
96
            self.augment_axes = [self.spatial_axes, []]
97
            for axis in self.spatial_axes:
98
                pairs = self.spatial_axes.copy()
99
                pairs.remove(axis)
100
                self.augment_axes.append([axis])
101
                self.augment_axes.append(pairs)
102
        else:
103
            self.augment_axes = [[]]
104
105
    def __call__(self, x, bmask):
106
        # Normalize and prepare input (assumes input data format of 'channels_last').
107
        x = (x - self.mean) / self.std
108
109
        # Transpose to channels_first data format if required by model.
110
        if self.model_data_format == 'channels_first':
111
            x = tf.transpose(x, (3, 0, 1, 2))
112
            bmask = tf.transpose(bmask, (3, 0, 1, 2))
113
        x = tf.expand_dims(x, axis=0)
114
        bmask = tf.expand_dims(bmask, axis=0)
115
116
        # Initialize list of inputs to feed model.
117
        y = []
118
119
        # Create shape for intensity shifting.
120
        shape = [1, 1, 1]
121
        shape.insert(self.channel_axis, x.shape[self.channel_axis])
122
123
        if self.channel_tta:
124
            _, var = tf.nn.moments(x, axes=self.spatial_axes, keepdims=True)
125
            std = tf.sqrt(var)
126
127
        # Apply spatial augmentation.
128
        for flip in self.augment_axes:
129
130
            # Run inference on spatially augmented input.
131
            aug = tf.reverse(x, axis=flip)
132
133
            aug, *_ = self.model(aug, training=False, inference=True)
134
            y.append(tf.reverse(aug, axis=flip))
135
136
            for _ in range(self.channel_tta):
137
                shift = tf.random.uniform(shape, -0.1, 0.1)
138
                scale = tf.random.uniform(shape, 0.9, 1.1)
139
140
                # Run inference on channel augmented input.
141
                aug = (aug + shift * std) * scale
142
                aug = self.model(aug, training=False, inference=True)
143
                aug = tf.reverse(aug, axis=flip)
144
                y.append(aug)
145
146
        # Aggregate outputs.
147
        y = tf.concat(y, axis=0)
148
        y = tf.reduce_mean(y, axis=0, keepdims=True)
149
150
        # Mask out zero-valued voxels.
151
        y *= bmask
152
153
        # Take the argmax to determine label.
154
        # y = tf.argmax(y, axis=self.channel_axis, output_type=tf.int32)
155
156
        # Transpose back to channels_last data format.
157
        y = tf.squeeze(y, axis=0)
158
        if self.model_data_format == 'channels_first':
159
            y = tf.transpose(y, (1, 2, 3, 0))
160
161
        return y
162
163
164
def pad_to_spatial_res(res, x, mask):
165
    # Assumes that x and mask are channels_last data format.
166
    res = tf.convert_to_tensor([res])
167
    shape = tf.convert_to_tensor(x.shape[:-1], dtype=tf.int32)
168
    shape = res - (shape % res)
169
    pad = [[0, shape[0]],
170
           [0, shape[1]],
171
           [0, shape[2]],
172
           [0, 0]]
173
174
    orig_shape = list(x.shape[:-1])
175
    x = tf.pad(x, pad, mode='CONSTANT', constant_values=0.0)
176
    mask = tf.pad(mask, pad, mode='CONSTANT', constant_values=0.0)
177
178
    return x, mask, orig_shape
179
    
180
181
def main(args):
182
    in_ch = len(args.modalities)
183
184
    # Initialize model(s) and load weights / preprocessing stats.
185
    tumor_model = Model(**args.tumor_model_args)
186
    tumor_crop_size = args.tumor_model_args['crop_size']
187
    _ = tumor_model(tf.zeros(shape=[1] + tumor_crop_size + [in_ch] if args.tumor_model_args['data_format'] == 'channels_last' \
188
                                    else [1, in_ch] + tumor_crop_size, dtype=tf.float32))
189
    tumor_model.load_weights(os.path.join(args.tumor_model, 'chkpt.hdf5'))
190
    tumor_mean = tf.convert_to_tensor(args.tumor_prepro['norm']['mean'], dtype=tf.float32)
191
    tumor_std = tf.convert_to_tensor(args.tumor_prepro['norm']['std'], dtype=tf.float32)
192
193
    if args.skull_strip:
194
        skull_model = Model(**args.skull_model_args)
195
        skull_crop_size = args.skull_model_args['crop_size']
196
        _ = model(tf.zeros(shape=[1] + skull_crop_size + [in_ch] if args.skull_model_args['data_format'] == 'channels_last' \
197
                                    else [1, in_ch] + skull_crop_size, dtype=tf.float32))
198
        skull_model.load_weights(os.path.join(args.skull_model, 'chkpt.hdf5'))
199
        skull_mean = tf.convert_to_tensor(args.skull_prepro['norm']['mean'], dtype=tf.float32)
200
        skull_std = tf.convert_to_tensor(args.skull_prepro['norm']['std'], dtype=tf.float32)
201
202
    # Initialize helper classes for inference and evaluation (optional).
203
    dice_fn = DiceCoefficient(data_format='channels_last')
204
    interpolator = Interpolator(args.modalities, order=args.order, mode=args.mode)
205
    tumor_ttaugmentor = TestTimeAugmentor(
206
                                tumor_mean,
207
                                tumor_std,
208
                                tumor_model,
209
                                args.tumor_model_args['data_format'],
210
                                spatial_tta=args.spatial_tta,
211
                                channel_tta=args.channel_tta,
212
                                threshold=args.threshold)
213
    if args.skull_strip:
214
        skull_ttaugmentor = TestTimeAugmentor(
215
                                    skull_mean,
216
                                    skull_std,
217
                                    skull_model,
218
                                    args.skull_model_args['data_format'],
219
                                    spatial_tta=args.spatial_tta,
220
                                    channel_tta=args.channel_tta,
221
                                    threshold=args.threshold)
222
223
    for loc in args.in_locs:
224
        for path in tqdm(glob.glob(os.path.join(loc, '*'))):
225
            with tf.device(args.device):
226
                # If data is labeled, extract label.
227
                try:
228
                    file_card = glob.glob(os.path.join(path, '*' + args.truth + '*' + '.nii' + '*'))[0]
229
                    y = np.array(nib.load(file_card).dataobj).astype(np.float32)
230
                    y = tf.expand_dims(y, axis=-1)
231
                except:
232
                    y = None
233
234
                # Rescale and interpolate input image.
235
                x, mask = interpolator(path)
236
237
                # Strip MRI brain of skull and eye sockets.
238
                if args.skull_strip:
239
                    x, pad_mask, pad = pad_to_spatial_res(          # Pad to spatial resolution
240
                                        args.skull_spatial_res,
241
                                        x,
242
                                        mask)
243
                    skull_mask = skull_ttaugmentor(x, pad_mask)     # Inference with test time augmentation.
244
                    skull_mask = 1.0 - skull_mask                   # Convert skull positives into negatives.
245
                    x *= skull_mask                                 # Mask out skull voxels.
246
                    x = tf.slice(x,                                 # Remove padding.
247
                                 [0, 0, 0, 0],
248
                                 pad + [-1])
249
250
                # Label brain tumor categories per voxel.
251
                x, pad_mask, pad = pad_to_spatial_res(              # Pad to spatial resolution.
252
                                        args.tumor_spatial_res,
253
                                        x,
254
                                        mask)
255
                tumor_mask = tumor_ttaugmentor(x, pad_mask)         # Inference with test time augmentation.
256
                tumor_mask = tf.slice(tumor_mask,                   # Remove padding.
257
                                      [0, 0, 0, 0],
258
                                      pad + [-1])
259
                tumor_mask += 1                                     # Convert [0,1,2] to [1,2,3] for label consistency.
260
                tumor_mask = tumor_mask.numpy()
261
                np.place(tumor_mask, tumor_mask >= 3, [4])          # Replace label `3` with `4` for label consistency.
262
263
                # Reverse interpolation and save as .nii.
264
                y_pred = interpolator.reverse(tumor_mask, path)
265
                
266
                # If label is available, score the prediction.
267
                if y is not None:
268
                    macro, micro = dice_fn(y, y_pred)
269
                    print('{}. Macro: {ma: 1.4f}. Micro: {mi: 1.4f}'
270
                            .format(path.split('/')[-1], ma=macro, mi=micro), flush=True)
271
272
273
if __name__ == '__main__':
274
    parser = TestArgParser()
275
    args = parser.parse_args()
276
    print('Test args: {}'.format(args))
277
    main(args)