a b/src/LiviaNet/startTraining.py
1
""" 
2
Copyright (c) 2016, Jose Dolz .All rights reserved.
3
4
Redistribution and use in source and binary forms, with or without modification,
5
are permitted provided that the following conditions are met:
6
7
    1. Redistributions of source code must retain the above copyright notice,
8
       this list of conditions and the following disclaimer.
9
    2. Redistributions in binary form must reproduce the above copyright notice,
10
       this list of conditions and the following disclaimer in the documentation
11
       and/or other materials provided with the distribution.
12
13
    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
15
    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
16
    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
17
    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18
    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20
    OTHER DEALINGS IN THE SOFTWARE.
21
22
Jose Dolz. Dec, 2016.
23
email: jose.dolz.upv@gmail.com
24
LIVIA Department, ETS, Montreal.
25
"""
26
27
import sys
28
import time
29
import numpy as np
30
import random
31
import math
32
import os
33
                  
34
from Modules.General.Utils import getImagesSet
35
from Modules.General.Utils import dump_model_to_gzip_file
36
from Modules.General.Utils import load_model_from_gzip_file
37
from Modules.General.Evaluation import computeDice
38
from Modules.IO.sampling import getSamplesSubepoch
39
from Modules.Parsers.parsersUtils import parserConfigIni
40
from startTesting import segmentVolume
41
import pdb
42
43
44
def startTraining(networkModelName,configIniName):
45
    print " ************************************************  STARTING TRAINING **************************************************"
46
    print " **********************  Starting training model (Reading parameters) **********************"
47
48
    myParserConfigIni = parserConfigIni()
49
   
50
    myParserConfigIni.readConfigIniFile(configIniName,1)
51
    
52
    # Image type (0: Nifti, 1: Matlab)
53
    imageType = myParserConfigIni.imageTypesTrain
54
55
    print (" --- Do training in {} epochs with {} subEpochs each...".format(myParserConfigIni.numberOfEpochs, myParserConfigIni.numberOfSubEpochs))
56
    print "-------- Reading Images names used in training/validation -------------"
57
58
    # -- Get list of images used for training -- #
59
    (imageNames_Train, names_Train)          = getImagesSet(myParserConfigIni.imagesFolder,myParserConfigIni.indexesForTraining)  # Images
60
    (groundTruthNames_Train, gt_names_Train) = getImagesSet(myParserConfigIni.GroundTruthFolder,myParserConfigIni.indexesForTraining) # Ground truth
61
    (roiNames_Train, roi_names_Train)        = getImagesSet(myParserConfigIni.ROIFolder,myParserConfigIni.indexesForTraining) # ROI
62
    
63
    # -- Get list of images used for validation -- #
64
    (imageNames_Val, names_Val)          = getImagesSet(myParserConfigIni.imagesFolder,myParserConfigIni.indexesForValidation)  # Images
65
    (groundTruthNames_Val, gt_names_Val) = getImagesSet(myParserConfigIni.GroundTruthFolder,myParserConfigIni.indexesForValidation) # Ground truth
66
    (roiNames_Val, roi_names_Val)        = getImagesSet(myParserConfigIni.ROIFolder,myParserConfigIni.indexesForValidation) # ROI
67
68
    # Print names
69
    print " ================== Images for training ================"
70
    for i in range(0,len(names_Train)):
71
       if len(roi_names_Train) > 0:
72
            print(" Image({}): {}  |  GT: {}  |  ROI {} ".format(i,names_Train[i], gt_names_Train[i], roi_names_Train[i] ))
73
       else:
74
            print(" Image({}): {}  |  GT: {}  ".format(i,names_Train[i], gt_names_Train[i] ))
75
    print " ================== Images for validation ================"
76
    for i in range(0,len(names_Val)):
77
        if len(roi_names_Train) > 0:
78
            print(" Image({}): {}  |  GT: {}  |  ROI {} ".format(i,names_Val[i], gt_names_Val[i], roi_names_Val[i] ))
79
        else:
80
            print(" Image({}): {}  |  GT: {}  ".format(i,names_Val[i], gt_names_Val[i]))
81
    print " ==============================================================="
82
   
83
    # --------------- Load my LiviaNet3D object  --------------- 
