Switch to unified view

a b/src/LiviaNet/generateNetwork.py
1
""" 
2
Copyright (c) 2016, Jose Dolz .All rights reserved.
3
4
Redistribution and use in source and binary forms, with or without modification,
5
are permitted provided that the following conditions are met:
6
7
    1. Redistributions of source code must retain the above copyright notice,
8
       this list of conditions and the following disclaimer.
9
    2. Redistributions in binary form must reproduce the above copyright notice,
10
       this list of conditions and the following disclaimer in the documentation
11
       and/or other materials provided with the distribution.
12
13
    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
15
    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
16
    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
17
    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18
    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20
    OTHER DEALINGS IN THE SOFTWARE.
21
22
Jose Dolz. Dec, 2016.
23
email: jose.dolz.upv@gmail.com
24
LIVIA Department, ETS, Montreal.
25
"""
26
27
import pdb
28
import os
29
30
from LiviaNet import LiviaNet3D
31
from Modules.General.Utils import dump_model_to_gzip_file
32
from Modules.General.Utils import makeFolder
33
from Modules.Parsers.parsersUtils import parserConfigIni
34
35
36
def generateNetwork(configIniName) :   
37
38
    myParserConfigIni = parserConfigIni()
39
40
    myParserConfigIni.readConfigIniFile(configIniName,0)
41
    print " **********************  Starting creation model **********************"
42
    print " ------------------------ General ------------------------ "
43
    print " - Network name: {}".format(myParserConfigIni.networkName)
44
    print " - Folder to save the outputs: {}".format(myParserConfigIni.folderName)
45
    print " ------------------------ CNN Architecture ------------------------  "
46
    print " - Number of classes: {}".format(myParserConfigIni.n_classes)
47
    print " - Layers: {}".format(myParserConfigIni.layers)
48
    print " - Kernel sizes: {}".format(myParserConfigIni.kernels)
49
50
    print " - Intermediate connected CNN layers: {}".format(myParserConfigIni.intermediate_ConnectedLayers)
51
   
52
    print " - Pooling: {}".format(myParserConfigIni.pooling_scales)
53
    print " - Dropout: {}".format(myParserConfigIni.dropout_Rates)
54
    
55
    def Linear():
56
        print " --- Activation function: Linear"
57
 
58
    def ReLU():
59
        print " --- Activation function: ReLU"
60
 
61
    def PReLU():
62
        print " --- Activation function: PReLU"
63
64
    def LeakyReLU():
65
        print " --- Activation function: Leaky ReLU"
66
                  
67
    printActivationFunction = {0 : Linear,
68
                               1 : ReLU,
69
                               2 : PReLU,
70
                               3 : LeakyReLU}
71
72
    printActivationFunction[myParserConfigIni.activationType]()
73
        
74
    def Random(layerType):
75
        print " --- Weights initialization (" +layerType+ " Layers): Random"
76
 
77
    def Delving(layerType):
78
        print " --- Weights initialization (" +layerType+ " Layers): Delving"
79
 
80
    def PreTrained(layerType):
81
        print " --- Weights initialization (" +layerType+ " Layers): PreTrained"
82
        
83
    printweight_Initialization_CNN = {0 : Random,
84
                                      1 : Delving,
85
                                      2 : PreTrained}
86
                               
87
    printweight_Initialization_CNN[myParserConfigIni.weight_Initialization_CNN]('CNN')
88
    printweight_Initialization_CNN[myParserConfigIni.weight_Initialization_FCN]('FCN')
89
90
    print " ------------------------ Training Parameters ------------------------  "
91
    if len(myParserConfigIni.learning_rate) == 1:
92
        print " - Learning rate: {}".format(myParserConfigIni.learning_rate)
93
    else:
94
        for i in xrange(len(myParserConfigIni.learning_rate)):
95
            print " - Learning rate at layer {} : {} ".format(str(i+1),myParserConfigIni.learning_rate[i])
96
    
97
    print " - Batch size: {}".format(myParserConfigIni.batch_size)
