Diff of /predict.py [000000] .. [7b3b92]

Switch to unified view

a b/predict.py
1
import os
2
import numpy as np
3
import nibabel as nib
4
from glob import glob
5
from tensorflow.keras.models import load_model
6
7
8
def read_brain(brain_dir, mode='train', x0=42, x1=194, y0=29, y1=221, z0=2, z1=146):
9
10
    """
11
    A function that reads and crops a brain modalities (nii.gz format)
12
    
13
    Parameters
14
    ----------
15
    brain_dir : string
16
        The path to a folder that contains MRI modalities of a specific brain
17
    mode : string
18
        'train' or 'validation' mode. The default is 'train'.
19
    x0, x1, y0, y1, z0, z1 : int
20
        The coordinates to crop 3D brain volume. For example, a brain volume with the 
21
        shape [x,y,z,modalites] is cropped [x0:x1, y0:y1, z0:z1, :] to have the shape
22
        [x1-x0, y1-y0, z1-z0, modalities]. One can calculate the x0,x1,... by calculating
23
        none zero pixels through dataset. Note that the final three shapes must be divisible
24
        by the network downscale rate.
25
        
26
    Returns
27
    -------
28
    all_modalities : array
29
        The cropped modalities (+ gt if mode='train')
30
    brain_affine : array
31
        The affine matrix of the input brain volume
32
    brain_name : str
33
        The name of the input brain volume
34
35
    """
36
    
37
    brain_dir = os.path.normpath(brain_dir)
38
    flair     = glob( os.path.join(brain_dir, '*_flair*.nii.gz'))
39
    t1        = glob( os.path.join(brain_dir, '*_t1*.nii.gz'))
40
    t1ce      = glob( os.path.join(brain_dir, '*_t1ce*.nii.gz'))
41
    t2        = glob( os.path.join(brain_dir, '*_t2*.nii.gz'))
42
    
43
    if mode=='train':
44
        gt             = glob( os.path.join(brain_dir, '*_seg*.nii.gz'))
45
        modalities_dir = [flair[0], t1[0], t1ce[0], t2[0], gt[0]]
46
        
47
    elif mode=='validation':
48
        modalities_dir = [flair[0], t1[0], t1ce[0], t2[0]]   
49
    
50
    all_modalities = []    
51
    for modality in modalities_dir:      
52
        nifti_file   = nib.load(modality)
53
        brain_numpy  = np.asarray(nifti_file.dataobj)    
54
        all_modalities.append(brain_numpy)
55
        
56
    # all modalities have the same affine, so we take one of them (the last one in this case),
57
    # affine is just saved for preparing the predicted nii.gz file in the future.       
58
    brain_affine   = nifti_file.affine
59
    all_modalities = np.array(all_modalities)
60
    all_modalities = np.rint(all_modalities).astype(np.int16)
61
    all_modalities = all_modalities[:, x0:x1, y0:y1, z0:z1]
62
    # to fit keras channel last model
63
    all_modalities = np.transpose(all_modalities) 
64
    # tumor grade + name
65
    brain_name     = os.path.basename(os.path.split(brain_dir)[0]) + '_' + os.path.basename(brain_dir) 
66
67
    return all_modalities, brain_affine, brain_name
68
    
69
    
70
71
def normalize_slice(slice):
72
    
73
    """
74
    Removes 1% of the top and bottom intensities and perform
75
    normalization on the input 2D slice.
76
    """
77
    
78
    b = np.percentile(slice, 99)
79
    t = np.percentile(slice, 1)
80
    slice = np.clip(slice, t, b)
81
    if np.std(slice)==0:
82
        return slice
83
    else:
84
        slice = (slice - np.mean(slice)) / np.std(slice)
85
        return slice
86
    
87
88
def normalize_volume(input_volume):
89
    
90
    """
91
    Perform a slice-based normalization on each modalities of input volume.
92
    """
