[96354c]: / src / inference.py

Download this file

200 lines (156 with data), 9.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import sys
import torch
from src.dataset.utils import nifi_volume
import numpy as np
from src.dataset import brats_labels
from src.compute_metric_results import compute_wt_tc_et
from src.config import BratsConfiguration
from src.dataset.utils import dataset
from src.models.io_model import load_model
from src.models.unet3d import unet3d
from src.models.vnet import vnet, asymm_vnet
from src.test import predict
from src.uncertainty import uncertainty, test_time_dropout, test_time_augmentation
from src.post_processing import post_process
from src.logging_conf import logger
from src.uncertainty.uncertainty import compute_normalization
def load_network(device, model_config, dataset_config, which_net):
n_modalities = dataset_config.getint("n_modalities")
n_classes = dataset_config.getint("classes")
if which_net == "vnet":
network = vnet.VNet(elu=model_config.getboolean("use_elu"), in_channels=n_modalities, classes=n_classes,
init_features_maps=model_config.getint("init_features_maps"))
elif model_config["network"] == "vnet_asymm":
network = asymm_vnet.VNet(non_linearity=model_config.get("non_linearity"), in_channels=n_modalities,
classes=n_classes,
init_features_maps=model_config.getint("init_features_maps"),
kernel_size=model_config.getint("kernel_size"),
padding=model_config.getint("padding"))
elif which_net == "3dunet_residual":
network = unet3d.ResidualUNet3D(in_channels=n_modalities, out_channels=n_classes, final_sigmoid=False,
f_maps=model_config.getint("init_features_maps"),
layer_order=model_config.get("unet_order"),
num_levels=4, num_groups=4, conv_padding=1)
elif which_net == "3dunet":
network = unet3d.UNet3D(in_channels=n_modalities, out_channels=n_classes, final_sigmoid=False,
f_maps=model_config.getint("init_features_maps"),
layer_order=model_config.get("unet_order"),
num_levels=4, num_groups=4, conv_padding=1)
else:
raise ValueError(f"bad network {which_net}")
network.to(device)
checkpoint_path = os.path.join(model_config.get("model_path"), model_config.get("checkpoint"))
model_path = os.path.dirname(checkpoint_path)
model, _, _, _ = load_model(network, checkpoint_path, device, None, False)
return model, model_path
def crop_no_patch(patch_size, images, brain_mask, sampling):
if sampling == "no_patch":
new_size = (160, 192, 128)
x_1 = int((patch_size[0] - new_size[0]) / 2)
x_2 = int(patch_size[0] - (patch_size[0] - new_size[0]) / 2)
y_1 = int((patch_size[1] - new_size[1]) / 2)
y_2 = int(patch_size[1] - (patch_size[1] - new_size[1]) / 2)
z_1 = int((patch_size[2] - new_size[2]) / 2)
z_2 = int(patch_size[2] - (patch_size[2] - new_size[2]) / 2)
new_images = images[:, x_1:x_2, y_1:y_2, z_1:z_2]
new_brain_mask = brain_mask[x_1:x_2, y_1:y_2, z_1:z_2]
return x_1, x_2, y_1, y_2, z_1, z_2, new_images, new_brain_mask, new_size
else:
x_1, x_2 = 0, patch_size[0]
y_1, y_2 = 0, patch_size[1]
z_1, z_2 = 0, patch_size[2]
return x_1, x_2, y_1, y_2, z_1, z_2, images, brain_mask, patch_size
def return_to_size(volume, sampling, x_1, x_2, y_1, y_2, z_1, z_2, final_size=(240, 240, 155)):
if sampling == "no_patch":
output = np.zeros(final_size)
output[x_1:x_2, y_1:y_2, z_1:z_2] = volume
return output
else:
return volume
if __name__ == "__main__":
config = BratsConfiguration(sys.argv[1])
model_config = config.get_model_config()
dataset_config = config.get_dataset_config()
basic_config = config.get_basic_config()
unc_config = config.get_uncertainty_config()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
task = "segmentation_task"
compute_metrics = False
flag_post_process = True
model, model_path = load_network(device, model_config, dataset_config, model_config["network"])
setx = "train"
data, data_test = dataset.read_brats(dataset_config.get("train_csv"))
data.extend(data_test)
sampling = dataset_config.get("sampling_method").split(".")[-1]
uncertainty_flag = basic_config.getboolean("uncertainty_flag")
uncertainty_type = unc_config.get("uncertainty_type")
n_iterations = unc_config.getint("n_iterations")
use_dropout = unc_config.getboolean("use_dropout")
for idx in range(0, len(data)):
patch_size = data[idx].size
images = data[idx].load_mri_volumes(normalize=True)
brain_mask = data[idx].get_brain_mask()
x_1, x_2, y_1, y_2, z_1, z_2, images, brain_mask, patch_size = crop_no_patch(patch_size, images, brain_mask,
sampling)
results = {}
if uncertainty_flag:
def uncertainty_loop(uncertainty_type, model, images, device, n_iterations, brain_mask, use_dropout):
if uncertainty_type == "ttd":
return test_time_dropout.ttd_uncertainty_loop(model, images, device, n_iterations)
elif uncertainty_type == "tta":
return test_time_augmentation.tta_uncertainty_loop(model, images, device, brain_mask, n_iterations,
use_dropout)
else:
raise ValueError(f"Wrong uncertainty type {uncertainty_type}. Should be ttd or tta")
_, prediction_score_vectors = uncertainty_loop(uncertainty_type, model, images, device,
n_iterations, brain_mask, use_dropout)
# Get segmentation map by computing the mean of the prediction scores and selecting bigger one
pred_scores = torch.stack(tuple(prediction_score_vectors)).cpu().numpy()
pred_scores_mean = np.mean(pred_scores, axis=0)
prediction_map = np.argmax(pred_scores_mean, axis=1).reshape(patch_size)
wt_var, tc_var, et_var = uncertainty.get_variation_uncertainty(prediction_score_vectors, patch_size)
global_unc = uncertainty.get_entropy_uncertainty(prediction_score_vectors, patch_size)
wt_var = return_to_size(wt_var, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
tc_var = return_to_size(tc_var, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
et_var = return_to_size(et_var, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
global_unc = return_to_size(global_unc, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
task = f"uncertainty_task_{uncertainty_type}"
results = {"whole": wt_var, "core": tc_var, "enhance": et_var, "entropy": global_unc}
else:
prediction_four_channels, vector_prediction_scores = predict.predict(model, images, device,
monte_carlo=False)
if model_config["network"] == "vnet":
best_scores_map = predict.get_scores_map_from_vector(vector_prediction_scores, patch_size)
else:
# it returns softmax predictions differently already as vector map of four channels
best_scores_map = vector_prediction_scores
prediction_map = predict.get_prediction_map(prediction_four_channels)
prediction_map = brats_labels.convert_to_brats_labels(prediction_map)
prediction_map = return_to_size(prediction_map, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
if flag_post_process:
threshold = 1
segmentation_post = prediction_map.copy()
pred_mask_wt = brats_labels.get_wt(segmentation_post)
mask_removed_regions_wt = post_process.keep_conn_component_bigger_than_th(pred_mask_wt, th=threshold)
elements_to_remove = pred_mask_wt - mask_removed_regions_wt
segmentation_post[elements_to_remove == 1] = 0
post_result = {"prediction": segmentation_post}
predict.save_predictions(data[idx], post_result, model_path, f"{task}_post_processed")
results["prediction"] = prediction_map
predict.save_predictions(data[idx], results, model_path, task)
if compute_metrics:
patient_path = os.path.join(data[idx].data_path, data[idx].patch_name, data[idx].seg)
data_path = os.path.join(data[idx].data_path, data[idx].patch_name, data[idx].flair)
if os.path.exists(patient_path):
volume_gt = data[idx].load_gt_mask()
volume = nifi_volume.load_nifi_volume(data_path)
metrics = compute_wt_tc_et(prediction_map, volume_gt, volume)
logger.info(f"{data[idx].patient} | {metrics}")
print("Normalize uncertainty for brats!")
if uncertainty_flag:
input_dir = os.path.join(model_path, task)
output_dir = os.path.join(model_path, task, "normalized")
gt_path = data[0].data_path
compute_normalization(input_dir=input_dir, output_dir=output_dir, ground_truth_path=gt_path)
print("All done!!!! Be happy!")