98
99
    if myParserConfigIni.applyBatchNorm == True:
100
        print " - Apply batch normalization in {} epochs".format(myParserConfigIni.BatchNormEpochs)
101
        
102
    print " ------------------------ Size of samples ------------------------  "
103
    print " - Training: {}".format(myParserConfigIni.sampleSize_Train)
104
    print " - Testing: {}".format(myParserConfigIni.sampleSize_Test)
105
106
    # --------------- Create my LiviaNet3D object  --------------- 
107
    myLiviaNet3D = LiviaNet3D()
108
    
109
    # --------------- Create the whole architecture (Conv layers + fully connected layers + classification layer)  --------------- 
110
    myLiviaNet3D.createNetwork(myParserConfigIni.networkName,
111
                               myParserConfigIni.folderName,
112
                               myParserConfigIni.layers,
113
                               myParserConfigIni.kernels,
114
                               myParserConfigIni.intermediate_ConnectedLayers,
115
                               myParserConfigIni.n_classes,
116
                               myParserConfigIni.sampleSize_Train,
117
                               myParserConfigIni.sampleSize_Test,
118
                               myParserConfigIni.batch_size,
119
                               myParserConfigIni.applyBatchNorm,
120
                               myParserConfigIni.BatchNormEpochs,
121
                               myParserConfigIni.activationType,
122
                               myParserConfigIni.dropout_Rates,
123
                               myParserConfigIni.pooling_scales,
124
                               myParserConfigIni.weight_Initialization_CNN,
125
                               myParserConfigIni.weight_Initialization_FCN,
126
                               myParserConfigIni.weightsFolderName,
127
                               myParserConfigIni.weightsTrainedIdx,
128
                               myParserConfigIni.tempSoftMax
129
                               )
130
                               # TODO: Specify also the weights if pre-trained
131
                               
132
                          
133
    #  ---------------  Initialize all the training parameters  --------------- 
134
    myLiviaNet3D.initTrainingParameters(myParserConfigIni.costFunction,
135
                                        myParserConfigIni.L1_reg_C,
136
                                        myParserConfigIni.L2_reg_C,
137
                                        myParserConfigIni.learning_rate,
138
                                        myParserConfigIni.momentumType,
139
                                        myParserConfigIni.momentumValue,
140
                                        myParserConfigIni.momentumNormalized,
141
                                        myParserConfigIni.optimizerType,
142
                                        myParserConfigIni.rho_RMSProp,
143
                                        myParserConfigIni.epsilon_RMSProp
144
                                        )
145
   
146
    # ---------------  Compile the functions (Training/Validation/Testing) --------------- 
147
    myLiviaNet3D.compileTheanoFunctions()
148
149
    #  --------------- Save the model --------------- 
150
    # Generate folders to store the model
151
    BASE_DIR  = os.getcwd()
152
    path_Temp = os.path.join(BASE_DIR,'outputFiles')
153
    # For the networks
154
    netFolderName  = os.path.join(path_Temp,myParserConfigIni.folderName)
155
    netFolderName  = os.path.join(netFolderName,'Networks')
156
   
157
    # For the predictions
158
    predlFolderName    = os.path.join(path_Temp,myParserConfigIni.folderName)
159
    predlFolderName    = os.path.join(predlFolderName,'Pred')
160
    predValFolderName  = os.path.join(predlFolderName,'Validation')
161
    predTestFolderName = os.path.join(predlFolderName,'Testing')
162
   
163
    makeFolder(netFolderName, "Networks")
164
    makeFolder(predValFolderName, "to store predictions (Validation)")
165
    makeFolder(predTestFolderName, "to store predictions (Testing)")
166
167
    modelFileName = netFolderName + "/" + myParserConfigIni.networkName + "_Epoch0"
168
    dump_model_to_gzip_file(myLiviaNet3D, modelFileName)
169
    
170
    strFinal =  " Network model saved in " + netFolderName + " as " + myParserConfigIni.networkName + "_Epoch0"
171
    print  strFinal
172
    
173
    return modelFileName
174
   
175