a b/src/ensemble/mean_ensemble.py
1
import os
2
import sys
3
import torch
4
import numpy as np
5
6
from src.config import BratsConfiguration
7
from src.dataset import brats_labels
8
from src.dataset.utils import dataset
9
from src.inference import crop_no_patch, return_to_size
10
from src.models.io_model import load_model
11
from src.models.unet3d import unet3d
12
from src.models.vnet import asymm_vnet
13
from src.test import predict
14
15
16
def _load(network, model_path, checkpoint):
17
    network.to(device)
18
    checkpoint_path = os.path.join(model_path, checkpoint)
19
    model, _, _, _ = load_model(network, checkpoint_path, device, None, False)
20
    return model
21
22
23
def load_model_1598550861(model_path):
24
    check = "model_1598550861/checkpoint_epoch_215_val_loss_0.2378825504485875_dice_0.7621174487349105.pth"
25
    network = asymm_vnet.VNet(non_linearity="prelu", in_channels=4,  classes=4, init_features_maps=32,
26
                              kernel_size=5, padding=2)
27
    return _load(network, model_path, check)
28
29
30
def load_model_1598639885(model_path):
31
    check = "model_1598639885/checkpoint_epoch_198_val_loss_0.19342842820572526_dice_0.8065715717942747.pth"
32
    network = unet3d.ResidualUNet3D(in_channels=4, out_channels=4, final_sigmoid=False, f_maps=32,
33
                                    layer_order="crg", num_levels=4, num_groups=4, conv_padding=1)
34
    return _load(network, model_path, check)
35
36
37
def load_model_1598640035(model_path):
38
    check = "model_1598640035/checkpoint_epoch_142_val_loss_0.21437616135976087_dice_0.7856238380039416.pth"
39
    network = unet3d.ResidualUNet3D(in_channels=4, out_channels=4, final_sigmoid=False,
40
                                    f_maps=32, layer_order="crg", num_levels=4, num_groups=4, conv_padding=1)
41
    return _load(network, model_path, check)
42
43
44
def load_model_1598640005(model_path):
45
    check = "model_1598640005/checkpoint_epoch_168_val_loss_0.20105469390137554_dice_0.7989453060986245.pth"
46
    network = unet3d.UNet3D(in_channels=4, out_channels=4, final_sigmoid=False, f_maps=32,
47
                            layer_order="crg", num_levels=4, num_groups=4, conv_padding=1)
48
49
    return _load(network, model_path, check)
50
51
52
if __name__ == "__main__":
53
54
    config = BratsConfiguration(sys.argv[1])
55
    model_config = config.get_model_config()
56
    dataset_config = config.get_dataset_config()
57
    basic_config = config.get_basic_config()
58
59
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
    sampling = dataset_config.get("sampling_method").split(".")[-1]
61
62
    models_gen_path = model_config.get("model_path")
63
    task = "ensemble_segmentation"
64
    compute_metrics = False
65
66
    model_vnet = load_model_1598550861(models_gen_path)
67
    model_2 = load_model_1598639885(models_gen_path)
68
    model_3 = load_model_1598640035(models_gen_path)
69
    model_4 = load_model_1598640005(models_gen_path)
70
71
    data, _ = dataset.read_brats(dataset_config.get("val_csv"))
72
73
    for idx in range(0, len(data)):
74
        results = {}
75
76
        images = data[idx].load_mri_volumes(normalize=True)
77
78
        x_1, x_2, y_1, y_2, z_1, z_2, images, brain_mask, patch_size = crop_no_patch(data[idx].size, images,
79
                                                                                     data[idx].get_brain_mask(),
80
                                                                                     sampling)
81
82
        _, prediction_four_channels_1 = predict.predict(model_vnet, images, device, False)
83
        _, prediction_four_channels_2 = predict.predict(model_2, images, device, False)
84
        _, prediction_four_channels_3 = predict.predict(model_3, images, device, False)
85
        _, prediction_four_channels_4 = predict.predict(model_4, images, device, False)
86
87
        prediction_four_channels_1 = prediction_four_channels_1.view((4, patch_size[0], patch_size[1], patch_size[2]))
88
89
        pred_scores = torch.stack([prediction_four_channels_1, prediction_four_channels_2[0], prediction_four_channels_3[0], prediction_four_channels_4[0]]).cpu().numpy()
90
        pred_scores_mean = np.mean(pred_scores, axis=0)
91
        prediction_map = np.argmax(pred_scores_mean, axis=0)
92
93
        prediction_map = brats_labels.convert_to_brats_labels(prediction_map)
94
        prediction_map = return_to_size(prediction_map, sampling, x_1, x_2, y_1, y_2, z_1, z_2)
95
96
        results["prediction"] = prediction_map
97
        predict.save_predictions(data[idx], results, models_gen_path, task)