Switch to unified view

a b/Projects/Caffe/allCNN/Classes/CaffeHelpers.py
1
############################################################################################
2
#
3
# The MIT License (MIT)
4
# 
5
# Peter Moss Acute Myeloid/Lymphoblastic Leukemia AI Research Project
6
# Copyright (C) 2018 Adam Milton-Barker (AdamMiltonBarker.com)
7
# 
8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
# 
15
# The above copyright notice and this permission notice shall be included in
16
# all copies or substantial portions of the Software.
17
# 
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24
# THE SOFTWARE.
25
#
26
# Title:         Caffe Acute Lymphoblastic Leukemia CNN Caffe Helpers
27
# Description:   Common Caffe tools used by the Caffe Acute Lymphoblastic Leukemia CNN
28
# Configuration: Required/Confs.json
29
# Last Modified: 2019-03-10
30
# References:    Based on: ACUTE LEUKEMIA CLASSIFICATION USING CONVOLUTION NEURAL NETWORK 
31
#                IN CLINICAL DECISION SUPPORT SYSTEM
32
#                https://airccj.org/CSCP/vol7/csit77505.pdf
33
#
34
############################################################################################
35
36
import os, sys, random, cv2, lmdb
37
sys.path.append('/home/upsquared/caffe/python')
38
39
from caffe.proto import caffe_pb2
40
41
class CaffeHelpers():
42
    
43
    def __init__(self, confs, helpers, logFile):
44
45
        """
46
        Sets up all default requirements and placeholders needed for the 
47
        Caffe Acute Lymphoblastic Leukemia CNN Helpers.
48
        """
49
50
        self.Helpers = helpers
51
        self.confs = confs
52
        self.logFile = logFile
53
54
        self.labels = None
55
        self.trainLMDB = None
56
57
        self.trainData = []
58
        self.trainData0 = []
59
        self.trainData1 = []
60
        self.valData = []
61
        self.valData0 = []
62
        self.valData1 = []
63
        self.classNames = []
64
65
        self.imSize = (self.confs["Settings"]["Classifier"]["Input"]["imageWidth"], 
66
                       self.confs["Settings"]["Classifier"]["Input"]["imageHeight"])
67
68
        self.negativeTrainAmnt = self.confs["Settings"]["Classifier"]["Data"]["negativeTrainAmnt"]
69
        self.positiveTrainAmnt = self.confs["Settings"]["Classifier"]["Data"]["positiveTrainAmnt"]
70
71
        self.negativeTestAmnt = self.confs["Settings"]["Classifier"]["Data"]["negativeTestAmnt"]
72
        self.positiveTestAmnt = self.confs["Settings"]["Classifier"]["Data"]["positiveTestAmnt"]
73
        
74
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "CaffeHelpers initiated")
75
76
    def deleteLMDB(self):
77
78
        """
79
        Deletes existing LMDB files.
80
        """
81
        
82
        os.system('rm -rf  ' + self.confs["Settings"]["Classifier"]["LMDB"]["train"])
83
        os.system('rm -rf  ' + self.confs["Settings"]["Classifier"]["LMDB"]["validation"])
84
        
85
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Existing LMDB deleted")
86
87
    def sortLabels(self):
88
89
        """
90
        Sorts the training / validation data and labels.
91
        """
92
93
        self.labels = open(self.confs["Settings"]["Classifier"]["Data"]["labels"], "w")
94
95
        for dirName in os.listdir(self.confs["Settings"]["Classifier"]["Model"]["dirData"]):
96
            if dirName == ".ipynb_checkpoints":
97
                continue
98
            path = os.path.join(self.confs["Settings"]["Classifier"]["Model"]["dirData"], dirName)
99
            if os.path.isdir(path):
100
                self.classNames.append(path)
101
                self.labels.write(dirName+"\n")
102
        
103
        self.labels.close()
104
        
105
        self.Helpers.logMessage(self.logFile, "allCNN",  "Status", str(len(self.classNames)) + " label(s) created")
106
107
    def appendDataPaths(self, directory, filename):
108
109
        """
110
        Appends training data paths.
111
        """
112
        
113
        filePath = os.path.join(directory, filename)
114
        cDirectory = os.path.basename(os.path.normpath(directory))
115
116
        image = cv2.imread(filePath)
117
        image = cv2.resize(image, self.imSize)
118
        cv2.imwrite(filePath, image)
119
120
        if cDirectory is self.confs["Settings"]["Classifier"]["Data"]["negativeDir"]:
121
            self.trainData0.append(filePath)
122
        elif cDirectory is self.confs["Settings"]["Classifier"]["Data"]["positiveDir"]:
123
            self.trainData1.append(filePath)
124
125
    def isValidFile(self, filename):
126
127
        """
128
        Checks that input file type is allowed.
129
        """
130
131
        return filename.endswith(tuple(self.confs["Settings"]["Classifier"]["Data"]["validFiles"]))
132
133
    def sortTrainingData(self):
134
135
        """
136
        Sorts the training / validation data
137
        """
138
139
        for directory in self.classNames:
140
            for filename in os.listdir(directory):
141
                if self.isValidFile(filename):
142
                    self.appendDataPaths(directory, filename)
143
                else:
144
                    continue
145
146
    def recreatePaperData(self):
147
148
        """
149
        Recreates the dataset sizes specified in the paper.
150
        """
151
        
