[7a84af]: / prod / predict_mat2.py

Download this file

156 lines (127 with data), 6.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import json
import os
import numpy as np
from brats.utils import get_last_model_path
from fetal_net.normalize import normalize_data
from fetal_net.postprocess import postprocess_prediction as process_pred
from fetal_net.prediction import patch_wise_prediction
from fetal_net.training import load_old_model
from brats.preprocess import window_intensities_data
from scipy.io import loadmat, savemat
from fetal_net.utils.cut_relevant_areas import find_bounding_box, check_bounding_box
def secondary_prediction(mask, vol, config2, model2_path=None,
preprocess_method2=None, norm_params2=None,
overlap_factor=0.9):
model2 = load_old_model(get_last_model_path(model2_path))
pred = mask
bbox_start, bbox_end = find_bounding_box(pred)
check_bounding_box(pred, bbox_start, bbox_end)
padding = [16, 16, 8]
if padding is not None:
bbox_start = np.maximum(bbox_start - padding, 0)
bbox_end = np.minimum(bbox_end + padding, mask.shape)
data = vol.astype(np.float)[
bbox_start[0]:bbox_end[0],
bbox_start[1]:bbox_end[1],
bbox_start[2]:bbox_end[2]
]
data = preproc_and_norm(data, preprocess_method2, norm_params2)
prediction = \
patch_wise_prediction(model=model2,
data=np.expand_dims(data, 0),
overlap_factor=overlap_factor,
patch_shape=config2["patch_shape"] + [config2["patch_depth"]])
prediction = prediction.squeeze()
padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end))
print(padding2)
print(prediction.shape)
prediction = np.pad(prediction, padding2, mode='constant', constant_values=0)
return prediction
def preproc_and_norm(data, preprocess_method, norm_params):
if preprocess_method is not None:
print('Applying preprocess by {}...'.format(preprocess_method))
if preprocess_method == 'window_1_99':
data = window_intensities_data(data)
else:
raise Exception('Unknown preprocess: {}'.format(preprocess_method))
if norm_params is not None and any(norm_params.values()):
data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std'])
return data
def main(input_mat_path, output_mat_path, overlap_factor,
config, model_path, preprocess_method=None, norm_params=None,
config2=None, model2_path=None, preprocess_method2=None, norm_params2=None):
print(model_path)
model = load_old_model(get_last_model_path(model_path))
print('Loading mat from {}...'.format(input_mat_path))
mat = loadmat(input_mat_path)
print('Predicting mask...')
data = mat['volume'].astype(np.float)
data = preproc_and_norm(data, preprocess_method, norm_params)
prediction = \
patch_wise_prediction(model=model,
data=np.expand_dims(data, 0),
overlap_factor=overlap_factor,
patch_shape=config["patch_shape"] + [config["patch_depth"]])
print('Post-processing mask...')
if prediction.shape[-1] > 1:
prediction = prediction[..., 1]
prediction = prediction.squeeze()
print("Storing prediction in [7-9], 7 should be the best...")
mat['masks'][0, 9] = \
process_pred(prediction, gaussian_std=0, threshold=0.2) # .astype(np.uint8)
mat['masks'][0, 8] = \
process_pred(prediction, gaussian_std=1, threshold=0.5) # .astype(np.uint8)
mat['masks'][0, 7] = \
process_pred(prediction, gaussian_std=0.5, threshold=0.5) # .astype(np.uint8)
if config2 is not None:
print('Making secondary prediction... [6]')
prediction = secondary_prediction(mat['masks'][0, 7], vol=mat['volume'].astype(np.float),
config2=config2, model2_path=model2_path,
preprocess_method2=preprocess_method2, norm_params2=norm_params2,
overlap_factor=0.9)
mat['masks'][0, 6] = \
process_pred(prediction, gaussian_std=0, threshold=0.2) # .astype(np.uint8)
mat['masks'][0, 5] = \
process_pred(prediction, gaussian_std=1, threshold=0.5) # .astype(np.uint8)
mat['masks'][0, 4] = \
process_pred(prediction, gaussian_std=0.5, threshold=0.5) # .astype(np.uint8)
print('Saving mat to {}'.format(output_mat_path))
savemat(output_mat_path, mat)
print('Finished.')
def get_params(config_dir):
with open(os.path.join(config_dir, 'config.json'), 'r') as f:
__config = json.load(f)
with open(os.path.join(config_dir, 'norm_params.json'), 'r') as f:
__norm_params = json.load(f)
__model_path = os.path.join(config_dir, os.path.basename(__config['model_file']))
return __config, __norm_params, __model_path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input_mat", help="specifies mat file dir path",
type=str, required=True)
parser.add_argument("--output_mat", help="specifies mat file dir path",
type=str, required=True)
parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches",
type=float, default=0.9)
# Params for primary prediction
parser.add_argument("--config_dir", help="specifies config dir path",
type=str, required=True)
parser.add_argument("--preprocess", help="what preprocess to do",
type=str, required=False, default=None)
# Params for secondary prediction
parser.add_argument("--config2_dir", help="specifies config dir path",
type=str, required=False, default=None)
parser.add_argument("--preprocess2", help="what preprocess to do",
type=str, required=False, default=None)
opts = parser.parse_args()
# 1
_config, _norm_params, _model_path = get_params(opts.config_dir)
# 2
if opts.config2_dir is not None:
_config2, _norm_params2, _model2_path = get_params(opts.config2_dir)
else:
_config2, _norm_params2, _model2_path = None, None, None
main(opts.input_mat, opts.output_mat, overlap_factor=opts.overlap_factor,
config=_config, model_path=_model_path, preprocess_method=opts.preprocess, norm_params=_norm_params,
config2=_config2, model2_path=_model2_path, preprocess_method2=opts.preprocess2, norm_params2=_norm_params2)