[26e314]: / experiment_brain_parcellation.py

Download this file

134 lines (103 with data), 4.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
__author__ = 'adeb'
from shutil import copy2
import inspect
import PIL
import pickle
from spynet.utils.utilities import analyse_classes
from data_brain_parcellation import DatasetBrainParcellation
from network_brain_parcellation import *
from spynet.models.network import *
from spynet.models.neuron_type import *
from spynet.data.dataset import *
from spynet.training.trainer import *
from spynet.training.monitor import *
from spynet.training.parameters_selector import *
from spynet.training.stopping_criterion import *
from spynet.training.cost_function import *
from spynet.training.learning_update import *
from spynet.experiment import Experiment
from spynet.utils.utilities import tile_raster_images
import theano
class ExperimentBrain(Experiment):
"""
Main experiment to train a network on a dataset
"""
def __init__(self, exp_name, data_path):
Experiment.__init__(self, exp_name, data_path)
def copy_file_virtual(self):
copy2(inspect.getfile(inspect.currentframe()), self.path)
def run(self):
###### Create the datasets
# aa = CostNegLLWeighted(np.array([0.9, 0.1]))
# e = theano.function(inputs=[], outputs=aa.test())
# print e()
## Load the data
training_data_path = self.data_path + "train.h5"
ds_training = DatasetBrainParcellation()
ds_training.read(training_data_path)
[ds_training, ds_validation] = ds_training.split_dataset_proportions([0.95, 0.05])
testing_data_path = self.data_path + "test.h5"
ds_testing = DatasetBrainParcellation()
ds_testing.read(testing_data_path)
## Display data sample
# image = PIL.Image.fromarray(tile_raster_images(X=ds_training.inputs[0:50],
# img_shape=(29, 29), tile_shape=(5, 10),
# tile_spacing=(1, 1)))
# image.save(self.path + "filters_corruption_30.png")
## Few stats about the targets
classes, proportion_class = analyse_classes(np.argmax(ds_training.outputs, axis=1), "Training data:")
print classes
## Scale some part of the data
print "Scaling"
s = Scaler([slice(-134, None, None)])
s.compute_parameters(ds_training.inputs)
s.scale(ds_training.inputs)
s.scale(ds_validation.inputs)
s.scale(ds_testing.inputs)
pickle.dump(s, open(self.path + "s.scaler", "wb"))
###### Create the network
net = NetworkUltimateConv()
net.init(33, 29, 5, 134, 135)
print net
###### Configure the trainer
# Cost function
cost_function = CostNegLL(net.ls_params)
# Learning update
learning_rate = 0.05
momentum = 0.5
lr_update = LearningUpdateGDMomentum(learning_rate, momentum)
# Create monitors and add them to the trainer
freq = 1
freq2 = 0.00001
# err_training = MonitorErrorRate(freq, "Train", ds_training)
# err_testing = MonitorErrorRate(freq, "Test", ds_testing)
err_validation = MonitorErrorRate(freq, "Val", ds_validation)
# dice_training = MonitorDiceCoefficient(freq, "Train", ds_training, 135)
dice_testing = MonitorDiceCoefficient(freq, "Test", ds_testing, 135)
# dice_validation = MonitorDiceCoefficient(freq, "Val", ds_validation, 135)
# Create stopping criteria and add them to the trainer
max_epoch = MaxEpoch(300)
early_stopping = EarlyStopping(err_validation, 10, 0.99, 5)
# Create the network selector
params_selector = ParamSelectorBestMonitoredValue(err_validation)
# Create the trainer object
batch_size = 200
t = Trainer(net, cost_function, params_selector, [max_epoch, early_stopping],
lr_update, ds_training, batch_size,
[err_validation, dice_testing])
###### Train the network
t.train()
###### Plot the records
# pred = np.argmax(t.net.predict(ds_testing.inputs, 10000), axis=1)
# d = compute_dice(pred, np.argmax(ds_testing.outputs, axis=1), 134)
# print "Dice test: {}".format(np.mean(d))
# print "Error rate test: {}".format(error_rate(np.argmax(ds_testing.outputs, axis=1), pred))
save_records_plot(self.path, [err_validation], "err", t.n_train_batches, "upper right")
# save_records_plot(self.path, [dice_testing], "dice", t.n_train_batches, "lower right")
###### Save the network
net.save_parameters(self.path + "net.net")
if __name__ == '__main__':
exp_name = "paper_ultimate_conv"
data_path = "./datasets/paper_ultimate_conv/"
exp = ExperimentBrain(exp_name, data_path)
exp.run()