|
a |
|
b/src/ensemble/mean_ensemble.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
import torch |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
from src.config import BratsConfiguration |
|
|
7 |
from src.dataset import brats_labels |
|
|
8 |
from src.dataset.utils import dataset |
|
|
9 |
from src.inference import crop_no_patch, return_to_size |
|
|
10 |
from src.models.io_model import load_model |
|
|
11 |
from src.models.unet3d import unet3d |
|
|
12 |
from src.models.vnet import asymm_vnet |
|
|
13 |
from src.test import predict |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def _load(network, model_path, checkpoint): |
|
|
17 |
network.to(device) |
|
|
18 |
checkpoint_path = os.path.join(model_path, checkpoint) |
|
|
19 |
model, _, _, _ = load_model(network, checkpoint_path, device, None, False) |
|
|
20 |
return model |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def load_model_1598550861(model_path): |
|
|
24 |
check = "model_1598550861/checkpoint_epoch_215_val_loss_0.2378825504485875_dice_0.7621174487349105.pth" |
|
|
25 |
network = asymm_vnet.VNet(non_linearity="prelu", in_channels=4, classes=4, init_features_maps=32, |
|
|
26 |
kernel_size=5, padding=2) |
|
|
27 |
return _load(network, model_path, check) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def load_model_1598639885(model_path): |
|
|
31 |
check = "model_1598639885/checkpoint_epoch_198_val_loss_0.19342842820572526_dice_0.8065715717942747.pth" |
|
|
32 |
network = unet3d.ResidualUNet3D(in_channels=4, out_channels=4, final_sigmoid=False, f_maps=32, |
|
|
33 |
layer_order="crg", num_levels=4, num_groups=4, conv_padding=1) |
|
|
34 |
return _load(network, model_path, check) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def load_model_1598640035(model_path): |
|
|
38 |
check = "model_1598640035/checkpoint_epoch_142_val_loss_0.21437616135976087_dice_0.7856238380039416.pth" |
|
|
39 |
network = unet3d.ResidualUNet3D(in_channels=4, out_channels=4, final_sigmoid=False, |
|
|
40 |
f_maps=32, layer_order="crg", num_levels=4, num_groups=4, conv_padding=1) |
|
|
41 |
return _load(network, model_path, check) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def load_model_1598640005(model_path): |
|
|
45 |
check = "model_1598640005/checkpoint_epoch_168_val_loss_0.20105469390137554_dice_0.7989453060986245.pth" |
|
|
46 |
network = unet3d.UNet3D(in_channels=4, out_channels=4, final_sigmoid=False, f_maps=32, |
|
|
47 |
layer_order="crg", num_levels=4, num_groups=4, conv_padding=1) |
|
|
48 |
|
|
|
49 |
return _load(network, model_path, check) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
if __name__ == "__main__": |
|
|
53 |
|
|
|
54 |
config = BratsConfiguration(sys.argv[1]) |
|
|
55 |
model_config = config.get_model_config() |
|
|
56 |
dataset_config = config.get_dataset_config() |
|
|
57 |
basic_config = config.get_basic_config() |
|
|
58 |
|
|
|
59 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
60 |
sampling = dataset_config.get("sampling_method").split(".")[-1] |
|
|
61 |
|
|
|
62 |
models_gen_path = model_config.get("model_path") |
|
|
63 |
task = "ensemble_segmentation" |
|
|
64 |
compute_metrics = False |
|
|
65 |
|
|
|
66 |
model_vnet = load_model_1598550861(models_gen_path) |
|
|
67 |
model_2 = load_model_1598639885(models_gen_path) |
|
|
68 |
model_3 = load_model_1598640035(models_gen_path) |
|
|
69 |
model_4 = load_model_1598640005(models_gen_path) |
|
|
70 |
|
|
|
71 |
data, _ = dataset.read_brats(dataset_config.get("val_csv")) |
|
|
72 |
|
|
|
73 |
for idx in range(0, len(data)): |
|
|
74 |
results = {} |
|
|
75 |
|
|
|
76 |
images = data[idx].load_mri_volumes(normalize=True) |
|
|
77 |
|
|
|
78 |
x_1, x_2, y_1, y_2, z_1, z_2, images, brain_mask, patch_size = crop_no_patch(data[idx].size, images, |
|
|
79 |
data[idx].get_brain_mask(), |
|
|
80 |
sampling) |
|
|
81 |
|
|
|
82 |
_, prediction_four_channels_1 = predict.predict(model_vnet, images, device, False) |
|
|
83 |
_, prediction_four_channels_2 = predict.predict(model_2, images, device, False) |
|
|
84 |
_, prediction_four_channels_3 = predict.predict(model_3, images, device, False) |
|
|
85 |
_, prediction_four_channels_4 = predict.predict(model_4, images, device, False) |
|
|
86 |
|
|
|
87 |
prediction_four_channels_1 = prediction_four_channels_1.view((4, patch_size[0], patch_size[1], patch_size[2])) |
|
|
88 |
|
|
|
89 |
pred_scores = torch.stack([prediction_four_channels_1, prediction_four_channels_2[0], prediction_four_channels_3[0], prediction_four_channels_4[0]]).cpu().numpy() |
|
|
90 |
pred_scores_mean = np.mean(pred_scores, axis=0) |
|
|
91 |
prediction_map = np.argmax(pred_scores_mean, axis=0) |
|
|
92 |
|
|
|
93 |
prediction_map = brats_labels.convert_to_brats_labels(prediction_map) |
|
|
94 |
prediction_map = return_to_size(prediction_map, sampling, x_1, x_2, y_1, y_2, z_1, z_2) |
|
|
95 |
|
|
|
96 |
results["prediction"] = prediction_map |
|
|
97 |
predict.save_predictions(data[idx], results, models_gen_path, task) |