|
a |
|
b/src/LiviaNet/generateNetwork.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2016, Jose Dolz .All rights reserved. |
|
|
3 |
|
|
|
4 |
Redistribution and use in source and binary forms, with or without modification, |
|
|
5 |
are permitted provided that the following conditions are met: |
|
|
6 |
|
|
|
7 |
1. Redistributions of source code must retain the above copyright notice, |
|
|
8 |
this list of conditions and the following disclaimer. |
|
|
9 |
2. Redistributions in binary form must reproduce the above copyright notice, |
|
|
10 |
this list of conditions and the following disclaimer in the documentation |
|
|
11 |
and/or other materials provided with the distribution. |
|
|
12 |
|
|
|
13 |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
|
|
14 |
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES |
|
|
15 |
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
|
|
16 |
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT |
|
|
17 |
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, |
|
|
18 |
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
|
19 |
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR |
|
|
20 |
OTHER DEALINGS IN THE SOFTWARE. |
|
|
21 |
|
|
|
22 |
Jose Dolz. Dec, 2016. |
|
|
23 |
email: jose.dolz.upv@gmail.com |
|
|
24 |
LIVIA Department, ETS, Montreal. |
|
|
25 |
""" |
|
|
26 |
|
|
|
27 |
import pdb |
|
|
28 |
import os |
|
|
29 |
|
|
|
30 |
from LiviaNet import LiviaNet3D |
|
|
31 |
from Modules.General.Utils import dump_model_to_gzip_file |
|
|
32 |
from Modules.General.Utils import makeFolder |
|
|
33 |
from Modules.Parsers.parsersUtils import parserConfigIni |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def generateNetwork(configIniName) : |
|
|
37 |
|
|
|
38 |
myParserConfigIni = parserConfigIni() |
|
|
39 |
|
|
|
40 |
myParserConfigIni.readConfigIniFile(configIniName,0) |
|
|
41 |
print " ********************** Starting creation model **********************" |
|
|
42 |
print " ------------------------ General ------------------------ " |
|
|
43 |
print " - Network name: {}".format(myParserConfigIni.networkName) |
|
|
44 |
print " - Folder to save the outputs: {}".format(myParserConfigIni.folderName) |
|
|
45 |
print " ------------------------ CNN Architecture ------------------------ " |
|
|
46 |
print " - Number of classes: {}".format(myParserConfigIni.n_classes) |
|
|
47 |
print " - Layers: {}".format(myParserConfigIni.layers) |
|
|
48 |
print " - Kernel sizes: {}".format(myParserConfigIni.kernels) |
|
|
49 |
|
|
|
50 |
print " - Intermediate connected CNN layers: {}".format(myParserConfigIni.intermediate_ConnectedLayers) |
|
|
51 |
|
|
|
52 |
print " - Pooling: {}".format(myParserConfigIni.pooling_scales) |
|
|
53 |
print " - Dropout: {}".format(myParserConfigIni.dropout_Rates) |
|
|
54 |
|
|
|
55 |
def Linear(): |
|
|
56 |
print " --- Activation function: Linear" |
|
|
57 |
|
|
|
58 |
def ReLU(): |
|
|
59 |
print " --- Activation function: ReLU" |
|
|
60 |
|
|
|
61 |
def PReLU(): |
|
|
62 |
print " --- Activation function: PReLU" |
|
|
63 |
|
|
|
64 |
def LeakyReLU(): |
|
|
65 |
print " --- Activation function: Leaky ReLU" |
|
|
66 |
|
|
|
67 |
printActivationFunction = {0 : Linear, |
|
|
68 |
1 : ReLU, |
|
|
69 |
2 : PReLU, |
|
|
70 |
3 : LeakyReLU} |
|
|
71 |
|
|
|
72 |
printActivationFunction[myParserConfigIni.activationType]() |
|
|
73 |
|
|
|
74 |
def Random(layerType): |
|
|
75 |
print " --- Weights initialization (" +layerType+ " Layers): Random" |
|
|
76 |
|
|
|
77 |
def Delving(layerType): |
|
|
78 |
print " --- Weights initialization (" +layerType+ " Layers): Delving" |
|
|
79 |
|
|
|
80 |
def PreTrained(layerType): |
|
|
81 |
print " --- Weights initialization (" +layerType+ " Layers): PreTrained" |
|
|
82 |
|
|
|
83 |
printweight_Initialization_CNN = {0 : Random, |
|
|
84 |
1 : Delving, |
|
|
85 |
2 : PreTrained} |
|
|
86 |
|
|
|
87 |
printweight_Initialization_CNN[myParserConfigIni.weight_Initialization_CNN]('CNN') |
|
|
88 |
printweight_Initialization_CNN[myParserConfigIni.weight_Initialization_FCN]('FCN') |
|
|
89 |
|
|
|
90 |
print " ------------------------ Training Parameters ------------------------ " |
|
|
91 |
if len(myParserConfigIni.learning_rate) == 1: |
|
|
92 |
print " - Learning rate: {}".format(myParserConfigIni.learning_rate) |
|
|
93 |
else: |
|
|
94 |
for i in xrange(len(myParserConfigIni.learning_rate)): |
|
|
95 |
print " - Learning rate at layer {} : {} ".format(str(i+1),myParserConfigIni.learning_rate[i]) |
|
|
96 |
|
|
|
97 |
print " - Batch size: {}".format(myParserConfigIni.batch_size) |
|
|
98 |
|
|
|
99 |
if myParserConfigIni.applyBatchNorm == True: |
|
|
100 |
print " - Apply batch normalization in {} epochs".format(myParserConfigIni.BatchNormEpochs) |
|
|
101 |
|
|
|
102 |
print " ------------------------ Size of samples ------------------------ " |
|
|
103 |
print " - Training: {}".format(myParserConfigIni.sampleSize_Train) |
|
|
104 |
print " - Testing: {}".format(myParserConfigIni.sampleSize_Test) |
|
|
105 |
|
|
|
106 |
# --------------- Create my LiviaNet3D object --------------- |
|
|
107 |
myLiviaNet3D = LiviaNet3D() |
|
|
108 |
|
|
|
109 |
# --------------- Create the whole architecture (Conv layers + fully connected layers + classification layer) --------------- |
|
|
110 |
myLiviaNet3D.createNetwork(myParserConfigIni.networkName, |
|
|
111 |
myParserConfigIni.folderName, |
|
|
112 |
myParserConfigIni.layers, |
|
|
113 |
myParserConfigIni.kernels, |
|
|
114 |
myParserConfigIni.intermediate_ConnectedLayers, |
|
|
115 |
myParserConfigIni.n_classes, |
|
|
116 |
myParserConfigIni.sampleSize_Train, |
|
|
117 |
myParserConfigIni.sampleSize_Test, |
|
|
118 |
myParserConfigIni.batch_size, |
|
|
119 |
myParserConfigIni.applyBatchNorm, |
|
|
120 |
myParserConfigIni.BatchNormEpochs, |
|
|
121 |
myParserConfigIni.activationType, |
|
|
122 |
myParserConfigIni.dropout_Rates, |
|
|
123 |
myParserConfigIni.pooling_scales, |
|
|
124 |
myParserConfigIni.weight_Initialization_CNN, |
|
|
125 |
myParserConfigIni.weight_Initialization_FCN, |
|
|
126 |
myParserConfigIni.weightsFolderName, |
|
|
127 |
myParserConfigIni.weightsTrainedIdx, |
|
|
128 |
myParserConfigIni.tempSoftMax |
|
|
129 |
) |
|
|
130 |
# TODO: Specify also the weights if pre-trained |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
# --------------- Initialize all the training parameters --------------- |
|
|
134 |
myLiviaNet3D.initTrainingParameters(myParserConfigIni.costFunction, |
|
|
135 |
myParserConfigIni.L1_reg_C, |
|
|
136 |
myParserConfigIni.L2_reg_C, |
|
|
137 |
myParserConfigIni.learning_rate, |
|
|
138 |
myParserConfigIni.momentumType, |
|
|
139 |
myParserConfigIni.momentumValue, |
|
|
140 |
myParserConfigIni.momentumNormalized, |
|
|
141 |
myParserConfigIni.optimizerType, |
|
|
142 |
myParserConfigIni.rho_RMSProp, |
|
|
143 |
myParserConfigIni.epsilon_RMSProp |
|
|
144 |
) |
|
|
145 |
|
|
|
146 |
# --------------- Compile the functions (Training/Validation/Testing) --------------- |
|
|
147 |
myLiviaNet3D.compileTheanoFunctions() |
|
|
148 |
|
|
|
149 |
# --------------- Save the model --------------- |
|
|
150 |
# Generate folders to store the model |
|
|
151 |
BASE_DIR = os.getcwd() |
|
|
152 |
path_Temp = os.path.join(BASE_DIR,'outputFiles') |
|
|
153 |
# For the networks |
|
|
154 |
netFolderName = os.path.join(path_Temp,myParserConfigIni.folderName) |
|
|
155 |
netFolderName = os.path.join(netFolderName,'Networks') |
|
|
156 |
|
|
|
157 |
# For the predictions |
|
|
158 |
predlFolderName = os.path.join(path_Temp,myParserConfigIni.folderName) |
|
|
159 |
predlFolderName = os.path.join(predlFolderName,'Pred') |
|
|
160 |
predValFolderName = os.path.join(predlFolderName,'Validation') |
|
|
161 |
predTestFolderName = os.path.join(predlFolderName,'Testing') |
|
|
162 |
|
|
|
163 |
makeFolder(netFolderName, "Networks") |
|
|
164 |
makeFolder(predValFolderName, "to store predictions (Validation)") |
|
|
165 |
makeFolder(predTestFolderName, "to store predictions (Testing)") |
|
|
166 |
|
|
|
167 |
modelFileName = netFolderName + "/" + myParserConfigIni.networkName + "_Epoch0" |
|
|
168 |
dump_model_to_gzip_file(myLiviaNet3D, modelFileName) |
|
|
169 |
|
|
|
170 |
strFinal = " Network model saved in " + netFolderName + " as " + myParserConfigIni.networkName + "_Epoch0" |
|
|
171 |
print strFinal |
|
|
172 |
|
|
|
173 |
return modelFileName |
|
|
174 |
|
|
|
175 |
|