|
a |
|
b/steps/step_C_varpcanet.m |
|
|
1 |
function [] = step_C_varpcanet(dirTest, dirUtilities, ext, numCoresFeatExtr, numCoresKnn, fidLogs, logS, savefile, plotFigures) |
|
|
2 |
|
|
|
3 |
%-------------------------------------- |
|
|
4 |
%General parameters |
|
|
5 |
stepPrint = 100; |
|
|
6 |
%PCA Params |
|
|
7 |
run('./params/paramsPCATuning.m'); |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
%-------------------------------------- |
|
|
11 |
%Dir DBs |
|
|
12 |
dbname_All = { ... |
|
|
13 |
'ALL_IDB' |
|
|
14 |
}; |
|
|
15 |
dbname_part_All = { ... |
|
|
16 |
'ALL_IDB2' |
|
|
17 |
}; |
|
|
18 |
dbname_ROI_All = { ... |
|
|
19 |
'ROI_256' |
|
|
20 |
}; |
|
|
21 |
|
|
|
22 |
%nets |
|
|
23 |
net_name{1} = 'AlexNet'; |
|
|
24 |
net_name{2} = 'VGG16'; |
|
|
25 |
net_name{3} = 'VGG19'; |
|
|
26 |
net_name{4} = 'ResNet18'; |
|
|
27 |
net_name{5} = 'ResNet50'; |
|
|
28 |
net_name{6} = 'ResNet101'; |
|
|
29 |
net_name{7} = 'DenseNet201'; |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
%-------------------------------------- |
|
|
33 |
colorS_init = 1; |
|
|
34 |
colorS_tune = 1; |
|
|
35 |
colorS_test = 1; |
|
|
36 |
%-------------------------------------- |
|
|
37 |
|
|
|
38 |
% processDummyDirs(); |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
|
|
|
42 |
%-------------------------------------- |
|
|
43 |
%Loop on dbs |
|
|
44 |
for db = 1 : numel(dbname_All) |
|
|
45 |
% for db = 2 |
|
|
46 |
|
|
|
47 |
%Close |
|
|
48 |
close all |
|
|
49 |
pause(0.2); |
|
|
50 |
|
|
|
51 |
%DB selection |
|
|
52 |
dbname = dbname_All{db}; |
|
|
53 |
dbnamePart = dbname_part_All{db}; |
|
|
54 |
ROI = dbname_ROI_All{db}; |
|
|
55 |
dirDB_wROI = [dirTest dbname '/' dbnamePart '/' ROI '/']; |
|
|
56 |
%dirDB_noROI = [dirWorkspace dbname '/' dbnamePart '/' ]; |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
%-------------------------------------- |
|
|
60 |
%loop on nets |
|
|
61 |
for n = 1 : numel(net_name) |
|
|
62 |
|
|
|
63 |
%switch net |
|
|
64 |
switch n |
|
|
65 |
|
|
|
66 |
case 1 |
|
|
67 |
net = alexnet; |
|
|
68 |
layer = 'fc6'; |
|
|
69 |
conv_layer = 'conv5'; |
|
|
70 |
|
|
|
71 |
case 2 |
|
|
72 |
net = vgg16; |
|
|
73 |
layer = 'fc6'; |
|
|
74 |
conv_layer = 'conv5_3'; |
|
|
75 |
|
|
|
76 |
case 3 |
|
|
77 |
net = vgg19; |
|
|
78 |
layer = 'fc6'; |
|
|
79 |
conv_layer = 'conv5_4'; |
|
|
80 |
|
|
|
81 |
case 4 |
|
|
82 |
net = resnet18; |
|
|
83 |
layer = 'fc1000'; |
|
|
84 |
conv_layer = 'res5b_relu'; |
|
|
85 |
|
|
|
86 |
case 5 |
|
|
87 |
net = resnet50; |
|
|
88 |
layer = 'fc1000'; |
|
|
89 |
conv_layer = 'res5c_branch2c'; |
|
|
90 |
|
|
|
91 |
case 6 |
|
|
92 |
net = resnet101; |
|
|
93 |
layer = 'fc1000'; |
|
|
94 |
conv_layer = 'res5c_branch2c'; |
|
|
95 |
|
|
|
96 |
case 7 |
|
|
97 |
net = densenet201; |
|
|
98 |
layer = 'fc1000'; |
|
|
99 |
conv_layer = 'conv5_block32_2_conv'; |
|
|
100 |
|
|
|
101 |
end %switch |
|
|
102 |
|
|
|
103 |
|
|
|
104 |
%-------------------------------------- |
|
|
105 |
%Folder creation |
|
|
106 |
%RESULTS: dirs net |
|
|
107 |
dirResults = ['./Results/' dbname '/' dbnamePart '/' net_name{n} '/']; |
|
|
108 |
mkdir_pers(dirResults, savefile); |
|
|
109 |
%RESULTS: log file |
|
|
110 |
timeStampRaw = datestr(datetime); |
|
|
111 |
timeStamp = strrep(timeStampRaw, ':', '-'); |
|
|
112 |
if savefile && logS |
|
|
113 |
logFile = [dirResults dbname '_log_' timeStamp '.txt']; |
|
|
114 |
fidLog = fopen(logFile, 'w'); |
|
|
115 |
fidLogs{2} = fidLog; |
|
|
116 |
end %if savefile && log |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
%-------------------------------------- |
|
|
120 |
%Display |
|
|
121 |
fprintf_pers(fidLogs, '\n'); |
|
|
122 |
fprintf_pers(fidLogs, '---------------\n'); |
|
|
123 |
fprintf_pers(fidLogs, 'ALL-Unsharpen\n'); |
|
|
124 |
fprintf_pers(fidLogs, [dbname '\n']); |
|
|
125 |
fprintf_pers(fidLogs, [dbnamePart '\n']); |
|
|
126 |
fprintf_pers(fidLogs, '---------------\n'); |
|
|
127 |
fprintf_pers(fidLogs, '\n'); |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
%-------------------------------------- |
|
|
131 |
%display |
|
|
132 |
fprintf_pers(fidLogs, '---------------\n'); |
|
|
133 |
fprintf_pers(fidLogs, ['Net: ' net_name{n} '\n']); |
|
|
134 |
fprintf_pers(fidLogs, '---------------\n'); |
|
|
135 |
fprintf_pers(fidLogs, '\n') |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
%-------------------------------------- |
|
|
139 |
%DB processing |
|
|
140 |
%Extract samples |
|
|
141 |
files = dir([dirDB_wROI '*.' ext]); |
|
|
142 |
|
|
|
143 |
%Compute labels |
|
|
144 |
[problem, labels, numImagesAll] = computeLabels(dirDB_wROI, files); |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
%-------------------------------------- |
|
|
148 |
%Display |
|
|
149 |
fprintf_pers(fidLogs, 'Extracting samples...\n'); |
|
|
150 |
fprintf_pers(fidLogs, ['\t' num2str(numImagesAll) ' images in total\n']); |
|
|
151 |
fprintf_pers(fidLogs, '\n'); |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
%-------------------------------------- |
|
|
155 |
%LOOP ON ITERATIONS |
|
|
156 |
%Init |
|
|
157 |
accuracy_knnAll = zeros(param.numIterations, 1); |
|
|
158 |
cmc = cell(param.numIterations, 1); |
|
|
159 |
|
|
|
160 |
%-------------------------------------- |
|
|
161 |
%Compute random fold indexes |
|
|
162 |
%if outside iteration loop, random fold once (es. 10-fold) |
|
|
163 |
%[allIndexes, cvIndices] = computeAllIndexesFold(numImagesAll, labels, param); |
|
|
164 |
|
|
|
165 |
%Loop |
|
|
166 |
for r = 1 : param.numIterations |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
%-------------------------------------- |
|
|
170 |
%Display |
|
|
171 |
fprintf_pers(fidLogs, ['Iteration N. ' num2str(r) '\n']); |
|
|
172 |
|
|
|
173 |
|
|
|
174 |
%-------------------------------------- |
|
|
175 |
%File save info |
|
|
176 |
fileSaveTest_iter = [dirResults '/results_iter_' num2str(r) '.mat']; |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
%-------------------------------------- |
|
|
180 |
%Compute random fold indexes |
|
|
181 |
%--10-fold |
|
|
182 |
%[indImagesTrain, indImagesTest, numImagesTrain, numImagesTest] = computeIndexesFold(cvIndices, r); |
|
|
183 |
%--2-fold if inside iteration loop, random fold each iteration (repeated 2-fold) |
|
|
184 |
[allIndexes, cvIndices] = computeAllIndexesFold(numImagesAll, labels, param); |
|
|
185 |
[indImagesTrain, indImagesTest, numImagesTrain, numImagesTest] = computeIndexesFold(cvIndices, randi(2, 1)); |
|
|
186 |
%Corresponding labels |
|
|
187 |
TrnLabels = labels(indImagesTrain); |
|
|
188 |
TestLabels = labels(indImagesTest); |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
%-------------------------------------- |
|
|
192 |
%Display output number of images |
|
|
193 |
fprintf_pers(fidLogs, ['\t' num2str(numImagesTrain) ' images are chosen for training\n']); |
|
|
194 |
fprintf_pers(fidLogs, ['\t' num2str(numImagesTest) ' images are chosen for testing\n']); |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
|
|
|
198 |
%%%%%%%%%%%%%% TRAINING %%%%%%%%%%%%% |
|
|
199 |
start_pool(numCoresFeatExtr); |
|
|
200 |
|
|
|
201 |
%-------------------------------------- |
|
|
202 |
fprintf_pers(fidLogs, '\tTraining... \n') |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
%-------------------------------------- |
|
|
206 |
%Load images for training |
|
|
207 |
fprintf_pers(fidLogs, '\t\tLoading images for training... \n') |
|
|
208 |
[imagesCellTrain, filenameTrn, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTrain, numImagesTrain, param, colorS_init, 100, dirUtilities, 0); |
|
|
209 |
imagesCellTrain = adjustFormat(imagesCellTrain); |
|
|
210 |
|
|
|
211 |
%norm |
|
|
212 |
%[imagesCellTrain, dd1, dd2] = computeNorm(imagesCellTrain); |
|
|
213 |
|
|
|
214 |
%find th_focus |
|
|
215 |
fprintf_pers(fidLogs, '\tComputing th_focus init... \n') |
|
|
216 |
|
|
|
217 |
%%%%%% QUI %%%%%% |
|
|
218 |
th_focus_init = find_th_focus(imagesCellTrain, TrnLabels, [128 128], dirUtilities, fidLogs); |
|
|
219 |
%th_focus_init = 7.3; |
|
|
220 |
|
|
|
221 |
%puliamo |
|
|
222 |
clear imagesCellTrain |
|
|
223 |
|
|
|
224 |
%tune |
|
|
225 |
%th_focus_start = round((th_focus_init - th_focus_init*percC/100)*10)/10; |
|
|
226 |
%th_focus_end = round((th_focus_init + th_focus_init*percC/100)*10)/10; |
|
|
227 |
th_focus_start = round((th_focus_init - 0.5)*10)/10; |
|
|
228 |
th_focus_end = round((th_focus_init + 0.5)*10)/10; |
|
|
229 |
%th_focus_start = th_focus_init; |
|
|
230 |
%th_focus_end = th_focus_init; |
|
|
231 |
|
|
|
232 |
%init |
|
|
233 |
accuracy_knnALLFOCUS = []; |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
%-------------------------------------- |
|
|
237 |
%tuning th_focus |
|
|
238 |
fprintf_pers(fidLogs, '\tTuning th_focus... \n') |
|
|
239 |
|
|
|
240 |
%loop on th_focus |
|
|
241 |
allfocuses = th_focus_start : 0.1 : th_focus_end; |
|
|
242 |
%%%%%% QUI %%%%%% |
|
|
243 |
for th_focus = allfocuses |
|
|
244 |
%for th_focus = th_focus_start |
|
|
245 |
|
|
|
246 |
%Display |
|
|
247 |
fprintf_pers(fidLogs, '\n') |
|
|
248 |
fprintf_pers(fidLogs, ['\t\tth_focus: ' num2str(th_focus) '\n']); |
|
|
249 |
|
|
|
250 |
|
|
|
251 |
fprintf_pers(fidLogs, '\t\tLoading images... \n') |
|
|
252 |
[imagesCellTrain, filenameTrn, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTrain, numImagesTrain, param, colorS_tune, th_focus, dirUtilities, 0); |
|
|
253 |
imagesCellTrain = adjustFormatForPCANet(imagesCellTrain, [128, 128]); |
|
|
254 |
|
|
|
255 |
%-------------------------------------- |
|
|
256 |
%PCANet Training |
|
|
257 |
%1 layer: PCA filters |
|
|
258 |
[V, PCANet] = trainPCANet(imagesCellTrain, PCANet, fidLogs, param, numCoresFeatExtr); |
|
|
259 |
%Feature extraction |
|
|
260 |
fprintf_pers(fidLogs, '\t\tFeature extraction... \n') |
|
|
261 |
[ftrain_all, numFeaturesTrain] = featExtrGaborAdapt(imagesCellTrain, V, PCANet, [], param, numImagesTrain, stepPrint); |
|
|
262 |
|
|
|
263 |
%size |
|
|
264 |
sizeTrain = size(ftrain_all, 2); |
|
|
265 |
|
|
|
266 |
%-------------------------------------- |
|
|
267 |
%performance |
|
|
268 |
fprintf_pers(fidLogs, '\t\tClassification - original... \n') |
|
|
269 |
errorStruct_original_temp = computeClassificationPerformance(numFeaturesTrain, sizeTrain, ftrain_all, TrnLabels, stepPrint, numCoresKnn, fidLogs, param); |
|
|
270 |
|
|
|
271 |
%compute cmc |
|
|
272 |
[cmc_original_temp, cmc_sum_original_temp] = computeCMC(errorStruct_original_temp.distMatrixTest, TrnLabels, ['cmc original temp iteration ' num2str(r)], 0); |
|
|
273 |
|
|
|
274 |
%Error metrics |
|
|
275 |
fprintf_pers(fidLogs, ['\t\tTraining accuracy (perc. of correctly classified samples, at iteration n. ' num2str(r) '): %s%%\n'], num2str(errorStruct_original_temp.accuracy_knn*100)); |
|
|
276 |
fprintf_pers(fidLogs, ['\t\tAUC of CMC (at iteration n. ' num2str(r) '): %s\n'], num2str(cmc_sum_original_temp)); |
|
|
277 |
|
|
|
278 |
|
|
|
279 |
%assign |
|
|
280 |
accuracy_knnALLFOCUS = [accuracy_knnALLFOCUS errorStruct_original_temp.accuracy_knn]; |
|
|
281 |
%accuracy_knnALLFOCUS = [accuracy_knnALLFOCUS cmc_sum_original_temp]; %maximize AUC of CMC |
|
|
282 |
|
|
|
283 |
|
|
|
284 |
end %th_focus |
|
|
285 |
|
|
|
286 |
%Puliamo |
|
|
287 |
clear imagesCellTrain ftrain_all |
|
|
288 |
|
|
|
289 |
|
|
|
290 |
|
|
|
291 |
|
|
|
292 |
|
|
|
293 |
%%%%%%%%%%%%%% APPLY BEST FOCUS %%%%%%%%%%%%% |
|
|
294 |
fprintf_pers(fidLogs, '\n') |
|
|
295 |
fprintf_pers(fidLogs, '\tApply best focus\n') |
|
|
296 |
%best focusc |
|
|
297 |
%[maxAcc, i_best_focus] = max(accuracy_knnALLFOCUS); |
|
|
298 |
%i_best_focus = i_best_focus(1); |
|
|
299 |
[sortA, isort] = sort(accuracy_knnALLFOCUS); |
|
|
300 |
i_best_focus = isort(end); %highest th_focus that gives the best result |
|
|
301 |
best_th_focus = allfocuses(i_best_focus); |
|
|
302 |
fprintf_pers(fidLogs, ['\t\tBest focus: ' num2str(best_th_focus) ' \n']); |
|
|
303 |
|
|
|
304 |
%training data |
|
|
305 |
fprintf_pers(fidLogs, '\t\tLoading training images - original... \n') |
|
|
306 |
[imagesCellTrain_original, ~, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTrain, numImagesTrain, param, colorS_test, 100, dirUtilities, 0); |
|
|
307 |
imagesCellTrain_original = adjustFormat(imagesCellTrain_original); |
|
|
308 |
%norm |
|
|
309 |
%[imagesCellTrain_original, meanA_original, stdA_original] = computeNorm(imagesCellTrain_original); |
|
|
310 |
fprintf_pers(fidLogs, '\t\tLoading training images - unsharpened... \n') |
|
|
311 |
[imagesCellTrain_unsharp, ~, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTrain, numImagesTrain, param, colorS_test, best_th_focus, dirUtilities, 0); |
|
|
312 |
imagesCellTrain_unsharp = adjustFormat(imagesCellTrain_unsharp); |
|
|
313 |
%norm |
|
|
314 |
%[imagesCellTrain_unsharp, meanA_unsharp, stdA_unsharp] = computeNorm(imagesCellTrain_unsharp); |
|
|
315 |
|
|
|
316 |
%testing data |
|
|
317 |
fprintf_pers(fidLogs, '\t\tLoading testing images - original... \n') |
|
|
318 |
[imagesCellTest_original, ~, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTest, numImagesTest, param, colorS_test, 100, dirUtilities, 0); |
|
|
319 |
[imagesCellTest_original, meanAll_test_original] = adjustFormat(imagesCellTest_original); |
|
|
320 |
%imagesCellTest_original = applyNorm(imagesCellTest_original, meanA_original, stdA_original); |
|
|
321 |
fprintf_pers(fidLogs, '\t\tLoading testing images - unsharpened... \n') |
|
|
322 |
[imagesCellTest_unsharp, filenameTest, indexes_test_imUnsharpened] = loadImages(files, dirDB_wROI, allIndexes, indImagesTest, numImagesTest, param, colorS_test, best_th_focus, dirUtilities, 0); |
|
|
323 |
[imagesCellTest_unsharp, meanAll_test_unsharp] = adjustFormat(imagesCellTest_unsharp); |
|
|
324 |
%imagesCellTest_unsharpened = applyNorm(imagesCellTest_unsharpened, meanA_unsharp, stdA_unsharp); |
|
|
325 |
|
|
|
326 |
|
|
|
327 |
|
|
|
328 |
%%%%%%%%%%%%%% TESTING - PRE-TRAINED CNNs %%%%%%%%%%%%% |
|
|
329 |
%pre-trained |
|
|
330 |
fprintf_pers(fidLogs, '\t\tPretrained CNNs... \n') |
|
|
331 |
|
|
|
332 |
|
|
|
333 |
%Feature extraction - ORIGINAL |
|
|
334 |
fprintf_pers(fidLogs, '\t\t\tFeature extraction - original... \n') |
|
|
335 |
ftrain_all_original = feature_extraction_cnn(imagesCellTrain_original, net, layer, colorS_test); |
|
|
336 |
ftest_all_original = feature_extraction_cnn(imagesCellTest_original, net, layer, colorS_test); |
|
|
337 |
|
|
|
338 |
%Feature extraction - UNSHARP |
|
|
339 |
fprintf_pers(fidLogs, '\t\t\tFeature extraction - unsharp... \n') |
|
|
340 |
ftrain_all_unsharp = feature_extraction_cnn(imagesCellTrain_unsharp, net, layer, colorS_test); |
|
|
341 |
ftest_all_unsharp = feature_extraction_cnn(imagesCellTest_unsharp, net, layer, colorS_test); |
|
|
342 |
|
|
|
343 |
%size |
|
|
344 |
numFeatures = size(ftest_all_unsharp, 1); |
|
|
345 |
sizeTest = size(ftest_all_unsharp, 2); |
|
|
346 |
|
|
|
347 |
|
|
|
348 |
%-------------------------------------- |
|
|
349 |
%Classification performance |
|
|
350 |
%Original |
|
|
351 |
%fprintf_pers(fidLogs, '\t\t\tClassification - original... \n') |
|
|
352 |
%errorStruct_pretrained_original(r) = computeClassificationPerformance(numFeatures, sizeTest, ftest_all_original, TestLabels, stepPrint, numCoresKnn, fidLogs, param); |
|
|
353 |
errorStruct_pretrained_original(r) = computeClassificationPerformanceTrainTest(numFeatures, sizeTest, ftrain_all_original, ftest_all_original, TrnLabels, TestLabels, stepPrint, numCoresKnn, fidLogs, param); |
|
|
354 |
%Unsharp |
|
|
355 |
%fprintf_pers(fidLogs, '\t\t\tClassification - unsharp... \n') |
|
|
356 |
%errorStruct_pretrained_unsharp(r) = computeClassificationPerformance(numFeatures, sizeTest, ftest_all_unsharp, TestLabels, stepPrint, numCoresKnn, fidLogs, param); |
|
|
357 |
errorStruct_pretrained_unsharp(r) = computeClassificationPerformanceTrainTest(numFeatures, sizeTest, ftrain_all_unsharp, ftest_all_unsharp, TrnLabels, TestLabels, stepPrint, numCoresKnn, fidLogs, param); |
|
|
358 |
|
|
|
359 |
%puliamo |
|
|
360 |
clear ftest_all_original ftest_all_unsharp |
|
|
361 |
|
|
|
362 |
%compute cmc |
|
|
363 |
[cmc1, cmc_sum1] = computeCMC_trainTest(errorStruct_pretrained_original(r).distMatrixTest, TestLabels, ['cmc original iteration ' num2str(r)], plotFigures); |
|
|
364 |
[cmc2, cmc_sum2] = computeCMC_trainTest(errorStruct_pretrained_unsharp(r).distMatrixTest, TestLabels, ['cmc unsharp iteration ' num2str(r)], plotFigures); |
|
|
365 |
cmc_original{r} = cmc1; |
|
|
366 |
cmc_sum_original{r} = cmc_sum1; |
|
|
367 |
cmc_unsharp{r} = cmc2; |
|
|
368 |
cmc_sum_unsharp{r} = cmc_sum2; |
|
|
369 |
errorStruct_pretrained_original(r).rank5 = cmc1(5); |
|
|
370 |
errorStruct_pretrained_unsharp(r).rank5 = cmc2(5); |
|
|
371 |
|
|
|
372 |
%Display |
|
|
373 |
fprintf_pers(fidLogs, '\n') |
|
|
374 |
fprintf_pers(fidLogs, ['\tPretrained - Accuracy original (at iteration n. ' num2str(r) '): %s%%\n'], num2str(errorStruct_pretrained_original(r).accuracy_knn*100)); |
|
|
375 |
fprintf_pers(fidLogs, ['\tPretrained - Accuracy unsharp (at iteration n. ' num2str(r) '): %s%%\n'], num2str(errorStruct_pretrained_unsharp(r).accuracy_knn*100)); |
|
|
376 |
fprintf_pers(fidLogs, ['\tPretrained - Rank 5 accuracy original (at iteration n. ' num2str(r) '): %s%%\n'], num2str(cmc1(5)*100)); |
|
|
377 |
fprintf_pers(fidLogs, ['\tPretrained - Rank 5 accuracy unsharp (at iteration n. ' num2str(r) '): %s%%\n'], num2str(cmc2(5)*100)); |
|
|
378 |
|
|
|
379 |
pause(0.1) |
|
|
380 |
|
|
|
381 |
|
|
|
382 |
%%%%%%%%%%%%%% TESTING - FINE TUNING CNNs %%%%%%%%%%%%% |
|
|
383 |
%pre-trained |
|
|
384 |
fprintf_pers(fidLogs, '\n') |
|
|
385 |
fprintf_pers(fidLogs, '\tFine tuning CNNs... \n') |
|
|
386 |
%fprintf_pers(fidLogs, '\t\t\tTraining... \n') |
|
|
387 |
|
|
|
388 |
pixelRange = [-30 30]; |
|
|
389 |
rotRange = [-180 180]; |
|
|
390 |
|
|
|
391 |
imageAugmenter = imageDataAugmenter( ... |
|
|
392 |
'RandXReflection', true, ... |
|
|
393 |
'RandYReflection', true, ... |
|
|
394 |
'RandRotation', rotRange); |
|
|
395 |
%'RandXTranslation', pixelRange, ... |
|
|
396 |
%'RandYTranslation', pixelRange ... |
|
|
397 |
|
|
|
398 |
|
|
|
399 |
numClasses = numel(unique(labels)); |
|
|
400 |
inputSize = net.Layers(1).InputSize; |
|
|
401 |
% layersTransfer = net.Layers(1:end-3); |
|
|
402 |
% layers = [ |
|
|
403 |
% layersTransfer |
|
|
404 |
% fullyConnectedLayer(numClasses, 'WeightLearnRateFactor', 20, 'BiasLearnRateFactor', 20) |
|
|
405 |
% softmaxLayer |
|
|
406 |
% classificationLayer]; |
|
|
407 |
|
|
|
408 |
%change last layers |
|
|
409 |
lgraph = replaceLayers(net, numClasses); |
|
|
410 |
|
|
|
411 |
%options |
|
|
412 |
options = trainingOptions('sgdm', ... |
|
|
413 |
'MiniBatchSize', 20, ... %128 20 |
|
|
414 |
'MaxEpochs', 100, ... |
|
|
415 |
'InitialLearnRate', 1e-4, ... |
|
|
416 |
'Shuffle', 'every-epoch', ... 'never' |
|
|
417 |
'ValidationFrequency', 3, ... |
|
|
418 |
'Verbose', false, ... |
|
|
419 |
'Plots', 'none'); % 'training-progress' |
|
|
420 |
|
|
|
421 |
fprintf_pers(fidLogs, '\t\tTraining original... \n') |
|
|
422 |
%netTransfer_original = fineTuneCNN(imagesCellTrain_original, TrnLabels, './dummy_train_original/', inputSize, imageAugmenter, layers, options); |
|
|
423 |
netTransfer_original = fineTuneCNN(imagesCellTrain_original, TrnLabels, './dummy_train_original/', inputSize, imageAugmenter, lgraph, options); |
|
|
424 |
|
|
|
425 |
fprintf_pers(fidLogs, '\t\tTraining unsharp... \n') |
|
|
426 |
%netTransfer_unsharp = fineTuneCNN(imagesCellTrain_unsharp, TrnLabels, './dummy_train_unsharp/', inputSize, imageAugmenter, layers, options); |
|
|
427 |
netTransfer_unsharp = fineTuneCNN(imagesCellTrain_unsharp, TrnLabels, './dummy_train_unsharp/', inputSize, imageAugmenter, lgraph, options); |
|
|
428 |
|
|
|
429 |
% |
|
|
430 |
%fprintf_pers(fidLogs, '\t\t\tTesting... \n'); |
|
|
431 |
|
|
|
432 |
%cm |
|
|
433 |
fprintf_pers(fidLogs, '\t\tTesting original... \n') |
|
|
434 |
errorStruct_finetune_original(r) = computeClassPerformanceFineTuneCNN(imagesCellTest_original, TestLabels, './dummy_test_original/', inputSize, netTransfer_original, fidLogs); |
|
|
435 |
|
|
|
436 |
fprintf_pers(fidLogs, '\t\tTesting unsharp... \n') |
|
|
437 |
errorStruct_finetune_unsharp(r) = computeClassPerformanceFineTuneCNN(imagesCellTest_unsharp, TestLabels, './dummy_test_unsharp/', inputSize, netTransfer_unsharp, fidLogs); |
|
|
438 |
|
|
|
439 |
%Display |
|
|
440 |
fprintf_pers(fidLogs, ['\tFine tuning - Accuracy original (at iteration n. ' num2str(r) '): %s%%\n'], num2str(errorStruct_finetune_original(r).accuracy_knn*100)); |
|
|
441 |
fprintf_pers(fidLogs, ['\tFine tuning - Accuracy unsharp (at iteration n. ' num2str(r) '): %s%%\n'], num2str(errorStruct_finetune_unsharp(r).accuracy_knn*100)); |
|
|
442 |
|
|
|
443 |
|
|
|
444 |
%-------------------------------------- |
|
|
445 |
%Save |
|
|
446 |
if savefile |
|
|
447 |
save(fileSaveTest_iter, 'errorStruct_pretrained_original', 'errorStruct_pretrained_unsharp', 'errorStruct_finetune_original', 'errorStruct_finetune_unsharp', 'cmc_original', 'cmc_unsharp'); |
|
|
448 |
end %if savefile |
|
|
449 |
|
|
|
450 |
|
|
|
451 |
%-------------------------------------- |
|
|
452 |
%Display progress |
|
|
453 |
fprintf_pers(fidLogs, '\n'); |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
%GRAD-CAM |
|
|
457 |
fprintf_pers(fidLogs, 'Grad-CAM\n'); |
|
|
458 |
fprintf_pers(fidLogs, '\n'); |
|
|
459 |
dirGcam = [dirResults 'gcam_iter_' num2str(r) '/']; |
|
|
460 |
mkdir_pers(dirGcam, savefile); |
|
|
461 |
[imagesCellTest_original, ~, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTest, numImagesTest, param, colorS_test, 100, dirUtilities, 0); |
|
|
462 |
[imagesCellTest_unsharp, ~, ~] = loadImages(files, dirDB_wROI, allIndexes, indImagesTest, numImagesTest, param, colorS_test, best_th_focus, dirUtilities, 0); |
|
|
463 |
computeGradCam(imagesCellTest_original, imagesCellTest_unsharp, meanAll_test_original, meanAll_test_unsharp, indexes_test_imUnsharpened, ... |
|
|
464 |
netTransfer_original, netTransfer_unsharp, conv_layer, inputSize, filenameTest, TestLabels, dirGcam); |
|
|
465 |
%computeGradCam2(imagesCellTest_original, imagesCellTest_unsharp, meanAll_test_original, meanAll_test_unsharp, indexes_test_imUnsharpened, ... |
|
|
466 |
%netTransfer_original, netTransfer_unsharp, conv_layer, inputSize, filenameTest, TestLabels, dirGcam); |
|
|
467 |
|
|
|
468 |
|
|
|
469 |
|
|
|
470 |
|
|
|
471 |
%Puliamo |
|
|
472 |
clear imagesCellTrain_original imagesCellTrain_unsharpenend imagesCellTest_original imagesCellTest_unsharpened netTransfer_original netTransfer_unsharp layers |
|
|
473 |
|
|
|
474 |
|
|
|
475 |
end %for r = 1 : param.numIterations |
|
|
476 |
|
|
|
477 |
|
|
|
478 |
close all |
|
|
479 |
pause(0.1) |
|
|
480 |
|
|
|
481 |
%display |
|
|
482 |
fprintf_pers(fidLogs, '\n'); |
|
|
483 |
|
|
|
484 |
|
|
|
485 |
|
|
|
486 |
%-------------------------------------- |
|
|
487 |
%Average classification performance |
|
|
488 |
%PRETRAINED |
|
|
489 |
%Error metrics |
|
|
490 |
fprintf_pers(fidLogs, '\n'); |
|
|
491 |
%original |
|
|
492 |
fprintf_pers(fidLogs, 'Pretrained - Original\n') |
|
|
493 |
stampaErrors(errorStruct_pretrained_original, fidLogs); |
|
|
494 |
fprintf_pers(fidLogs, '\tRank 5 accuracy (mean; std): %s%%; %s%% \n', num2str(mean([errorStruct_pretrained_original.rank5])*100), num2str(std([errorStruct_pretrained_original.rank5])*100)); |
|
|
495 |
%unsharp |
|
|
496 |
fprintf_pers(fidLogs, 'Pretrained - Unsharp\n') |
|
|
497 |
stampaErrors(errorStruct_pretrained_unsharp, fidLogs); |
|
|
498 |
fprintf_pers(fidLogs, '\tRank 5 accuracy (mean; std): %s%%; %s%% \n', num2str(mean([errorStruct_pretrained_unsharp.rank5])*100), num2str(std([errorStruct_pretrained_unsharp.rank5])*100)); |
|
|
499 |
|
|
|
500 |
%FINE TUNING |
|
|
501 |
%Error metrics |
|
|
502 |
fprintf_pers(fidLogs, '\n'); |
|
|
503 |
%original |
|
|
504 |
fprintf_pers(fidLogs, 'Fine tuning - Original\n') |
|
|
505 |
stampaErrors(errorStruct_finetune_original, fidLogs); |
|
|
506 |
%unsharp |
|
|
507 |
fprintf_pers(fidLogs, 'Fine tuning - Unsharp\n') |
|
|
508 |
stampaErrors(errorStruct_finetune_unsharp, fidLogs); |
|
|
509 |
|
|
|
510 |
|
|
|
511 |
|
|
|
512 |
|
|
|
513 |
%-------------------------------------- |
|
|
514 |
%Average CMC |
|
|
515 |
%original |
|
|
516 |
stampaAvgCMC(cmc_original, 'original', dirResults, savefile, plotFigures); |
|
|
517 |
%unsharp |
|
|
518 |
stampaAvgCMC(cmc_unsharp, 'unsharp', dirResults, savefile, plotFigures); |
|
|
519 |
|
|
|
520 |
|
|
|
521 |
|
|
|
522 |
|
|
|
523 |
%-------------------------------------- |
|
|
524 |
%Display progress |
|
|
525 |
fprintf_pers(fidLogs, '\n'); |
|
|
526 |
|
|
|
527 |
|
|
|
528 |
%-------------------------------------- |
|
|
529 |
%Close file log |
|
|
530 |
if savefile && logS |
|
|
531 |
fclose(fidLog); |
|
|
532 |
end %if savefile && log |
|
|
533 |
% delete(gcp('nocreate')); |
|
|
534 |
fclose('all'); |
|
|
535 |
|
|
|
536 |
|
|
|
537 |
%close |
|
|
538 |
close all |
|
|
539 |
pause(0.1) |
|
|
540 |
|
|
|
541 |
|
|
|
542 |
end %for n |
|
|
543 |
|
|
|
544 |
|
|
|
545 |
end %for db |
|
|
546 |
|
|
|
547 |
|
|
|
548 |
|