[26e314]: / side_analysis / experiment_iterative.py

Download this file

175 lines (137 with data), 6.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
__author__ = 'adeb'
from shutil import copy2
import inspect
import PIL
import pickle
from spynet.utils.utilities import analyse_classes
from data_brain_parcellation import DatasetBrainParcellation, DataGeneratorBrain
from network_brain_parcellation import *
from spynet.models.network import *
from spynet.models.neuron_type import *
from spynet.data.dataset import *
from spynet.training.trainer import *
from spynet.training.monitor import *
from spynet.training.parameters_selector import *
from spynet.training.stopping_criterion import *
from spynet.training.cost_function import *
from spynet.training.learning_update import *
from spynet.experiment import Experiment
from spynet.utils.utilities import tile_raster_images
from spynet.utils.utilities import load_config
from data_brain_parcellation import generate_and_save
class ExperimentBrain(Experiment):
"""
Experiment in which the network is iteratively trained on its misclassified datapoints. kind of boosting.
"""
def __init__(self, exp_name, data_path):
Experiment.__init__(self, exp_name, data_path)
def copy_file_virtual(self):
copy2(inspect.getfile(inspect.currentframe()), self.path)
def run(self):
###### Create the datasets
# Testing
data_cf_test = load_config("cfg_testing_data_creation.py")
data_cf_test.general["file_path"] = data_cf_test.data_path + "test_temp.h5"
data_generator_test = DataGeneratorBrain()
data_generator_test.init_from_config(data_cf_test)
ds_test_temp = DatasetBrainParcellation()
vx, inputs, outputs, file_ids = data_generator_test.generate_parallel(data_cf_test.general["n_data"])
ds_test_temp.populate(inputs, outputs, vx, file_ids)
ds_test_temp.shuffle_data()
# Training
prop_validation = 0.1
data_cf_train = load_config("cfg_training_data_creation.py")
data_cf_train.general["file_path"] = data_cf_train.data_path + "train_temp.h5"
data_generator_train = DataGeneratorBrain()
data_generator_train.init_from_config(data_cf_train)
ds_train_temp = DatasetBrainParcellation()
vx, inputs, outputs, file_ids = data_generator_train.generate_parallel(data_cf_train.general["n_data"])
ds_train_temp.populate(inputs, outputs, vx, file_ids)
ds_train_temp.shuffle_data()
[ds_train_temp, ds_validation_temp] = ds_train_temp.split_dataset_proportions([1-prop_validation, prop_validation])
## Scale some part of the data
print "Scaling"
s = Scaler([slice(-134, None, None)])
s.compute_parameters(ds_train_temp.inputs)
s.scale(ds_train_temp.inputs)
s.scale(ds_validation_temp.inputs)
s.scale(ds_test_temp.inputs)
pickle.dump(s, open(self.path + "s.scaler", "wb"))
###### Create the network
net = NetworkUltimateConv()
net.init(29, 29, 13, 134, 135)
print net
###### Configure the trainer
# Cost function
cost_function = CostNegLL()
# Learning update
learning_rate = 0.05
momentum = 0.5
lr_update = LearningUpdateGDMomentum(learning_rate, momentum)
# Create monitors and add them to the trainer
freq = 1
freq2 = 0.00001
err_training = MonitorErrorRate(freq, "Train", ds_train_temp)
err_testing = MonitorErrorRate(freq, "Test", ds_test_temp)
err_validation = MonitorErrorRate(freq, "Val", ds_validation_temp)
# dice_training = MonitorDiceCoefficient(freq, "Train", ds_train_temp, 135)
dice_testing = MonitorDiceCoefficient(freq, "Test", ds_test_temp, 135)
# dice_validation = MonitorDiceCoefficient(freq, "Val", ds_validation_temp, 135)
# Create stopping criteria and add them to the trainer
max_epoch = MaxEpoch(100)
#early_stopping = EarlyStopping(err_validation, 10, 0.99, 5)
# Create the network selector
params_selector = ParamSelectorBestMonitoredValue(err_validation)
# Create the trainer object
batch_size = 200
t = Trainer(net, cost_function, params_selector, [max_epoch],
lr_update, ds_train_temp, batch_size,
[err_training, err_validation, err_testing, dice_testing])
###### Train the network
n_iter = 1
evolution = np.zeros((2, n_iter))
n_validation_data = ds_validation_temp.n_data
for i in xrange(n_iter):
# Train
t.train()
net.save_parameters(self.path + "net.net")
# Test
print "ERRORS::::"
out_pred = net.predict(ds_test_temp.inputs, 1000)
errors = np.argmax(out_pred, axis=1) != np.argmax(ds_test_temp.outputs, axis=1)
print np.mean(errors)
evolution[0,i] = np.mean(errors)
out_pred = net.predict(ds_validation_temp.inputs, 1000)
errors = np.argmax(out_pred, axis=1) != np.argmax(ds_validation_temp.outputs, axis=1)
print np.mean(errors)
evolution[1,i] = np.mean(errors)
vx_errors = ds_train_temp.vx[errors]
file_ids_errors = ds_validation_temp.file_ids[errors]
inputs_errors = ds_validation_temp.inputs[errors]
outputs_errors = ds_validation_temp.outputs[errors]
# Update datasets, trainer, monitors
n_data = int(round(0.9*data_cf_train.general["n_data"]))
prop_validation = 0.1
vx, inputs, outputs, file_ids = data_generator_train.generate_parallel(n_data)
s.scale(inputs)
split = int(round(n_data*(1-prop_validation)))
ds_train_temp.populate(np.concatenate([inputs[0:split], inputs_errors], axis=0),
np.concatenate([outputs[0:split], outputs_errors], axis=0),
np.concatenate([vx[0:split], vx_errors], axis=0),
np.concatenate([file_ids[0:split], file_ids_errors], axis=0))
ds_train_temp.shuffle_data()
ds_validation_temp.populate(inputs[split:n_data], outputs[split:n_data], vx[split:n_data], file_ids[split:n_data])
ds_validation_temp.shuffle_data()
print ds_train_temp.n_data
t.init()
print evolution
###### Plot the records
# save_records_plot(self.path, [err_training, err_validation, err_testing], "err", t.n_train_batches, "upper right")
# save_records_plot(self.path, [dice_testing], "dice", t.n_train_batches, "lower right")
###### Save the network
net.save_parameters(self.path + "net.net")
if __name__ == '__main__':
exp_name = "test_iter"
data_path = "./datasets/test_iter_f1/"
exp = ExperimentBrain(exp_name, data_path)
exp.run()