93
    normalized_slices = np.zeros_like(input_volume).astype(np.float32)
94
    for slice_ix in range(4):
95
        normalized_slices[slice_ix] = input_volume[slice_ix]
96
        for mode_ix in range(input_volume.shape[1]):
97
            normalized_slices[slice_ix][mode_ix] = normalize_slice(input_volume[slice_ix][mode_ix])
98
99
    return normalized_slices    
100
101
102
def save_predicted_results(prediction, brain_affine, view, output_dir,  z_main=155, z0=2, z1=146, y_main=240, y0=29, y1=221, x_main=240, x0=42, x1=194):
103
    
104
    """
105
    Save the segmented results into a .nii.gz file, so that it can be uploaded to the BraTS server.
106
    Note that to correctly save the segmented brains, it is necessery to set x0, x1, ... correctly.
107
    
108
    Parameters
109
    ----------
110
    prediction : array
111
        The predictred brain.
112
    brain_affine : array
113
        The affine matrix of the predicted brain volume
114
    view : str
115
        'axial', 'sagittal' or 'coronal'. The 'view' is needed to reconstruct output axes.
116
    output_dir : str
117
        The path to save .nii.gz file.
118
119
120
    """
121
    
122
    prediction = np.argmax(prediction, axis=-1).astype(np.uint16)            
123
    prediction[prediction==3] = 4
124
    
125
    if view=="axial":
126
        prediction    = np.pad(prediction, ((z0, z_main-z1), (y0, y_main-y1), (x0, x_main-x1)), 'constant')
127
        prediction    = prediction.transpose(2,1,0)
128
    elif view=="sagital":
129
        prediction    = np.pad(prediction, ((x0, x_main-x1), (y0, y_main-y1), (z0 , z_main-z1)), 'constant')
130
    elif view=="coronal":
131
        prediction    = np.pad(prediction, ((y0, y_main-y1), (x0, x_main-x1), (z0 , z_main-z1)), 'constant')
132
        prediction    = prediction.transpose(1,0,2)
133
    #
134
    prediction_ni    = nib.Nifti1Image(prediction, brain_affine)
135
    prediction_ni.to_filename(output_dir+ '.nii.gz')
136
137
138
139
140
141
142
if __name__ == '__main__':
143
       
144
    val_data_dir       = '/path/to/data/*'
145
    view               = 'axial'
146
    saved_model_dir    = '/path/to/a/trained/model.hdf5'  #ex './save/axial_fold0/model.hdf5'
147
    save_pred_dir      = './predict/'
148
    batch_size         = 32
149
150
    
151
    if not os.path.isdir(save_pred_dir):
152
        os.mkdir(save_pred_dir)
153
       
154
    all_brains_dir = glob(val_data_dir)
155
    all_brains_dir.sort()
156
    
157
    if view == 'axial':
158
        view_axes = (0, 1, 2, 3)            
159
    elif view == 'sagittal': 
160
        view_axes = (2, 1, 0, 3)
161
    elif view == 'coronal':
162
        view_axes = (1, 2, 0, 3)            
163
    else:
164
        ValueError('unknown input view => {}'.format(view))
165
    
166
    
167
    model        = load_model(saved_model_dir, compile=False)
168
    for brain_dir in all_brains_dir:    
169
        if os.path.isdir(brain_dir):
170
            print("Volume ID: ", os.path.basename(brain_dir))
171
            all_modalities, brain_affine, _ = read_brain(brain_dir, mode='validation')
172
            all_modalities                  = all_modalities.transpose(view_axes)
173
            all_modalities                  = normalize_volume(all_modalities)
174
            prediction                      = model.predict(all_modalities, batch_size=batch_size, verbose=1)
175
            output_dir                      = os.path.join(save_pred_dir, os.path.basename(brain_dir))
176
            save_predicted_results(prediction, brain_affine, view, output_dir)
177
            
178