152
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Total data size: " + str(len(self.trainData0) + len(self.trainData1)) + " (" + str(len(self.trainData0)) + " + " + str(len(self.trainData1)) + ")")
153
154
        msg = "Recreating negative training size of " + str(self.negativeTrainAmnt) 
155
        msg += " with negative testing size of " + str(self.negativeTestAmnt) 
156
        msg += " and positive training size of " + str(self.positiveTrainAmnt) 
157
        msg += " with positive testing size of " + str(self.positiveTestAmnt)
158
159
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", msg)
160
        
161
        random.shuffle(self.trainData0)
162
        random.shuffle(self.trainData1)
163
164
        trainingData0 = self.trainData0[0:self.negativeTrainAmnt]
165
        trainingData1 = self.trainData1[0:self.positiveTrainAmnt]
166
167
        self.Helpers.logMessage(self.logFile, "allCNN", "Data", "Negative training data created, size: " + str(len(trainingData0)))
168
        self.Helpers.logMessage(self.logFile, "allCNN", "Data", "Positive training data created, size: " + str(len(trainingData1)))
169
        
170
        for i in range(0, len(trainingData0)): 
171
            self.trainData.append(trainingData0[i])
172
        for i in range(0, len(trainingData1)): 
173
            self.trainData.append(trainingData1[i])
174
        
175
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Paper training data created. " + str(len(trainingData0)) + " x Negative & " + str(len(trainingData1)) + " x Positive ")
176
177
        valData0 = self.trainData0[self.negativeTrainAmnt:]
178
        valData1 = self.trainData1[self.positiveTrainAmnt:]
179
180
        self.Helpers.logMessage(self.logFile, "allCNN", "Data", "Negative validation data created, size: " + str(len(valData0)))
181
        self.Helpers.logMessage(self.logFile, "allCNN", "Data", "Positive validation data created, size: " + str(len(valData1)))
182
        
183
        for i in range(0, len(valData0)): 
184
            self.valData.append(valData0[i])
185
        for i in range(0, len(valData1)): 
186
            self.valData.append(valData1[i])
187
        
188
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Paper validation data created. " + str(len(valData0)) + " x Negative & " + str(len(valData1)) + " x Positive ")
189
190
        msg = "Recreated negative training size of " + str(len(trainingData0)) 
191
        msg += " with negative testing size of " + str(len(trainingData1)) 
192
        msg += " and positive training size of " + str(len(valData0)) 
193
        msg += " with positive testing size of " + str(len(valData1))
194
195
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", msg)
196
197
    def createDatum(self, imageData, label):
198
199
        """
200
        Generates a Datum object including label.
201
        """
202
    
203
        datum = caffe_pb2.Datum()
204
        datum.channels = imageData.shape[2]
205
        datum.height = imageData.shape[0]
206
        datum.width = imageData.shape[1]
207
        datum.data = imageData.tobytes()
208
        datum.label = int(label)
209
210
        return datum
211
        
212
    def transform(self, img):
213
214
        """
215
        Transforms image using histogram equalization and resizing.
216
        """
217
        
218
        img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
219
        img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
220
        img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
221
        
222
        return cv2.resize(img, 
223
                          self.imSize, 
224
                          interpolation = cv2.INTER_CUBIC)
225
226
    def createAllDatum(self, rlmdb, data, dType):
227
228
        """
229
        Generates all Datum objects.
230
        """
231
        
232
        if dType is "Training":
233
            dataPath = self.trainData
234
        elif dType is "Validation":
235
            dataPath = self.valData
236
237
        with rlmdb.begin(write=True) as i:
238
            count = 0
239
            for data in dataPath:
240
                i.put(
241
                    '{:08}'.format(count).encode('ascii'), 
242
                    self.createDatum(
243
                        cv2.resize(
244
                            self.transform(
245
                                cv2.imread(data, cv2.IMREAD_COLOR)
246
                            ), 
247
                            (self.confs["Settings"]["Classifier"]["Input"]["imageHeight"], self.confs["Settings"]["Classifier"]["Input"]["imageWidth"])), 
248
                        os.path.basename(os.path.dirname(data))
249
                    ).SerializeToString())
250
                count = count + 1
251
        rlmdb.close()
252
        
253
        self.Helpers.logMessage(self.logFile,  "allCNN", "Status", dType + " data count: " + str(count))
254
255
    def createTrainingLMDB(self):
256
257
        """
258
        Creates training LMDB database.
259
        """
260
261
        random.shuffle(self.trainData)
262
        self.createAllDatum(lmdb.open(self.confs["Settings"]["Classifier"]["LMDB"]["train"], map_size=int(1e12)), self.trainData, "Training")
263
        
264
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Training LDBM created")
265
266
    def createValidationLMDB(self):
267
268
        """
269
        Creates validation LMDB database.
270
        """
271
272
        random.shuffle(self.valData)
273
        self.createAllDatum(lmdb.open(self.confs["Settings"]["Classifier"]["LMDB"]["validation"], map_size=int(1e12)), self.trainData, "Validation")
274
        
275
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Validation LDBM created")
276
277
    def computeMean(self):
278
279
        """
280
        Computes the mean.
281
        """
282
283
        os.system('/home/upsquared/caffe/build/tools/compute_image_mean -backend=lmdb  ' + self.confs["Settings"]["Classifier"]["LMDB"]["train"] + ' ' + self.confs["Settings"]["Classifier"]["Caffe"]["proto"])
284
        
285
        self.Helpers.logMessage(self.logFile, "allCNN", "Status", "Mean computed")