[26e314]: / main_compute_segmentations.py

Download this file

129 lines (106 with data), 5.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
__author__ = 'adeb'
import os
import imp
import time
import pickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.cm as cm
from matplotlib import pyplot as plt
import nibabel
import theano
import theano.tensor as T
import scipy.io
from spynet.utils.utilities import create_img_from_pred, compute_dice_symb, compute_dice, error_rate
import spynet.training.trainer as trainer
from spynet.models.network import *
from network_brain_parcellation import *
from spynet.data.utils_3d.pick_voxel import *
from spynet.data.utils_3d.pick_patch import *
from spynet.data.utils_3d.pick_target import *
from data_brain_parcellation import DatasetBrainParcellation, DataGeneratorBrain, list_miccai_files, RegionCentroids
from spynet.utils.utilities import open_h5file
if __name__ == '__main__':
"""
Compute the segmentations of the testing brains with the trained networks (with approximation of the centroids)
"""
experiment_path = "./experiments/test_iter_0/"
data_path = "./datasets/test_iter/"
cf_data = imp.load_source("cf_data", data_path + "cfg_testing_data_creation.py")
# Load the network
net = NetworkUltimateConv()
net.init(29, 29, 13, 134, 135)
net.load_parameters(open_h5file(experiment_path + "net.net"))
n_out = net.n_out
# Load the scaler
scaler = pickle.load(open(experiment_path + "s.scaler", "rb"))
# Files on which to evaluate the network
file_list = list_miccai_files(**{"mode": "folder", "path": "./datasets/miccai/2/"})
n_files = len(file_list)
# Options for the generation of the dataset
# The generation/evaluation of the dataset has to be split into batches as a whole brain does not fit into memory
batch_size = 50000
select_region = SelectWholeBrain()
extract_vx = ExtractVoxelAll(1)
pick_vx = PickVoxel(select_region, extract_vx)
pick_patch = create_pick_features(cf_data)
pick_tg = create_pick_target(cf_data)
# Create the data generator
data_gen = DataGeneratorBrain()
data_gen.init_from(file_list, pick_vx, pick_patch, pick_tg)
# Evaluate the centroids
net_wo_centroids_path = "./experiments/report_3_patches_balanced_conv/"
net_wo_centroids = NetworkThreePatchesConv()
net_wo_centroids.init(29, 135)
net_wo_centroids.load_parameters(open_h5file(net_wo_centroids_path + "net.net"))
ds_testing = DatasetBrainParcellation()
ds_testing.read(data_path + "train.h5")
pred_wo_centroids = np.argmax(net_wo_centroids.predict(ds_testing.inputs, 1000), axis=1)
region_centroids = RegionCentroids(134)
region_centroids.update_barycentres(ds_testing.vx, pred_wo_centroids)
# Generate and evaluate the dataset
start_time = time.clock()
dices = np.zeros((n_files, 134))
errors = np.zeros((n_files,))
pred_functions = {}
for atlas_id in xrange(n_files):
print "Atlas: {}".format(atlas_id)
ls_vx = []
ls_pred = []
brain_batches = data_gen.generate_single_atlas(atlas_id, None, region_centroids, batch_size, True)
vx_all, pred_all = net.predict_from_generator(brain_batches, scaler, pred_functions)
# Construct the predicted image
img_true = data_gen.atlases[atlas_id][1]
img_pred = create_img_from_pred(vx_all, pred_all, img_true.shape)
# Compute the dice coefficient and the error
non_zo = img_pred.nonzero() or img_true.nonzero()
pred = img_pred[non_zo]
true = img_true[non_zo]
dice_regions = compute_dice(pred, true, n_out)
err_global = error_rate(pred, true)
end_time = time.clock()
print "It took {} seconds to evaluate the whole brain.".format(end_time - start_time)
print "The mean dice is {}".format(dice_regions.mean())
print "The error rateis {}".format(err_global)
# Save the results
errors[atlas_id] = err_global
dices[atlas_id, :] = dice_regions
# Diff Image
img_diff = (img_pred == img_true).astype(np.uint8)
img_diff += 1
img_diff[img_pred == 0] = 0
# Save the 3D images
affine = data_gen.atlases[atlas_id][2]
mri_file = data_gen.files[atlas_id][0]
mri = data_gen.atlases[atlas_id][0]
basename = os.path.splitext(os.path.basename(mri_file))[0]
img_pred_nifti = nibabel.Nifti1Image(img_pred, affine)
img_mri_nifti = nibabel.Nifti1Image(mri, affine)
img_true_nifti = nibabel.Nifti1Image(img_true, affine)
img_diff_nifti = nibabel.Nifti1Image(img_diff, affine)
nibabel.save(img_pred_nifti, experiment_path + basename + "_pred.nii")
nibabel.save(img_mri_nifti, experiment_path + basename + "_mri.nii")
nibabel.save(img_true_nifti, experiment_path + basename + "_true.nii")
nibabel.save(img_diff_nifti, experiment_path + basename + "_diff.nii")
scipy.io.savemat(experiment_path + "dice_brains.mat", mdict={'arr': dices})
scipy.io.savemat(experiment_path + "error_brains.mat", mdict={'arr': errors})