84
    print (" ... Loading model from {}".format(networkModelName))
85
    myLiviaNet3D = load_model_from_gzip_file(networkModelName)
86
    print " ... Network architecture successfully loaded...."
87
88
    # Asign parameters to loaded Net
89
    myLiviaNet3D.numberOfEpochs = myParserConfigIni.numberOfEpochs
90
    myLiviaNet3D.numberOfSubEpochs = myParserConfigIni.numberOfSubEpochs
91
    myLiviaNet3D.numberOfSamplesSupEpoch  = myParserConfigIni.numberOfSamplesSupEpoch
92
    myLiviaNet3D.firstEpochChangeLR  = myParserConfigIni.firstEpochChangeLR
93
    myLiviaNet3D.frequencyChangeLR  = myParserConfigIni.frequencyChangeLR
94
    
95
    numberOfEpochs = myLiviaNet3D.numberOfEpochs
96
    numberOfSubEpochs = myLiviaNet3D.numberOfSubEpochs
97
    numberOfSamplesSupEpoch = myLiviaNet3D.numberOfSamplesSupEpoch
98
    
99
    # --------------- --------------  --------------- 
100
    # --------------- Start TRAINING  --------------- 
101
    # --------------- --------------  --------------- 
102
    # Get sample dimension values
103
    receptiveField = myLiviaNet3D.receptiveField
104
    sampleSize_Train = myLiviaNet3D.sampleSize_Train
105
106
    trainingCost = []
107
108
    if myParserConfigIni.applyPadding == 1:
109
        applyPadding = True
110
    else:
111
        applyPadding = False
112
    
113
    learningRateModifiedEpoch = 0
114
    
115
    # Run over all the (remaining) epochs and subepochs
116
    for e_i in xrange(numberOfEpochs):
117
        # Recover last trained epoch
118
        numberOfEpochsTrained = myLiviaNet3D.numberOfEpochsTrained
119
                                        
120
        print(" ============== EPOCH: {}/{} =================".format(numberOfEpochsTrained+1,numberOfEpochs))
121
122
        costsOfEpoch = []
123
        
124
        for subE_i in xrange(numberOfSubEpochs): 
125
            epoch_nr = subE_i+1
126
            print (" --- SubEPOCH: {}/{}".format(epoch_nr,myLiviaNet3D.numberOfSubEpochs))
127
128
            # Get all the samples that will be used in this sub-epoch
129
            [imagesSamplesAll,
130
            gt_samplesAll] = getSamplesSubepoch(numberOfSamplesSupEpoch,
131
                                                imageNames_Train,
132
                                                groundTruthNames_Train,
133
                                                roiNames_Train,
134
                                                imageType,
135
                                                sampleSize_Train,
136
                                                receptiveField,
137
                                                applyPadding
138
                                                )
139
140
            # Variable that will contain weights for the cost function
141
            # --- In its current implementation, all the classes have the same weight
142
            weightsCostFunction = np.ones(myLiviaNet3D.n_classes, dtype='float32')
143
               
144
            numberBatches = len(imagesSamplesAll) / myLiviaNet3D.batch_Size 
145
            
146
            myLiviaNet3D.trainingData_x.set_value(imagesSamplesAll, borrow=True)
147
            myLiviaNet3D.trainingData_y.set_value(gt_samplesAll, borrow=True)
148
                 
149
            costsOfBatches = []
150
            evalResultsSubepoch = np.zeros([ myLiviaNet3D.n_classes, 4 ], dtype="int32")
151
    
152
            for b_i in xrange(numberBatches):
153
                # TODO: Make a line that adds a point at each trained batch (Or percentage being updated)
154
                costErrors = myLiviaNet3D.networkModel_Train(b_i, weightsCostFunction)
155
                meanBatchCostError = costErrors[0]
156
                costsOfBatches.append(meanBatchCostError)
157
                myLiviaNet3D.updateLayersMatricesBatchNorm() 
158
159
            
160
            #======== Calculate and Report accuracy over subepoch
161
            meanCostOfSubepoch = sum(costsOfBatches) / float(numberBatches)
162
            print(" ---------- Cost of this subEpoch: {}".format(meanCostOfSubepoch))
163
            
