|
a |
|
b/src/LiviaNet/startTraining.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2016, Jose Dolz .All rights reserved. |
|
|
3 |
|
|
|
4 |
Redistribution and use in source and binary forms, with or without modification, |
|
|
5 |
are permitted provided that the following conditions are met: |
|
|
6 |
|
|
|
7 |
1. Redistributions of source code must retain the above copyright notice, |
|
|
8 |
this list of conditions and the following disclaimer. |
|
|
9 |
2. Redistributions in binary form must reproduce the above copyright notice, |
|
|
10 |
this list of conditions and the following disclaimer in the documentation |
|
|
11 |
and/or other materials provided with the distribution. |
|
|
12 |
|
|
|
13 |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
|
|
14 |
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES |
|
|
15 |
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
|
|
16 |
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT |
|
|
17 |
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, |
|
|
18 |
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
|
19 |
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR |
|
|
20 |
OTHER DEALINGS IN THE SOFTWARE. |
|
|
21 |
|
|
|
22 |
Jose Dolz. Dec, 2016. |
|
|
23 |
email: jose.dolz.upv@gmail.com |
|
|
24 |
LIVIA Department, ETS, Montreal. |
|
|
25 |
""" |
|
|
26 |
|
|
|
27 |
import sys |
|
|
28 |
import time |
|
|
29 |
import numpy as np |
|
|
30 |
import random |
|
|
31 |
import math |
|
|
32 |
import os |
|
|
33 |
|
|
|
34 |
from Modules.General.Utils import getImagesSet |
|
|
35 |
from Modules.General.Utils import dump_model_to_gzip_file |
|
|
36 |
from Modules.General.Utils import load_model_from_gzip_file |
|
|
37 |
from Modules.General.Evaluation import computeDice |
|
|
38 |
from Modules.IO.sampling import getSamplesSubepoch |
|
|
39 |
from Modules.Parsers.parsersUtils import parserConfigIni |
|
|
40 |
from startTesting import segmentVolume |
|
|
41 |
import pdb |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def startTraining(networkModelName,configIniName): |
|
|
45 |
print " ************************************************ STARTING TRAINING **************************************************" |
|
|
46 |
print " ********************** Starting training model (Reading parameters) **********************" |
|
|
47 |
|
|
|
48 |
myParserConfigIni = parserConfigIni() |
|
|
49 |
|
|
|
50 |
myParserConfigIni.readConfigIniFile(configIniName,1) |
|
|
51 |
|
|
|
52 |
# Image type (0: Nifti, 1: Matlab) |
|
|
53 |
imageType = myParserConfigIni.imageTypesTrain |
|
|
54 |
|
|
|
55 |
print (" --- Do training in {} epochs with {} subEpochs each...".format(myParserConfigIni.numberOfEpochs, myParserConfigIni.numberOfSubEpochs)) |
|
|
56 |
print "-------- Reading Images names used in training/validation -------------" |
|
|
57 |
|
|
|
58 |
# -- Get list of images used for training -- # |
|
|
59 |
(imageNames_Train, names_Train) = getImagesSet(myParserConfigIni.imagesFolder,myParserConfigIni.indexesForTraining) # Images |
|
|
60 |
(groundTruthNames_Train, gt_names_Train) = getImagesSet(myParserConfigIni.GroundTruthFolder,myParserConfigIni.indexesForTraining) # Ground truth |
|
|
61 |
(roiNames_Train, roi_names_Train) = getImagesSet(myParserConfigIni.ROIFolder,myParserConfigIni.indexesForTraining) # ROI |
|
|
62 |
|
|
|
63 |
# -- Get list of images used for validation -- # |
|
|
64 |
(imageNames_Val, names_Val) = getImagesSet(myParserConfigIni.imagesFolder,myParserConfigIni.indexesForValidation) # Images |
|
|
65 |
(groundTruthNames_Val, gt_names_Val) = getImagesSet(myParserConfigIni.GroundTruthFolder,myParserConfigIni.indexesForValidation) # Ground truth |
|
|
66 |
(roiNames_Val, roi_names_Val) = getImagesSet(myParserConfigIni.ROIFolder,myParserConfigIni.indexesForValidation) # ROI |
|
|
67 |
|
|
|
68 |
# Print names |
|
|
69 |
print " ================== Images for training ================" |
|
|
70 |
for i in range(0,len(names_Train)): |
|
|
71 |
if len(roi_names_Train) > 0: |
|
|
72 |
print(" Image({}): {} | GT: {} | ROI {} ".format(i,names_Train[i], gt_names_Train[i], roi_names_Train[i] )) |
|
|
73 |
else: |
|
|
74 |
print(" Image({}): {} | GT: {} ".format(i,names_Train[i], gt_names_Train[i] )) |
|
|
75 |
print " ================== Images for validation ================" |
|
|
76 |
for i in range(0,len(names_Val)): |
|
|
77 |
if len(roi_names_Train) > 0: |
|
|
78 |
print(" Image({}): {} | GT: {} | ROI {} ".format(i,names_Val[i], gt_names_Val[i], roi_names_Val[i] )) |
|
|
79 |
else: |
|
|
80 |
print(" Image({}): {} | GT: {} ".format(i,names_Val[i], gt_names_Val[i])) |
|
|
81 |
print " ===============================================================" |
|
|
82 |
|
|
|
83 |
# --------------- Load my LiviaNet3D object --------------- |
|
|
84 |
print (" ... Loading model from {}".format(networkModelName)) |
|
|
85 |
myLiviaNet3D = load_model_from_gzip_file(networkModelName) |
|
|
86 |
print " ... Network architecture successfully loaded...." |
|
|
87 |
|
|
|
88 |
# Asign parameters to loaded Net |
|
|
89 |
myLiviaNet3D.numberOfEpochs = myParserConfigIni.numberOfEpochs |
|
|
90 |
myLiviaNet3D.numberOfSubEpochs = myParserConfigIni.numberOfSubEpochs |
|
|
91 |
myLiviaNet3D.numberOfSamplesSupEpoch = myParserConfigIni.numberOfSamplesSupEpoch |
|
|
92 |
myLiviaNet3D.firstEpochChangeLR = myParserConfigIni.firstEpochChangeLR |
|
|
93 |
myLiviaNet3D.frequencyChangeLR = myParserConfigIni.frequencyChangeLR |
|
|
94 |
|
|
|
95 |
numberOfEpochs = myLiviaNet3D.numberOfEpochs |
|
|
96 |
numberOfSubEpochs = myLiviaNet3D.numberOfSubEpochs |
|
|
97 |
numberOfSamplesSupEpoch = myLiviaNet3D.numberOfSamplesSupEpoch |
|
|
98 |
|
|
|
99 |
# --------------- -------------- --------------- |
|
|
100 |
# --------------- Start TRAINING --------------- |
|
|
101 |
# --------------- -------------- --------------- |
|
|
102 |
# Get sample dimension values |
|
|
103 |
receptiveField = myLiviaNet3D.receptiveField |
|
|
104 |
sampleSize_Train = myLiviaNet3D.sampleSize_Train |
|
|
105 |
|
|
|
106 |
trainingCost = [] |
|
|
107 |
|
|
|
108 |
if myParserConfigIni.applyPadding == 1: |
|
|
109 |
applyPadding = True |
|
|
110 |
else: |
|
|
111 |
applyPadding = False |
|
|
112 |
|
|
|
113 |
learningRateModifiedEpoch = 0 |
|
|
114 |
|
|
|
115 |
# Run over all the (remaining) epochs and subepochs |
|
|
116 |
for e_i in xrange(numberOfEpochs): |
|
|
117 |
# Recover last trained epoch |
|
|
118 |
numberOfEpochsTrained = myLiviaNet3D.numberOfEpochsTrained |
|
|
119 |
|
|
|
120 |
print(" ============== EPOCH: {}/{} =================".format(numberOfEpochsTrained+1,numberOfEpochs)) |
|
|
121 |
|
|
|
122 |
costsOfEpoch = [] |
|
|
123 |
|
|
|
124 |
for subE_i in xrange(numberOfSubEpochs): |
|
|
125 |
epoch_nr = subE_i+1 |
|
|
126 |
print (" --- SubEPOCH: {}/{}".format(epoch_nr,myLiviaNet3D.numberOfSubEpochs)) |
|
|
127 |
|
|
|
128 |
# Get all the samples that will be used in this sub-epoch |
|
|
129 |
[imagesSamplesAll, |
|
|
130 |
gt_samplesAll] = getSamplesSubepoch(numberOfSamplesSupEpoch, |
|
|
131 |
imageNames_Train, |
|
|
132 |
groundTruthNames_Train, |
|
|
133 |
roiNames_Train, |
|
|
134 |
imageType, |
|
|
135 |
sampleSize_Train, |
|
|
136 |
receptiveField, |
|
|
137 |
applyPadding |
|
|
138 |
) |
|
|
139 |
|
|
|
140 |
# Variable that will contain weights for the cost function |
|
|
141 |
# --- In its current implementation, all the classes have the same weight |
|
|
142 |
weightsCostFunction = np.ones(myLiviaNet3D.n_classes, dtype='float32') |
|
|
143 |
|
|
|
144 |
numberBatches = len(imagesSamplesAll) / myLiviaNet3D.batch_Size |
|
|
145 |
|
|
|
146 |
myLiviaNet3D.trainingData_x.set_value(imagesSamplesAll, borrow=True) |
|
|
147 |
myLiviaNet3D.trainingData_y.set_value(gt_samplesAll, borrow=True) |
|
|
148 |
|
|
|
149 |
costsOfBatches = [] |
|
|
150 |
evalResultsSubepoch = np.zeros([ myLiviaNet3D.n_classes, 4 ], dtype="int32") |
|
|
151 |
|
|
|
152 |
for b_i in xrange(numberBatches): |
|
|
153 |
# TODO: Make a line that adds a point at each trained batch (Or percentage being updated) |
|
|
154 |
costErrors = myLiviaNet3D.networkModel_Train(b_i, weightsCostFunction) |
|
|
155 |
meanBatchCostError = costErrors[0] |
|
|
156 |
costsOfBatches.append(meanBatchCostError) |
|
|
157 |
myLiviaNet3D.updateLayersMatricesBatchNorm() |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
#======== Calculate and Report accuracy over subepoch |
|
|
161 |
meanCostOfSubepoch = sum(costsOfBatches) / float(numberBatches) |
|
|
162 |
print(" ---------- Cost of this subEpoch: {}".format(meanCostOfSubepoch)) |
|
|
163 |
|
|
|
164 |
# Release data |
|
|
165 |
myLiviaNet3D.trainingData_x.set_value(np.zeros([1,1,1,1,1], dtype="float32")) |
|
|
166 |
myLiviaNet3D.trainingData_y.set_value(np.zeros([1,1,1,1], dtype="float32")) |
|
|
167 |
|
|
|
168 |
# Get mean cost epoch |
|
|
169 |
costsOfEpoch.append(meanCostOfSubepoch) |
|
|
170 |
|
|
|
171 |
meanCostOfEpoch = sum(costsOfEpoch) / float(numberOfSubEpochs) |
|
|
172 |
|
|
|
173 |
# Include the epoch cost to the main training cost and update current mean |
|
|
174 |
trainingCost.append(meanCostOfEpoch) |
|
|
175 |
currentMeanCost = sum(trainingCost) / float(str( e_i + 1)) |
|
|
176 |
|
|
|
177 |
print(" ---------- Training on Epoch #" + str(e_i) + " finished ----------" ) |
|
|
178 |
print(" ---------- Cost of Epoch: {} / Mean training error {}".format(meanCostOfEpoch,currentMeanCost)) |
|
|
179 |
print(" -------------------------------------------------------- " ) |
|
|
180 |
|
|
|
181 |
# ------------- Update Learning Rate if required ----------------# |
|
|
182 |
|
|
|
183 |
if e_i >= myLiviaNet3D.firstEpochChangeLR : |
|
|
184 |
if learningRateModifiedEpoch == 0: |
|
|
185 |
currentLR = myLiviaNet3D.learning_rate.get_value() |
|
|
186 |
newLR = currentLR / 2.0 |
|
|
187 |
myLiviaNet3D.learning_rate.set_value(newLR) |
|
|
188 |
print(" ... Learning rate has been changed from {} to {}".format(currentLR, newLR)) |
|
|
189 |
learningRateModifiedEpoch = e_i |
|
|
190 |
else: |
|
|
191 |
if (e_i) == (learningRateModifiedEpoch + myLiviaNet3D.frequencyChangeLR): |
|
|
192 |
currentLR = myLiviaNet3D.learning_rate.get_value() |
|
|
193 |
newLR = currentLR / 2.0 |
|
|
194 |
myLiviaNet3D.learning_rate.set_value(newLR) |
|
|
195 |
print(" ... Learning rate has been changed from {} to {}".format(currentLR, newLR)) |
|
|
196 |
learningRateModifiedEpoch = e_i |
|
|
197 |
|
|
|
198 |
# ---------------------- Start validation ---------------------- # |
|
|
199 |
|
|
|
200 |
numberImagesToSegment = len(imageNames_Val) |
|
|
201 |
print(" ********************** Starting validation **********************") |
|
|
202 |
|
|
|
203 |
# Run over the images to segment |
|
|
204 |
for i_d in xrange(numberImagesToSegment) : |
|
|
205 |
print("------------- Segmenting subject: {} ....total: {}/{}... -------------".format(names_Val[i_d],str(i_d+1),str(numberImagesToSegment))) |
|
|
206 |
strideValues = myLiviaNet3D.lastLayer.outputShapeTest[2:] |
|
|
207 |
|
|
|
208 |
segmentVolume(myLiviaNet3D, |
|
|
209 |
i_d, |
|
|
210 |
imageNames_Val, # Full path |
|
|
211 |
names_Val, # Only image name |
|
|
212 |
groundTruthNames_Val, |
|
|
213 |
roiNames_Val, |
|
|
214 |
imageType, |
|
|
215 |
applyPadding, |
|
|
216 |
receptiveField, |
|
|
217 |
sampleSize_Train, |
|
|
218 |
strideValues, |
|
|
219 |
myLiviaNet3D.batch_Size, |
|
|
220 |
0 # Validation (0) or testing (1) |
|
|
221 |
) |
|
|
222 |
|
|
|
223 |
|
|
|
224 |
print(" ********************** Validation DONE ********************** ") |
|
|
225 |
|
|
|
226 |
# ------ In this point the training is done at Epoch n ---------# |
|
|
227 |
# Increase number of epochs trained |
|
|
228 |
myLiviaNet3D.numberOfEpochsTrained += 1 |
|
|
229 |
|
|
|
230 |
# --------------- Save the model --------------- |
|
|
231 |
BASE_DIR = os.getcwd() |
|
|
232 |
path_Temp = os.path.join(BASE_DIR,'outputFiles') |
|
|
233 |
netFolderName = os.path.join(path_Temp,myLiviaNet3D.folderName) |
|
|
234 |
netFolderName = os.path.join(netFolderName,'Networks') |
|
|
235 |
|
|
|
236 |
modelFileName = netFolderName + "/" + myLiviaNet3D.networkName + "_Epoch" + str (myLiviaNet3D.numberOfEpochsTrained) |
|
|
237 |
dump_model_to_gzip_file(myLiviaNet3D, modelFileName) |
|
|
238 |
|
|
|
239 |
strFinal = " Network model saved in " + netFolderName + " as " + myLiviaNet3D.networkName + "_Epoch" + str (myLiviaNet3D.numberOfEpochsTrained) |
|
|
240 |
print strFinal |
|
|
241 |
|
|
|
242 |
print("................ The whole Training is done.....") |
|
|
243 |
print(" ************************************************************************************ ") |