164
            # Release data
165
            myLiviaNet3D.trainingData_x.set_value(np.zeros([1,1,1,1,1], dtype="float32"))
166
            myLiviaNet3D.trainingData_y.set_value(np.zeros([1,1,1,1], dtype="float32"))
167
168
            # Get mean cost epoch
169
            costsOfEpoch.append(meanCostOfSubepoch)
170
171
        meanCostOfEpoch =  sum(costsOfEpoch) / float(numberOfSubEpochs)
172
        
173
        # Include the epoch cost to the main training cost and update current mean 
174
        trainingCost.append(meanCostOfEpoch)
175
        currentMeanCost = sum(trainingCost) / float(str( e_i + 1))
176
        
177
        print(" ---------- Training on Epoch #" + str(e_i) + " finished ----------" )
178
        print(" ---------- Cost of Epoch: {} / Mean training error {}".format(meanCostOfEpoch,currentMeanCost))
179
        print(" -------------------------------------------------------- " )
180
        
181
        # ------------- Update Learning Rate if required ----------------#
182
183
        if e_i >= myLiviaNet3D.firstEpochChangeLR :
184
            if learningRateModifiedEpoch == 0:
185
                currentLR = myLiviaNet3D.learning_rate.get_value()
186
                newLR = currentLR / 2.0
187
                myLiviaNet3D.learning_rate.set_value(newLR)
188
                print(" ... Learning rate has been changed from {} to {}".format(currentLR, newLR))
189
                learningRateModifiedEpoch = e_i
190
            else:
191
                if (e_i) == (learningRateModifiedEpoch + myLiviaNet3D.frequencyChangeLR):
192
                    currentLR = myLiviaNet3D.learning_rate.get_value()
193
                    newLR = currentLR / 2.0
194
                    myLiviaNet3D.learning_rate.set_value(newLR)
195
                    print(" ... Learning rate has been changed from {} to {}".format(currentLR, newLR))
196
                    learningRateModifiedEpoch = e_i
197
                
198
        # ---------------------- Start validation ---------------------- #
199
        
200
        numberImagesToSegment = len(imageNames_Val)
201
        print(" ********************** Starting validation **********************")
202
203
        # Run over the images to segment   
204
        for i_d in xrange(numberImagesToSegment) :
205
            print("-------------  Segmenting subject: {} ....total: {}/{}... -------------".format(names_Val[i_d],str(i_d+1),str(numberImagesToSegment)))
206
            strideValues = myLiviaNet3D.lastLayer.outputShapeTest[2:]
207
            
208
            segmentVolume(myLiviaNet3D,
209
                          i_d,
210
                          imageNames_Val,  # Full path
211
                          names_Val,       # Only image name
212
                          groundTruthNames_Val,
213
                          roiNames_Val,
214
                          imageType,
215
                          applyPadding,
216
                          receptiveField, 
217
                          sampleSize_Train,
218
                          strideValues,
219
                          myLiviaNet3D.batch_Size,
220
                          0 # Validation (0) or testing (1)
221
                          )
222
                         
223
       
224
        print(" ********************** Validation DONE ********************** ")
225
226
        # ------ In this point the training is done at Epoch n ---------#
227
        # Increase number of epochs trained
228
        myLiviaNet3D.numberOfEpochsTrained += 1
229
230
        #  --------------- Save the model --------------- 
231
        BASE_DIR = os.getcwd()
232
        path_Temp = os.path.join(BASE_DIR,'outputFiles')
233
        netFolderName = os.path.join(path_Temp,myLiviaNet3D.folderName)
234
        netFolderName  = os.path.join(netFolderName,'Networks')
235
236
        modelFileName = netFolderName + "/" + myLiviaNet3D.networkName + "_Epoch" + str (myLiviaNet3D.numberOfEpochsTrained)
237
        dump_model_to_gzip_file(myLiviaNet3D, modelFileName)
238
 
239
        strFinal =  " Network model saved in " + netFolderName + " as " + myLiviaNet3D.networkName + "_Epoch" + str (myLiviaNet3D.numberOfEpochsTrained)
240
        print  strFinal
241
242
    print("................ The whole Training is done.....")
243
    print(" ************************************************************************************ ")