a b/bin/ann_script_val.lua
1
2
3
print('\n\n @ @ @ @ @ @ START @ @ @ @ @ @ @ ');
4
print('file: script_start.lua');
5
print('author Davide Chicco <davide.chicco@gmail.com>');
6
print(os.date("%c", os.time()));
7
8
MAX_MSE = 4
9
-- RANDOM_SEED = 42
10
11
-- torch.manualSeed(RANDOM_SEED)
12
-- math.randomseed(RANDOM_SEED)
13
14
local timeStart = os.time()
15
16
-- createPerceptron
17
function createPerceptron(this_input_number, this_hidden_units, this_hidden_layers, this_output_number)
18
    perceptron = nn.Sequential()  
19
20
    perceptron:add(nn.Linear(this_input_number, this_hidden_units))
21
    -- perceptron:add(nn.Sigmoid())
22
    perceptron:add(nn.ReLU())
23
    if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end
24
    
25
    for w=1,this_hidden_layers do
26
        perceptron:add(nn.Linear(this_hidden_units, this_hidden_units))
27
        -- perceptron:add(nn.Sigmoid())
28
        perceptron:add(nn.ReLU())
29
        if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end
30
    end
31
    perceptron:add(nn.Linear(this_hidden_units, this_output_number))
32
33
    if XAVIER_INITIALIZATION==true then 
34
        print("XAVIER_INITIALIZATION = "..tostring(XAVIER_INITIALIZATION))
35
        perceptron = require("./weight-init.lua")(perceptron,  'xavier') -- XAVIER
36
    end
37
38
    return perceptron;
39
end
40
41
42
-- function executeTest
43
function executeTest(testPerceptron, dataset_patient_profile)
44
    local correctPredictions = 0
45
    local atleastOneTrue = false
46
    local atleastOneFalse = false
47
    local predictionTestVect = {}
48
    local truthVect = {}
49
50
    for i=1,#dataset_patient_profile do
51
        local current_label = dataset_patient_profile[i][2][1]
52
        local original_prediction = testPerceptron:forward(dataset_patient_profile[i][1])[1]
53
54
        -- io.write("original_prediction = ".. original_prediction)
55
        prediction = original_prediction --(original_prediction+1)/2
56
        predictionTestVect[i] = prediction
57
        truthVect[i] = current_label      
58
59
        -- io.write(" prediction = ".. round(prediction,2))
60
        -- io.write(" current_label = ".. current_label.."\n")
61
        -- io.flush()
62
63
        local labelResult = false      
64
        if current_label >= THRESHOLD and prediction >= THRESHOLD  then
65
            labelResult = true
66
        elseif current_label < THRESHOLD and prediction < THRESHOLD  then
67
            labelResult = true
68
        end
69
70
        if labelResult==true then correctPredictions = correctPredictions + 1; end      
71
        if prediction>=THRESHOLD then
72
            atleastOneTrue = true
73
        else
74
            atleastOneFalse = true
75
        end
76
    end
77
    print("\nCorrect predictions = "..round(correctPredictions*100/#dataset_patient_profile,2).."%")
78
79
    if atleastOneTrue==false then print("ATTENTION: all the predictions are FALSE") end
80
    if atleastOneFalse==false then print("ATTENTION: all the predictions are TRUE") end
81
82
    require './metrics_ROC_AUC_computer.lua'
83
    local output_AUC_computer = metrics_ROC_AUC_computer(predictionTestVect, truthVect)
84
    local auroc = output_AUC_computer[1]
85
    local aupr = output_AUC_computer[2]
86
87
    local printValues = false
88
    local output_confusion_matrix = confusion_matrix(predictionTestVect, truthVect, THRESHOLD, printValues)
89
90
    return {output_confusion_matrix[4], output_confusion_matrix[1], output_confusion_matrix[5], auroc, aupr}; 
91
    -- MCC, accuracy, f1_score, AUROC, AUPR
92
end
93
94
-- Function sleep
95
function sleep(n) os.execute("sleep " .. tonumber(n)); end
96
97
-- Function table.contains
98
function table.contains(table, element)  
99
    local count = 1
100
    for _, value in pairs(table) do
101
        -- print("value: "..tostring(value).." element: "..tostring(element));
102
        if tostring(value) == tostring(element) or value==element then
103
            return {true,count}
104
        end
105
        
106
        count = count + 1
107
    end
108
    return {false,-1}
109
end
110
111
-- Function that prints 
112
function printTime(timeStart, stringToPrint)
113
    timeEnd = os.time();
114
    duration = timeEnd - timeStart;
115
    print('\nduration '..stringToPrint.. ': '.. comma_value(tonumber(duration)).. ' seconds');
116
    io.flush();
117
    print('duration '..stringToPrint.. ': '..string.format("%.2d days, %.2d hours, %.2d minutes, %.2d seconds", (duration/(60*60))/24, duration/(60*60)%24, duration/60%60, duration%60)) 
118
    io.flush();
119
    
120
    return duration;
121
end
122
123
-- Function that reads a value and returns the string of the signed value
124
function signedValueFunction(value)
125
    local value = tonumber(value);
126
    --print("value = "..value);
127
    local charPlus = ""
128
    if tonumber(value) >= 0 then charPlus = "+"; end
129
    local outputString = charPlus..""..tostring(round(value,2));
130
    --print("outputString = "..outputString);
131
    return tostring(outputString);
132
end
133
134
-- from sam_lie
135
-- Compatible with Lua 5.0 and 5.1.
136
-- Disclaimer : use at own risk especially for hedge fund reports :-)
137
--============================================================
138
-- add comma to separate thousands
139
-- 
140
function comma_value(amount)
141
    local formatted = amount
142
    while true do  
143
        formatted, k = string.gsub(formatted, "^(-?%d+)(%d%d%d)", '%1,%2')
144
        if (k==0) then
145
            break
146
        end
147
    end
148
    return formatted
149
end
150
151
-- function that computes the confusion matrix
152
function confusion_matrix(predictionTestVect, truthVect, threshold, printValues)
153
154
    local tp = 0
155
    local tn = 0
156
    local fp = 0
157
    local fn = 0
158
    local f1_score = -2
159
    local MatthewsCC = -2
160
    local accuracy = -2
161
    local arrayFPindices = {}
162
    local arrayFPvalues = {}
163
    local arrayTPvalues = {}
164
    local areaRoc = 0
165
166
    local fpRateVett = {}
167
    local tpRateVett = {}
168
    local precisionVett = {}
169
    local recallVett = {}
170
171
    for i=1,#predictionTestVect do
172
173
        if printValues == true then
174
            io.write("predictionTestVect["..i.."] = ".. round(predictionTestVect[i],4).."\ttruthVect["..i.."] = "..truthVect[i].." ");
175
            io.flush();
176
        end
177
178
        if predictionTestVect[i] >= threshold and truthVect[i] >= threshold then
179
            tp = tp + 1
180
            arrayTPvalues[#arrayTPvalues+1] = predictionTestVect[i]
181
            if printValues == true then print(" TP ") end
182
        elseif  predictionTestVect[i] < threshold and truthVect[i] >= threshold then
183
            fn = fn + 1
184
            if printValues == true then print(" FN ") end
185
        elseif  predictionTestVect[i] >= threshold and truthVect[i] < threshold then
186
            fp = fp + 1
187
            if printValues == true then print(" FP ") end
188
            arrayFPindices[#arrayFPindices+1] = i;
189
            arrayFPvalues[#arrayFPvalues+1] = predictionTestVect[i]  
190
        elseif  predictionTestVect[i] < threshold and truthVect[i] < threshold then
191
            tn = tn + 1
192
            if printValues == true then print(" TN ") end
193
        end
194
    end
195
196
    print("TOTAL:")
197
    print(" FN = "..comma_value(fn).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction < threshold)");
198
    print(" TP = "..comma_value(tp).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction >= threshold)\n");
199
    print(" FP = "..comma_value(fp).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction >= threshold)");
200
    print(" TN = "..comma_value(tn).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction < threshold)\n");
201
202
    local continueLabel = true
203
204
    if continueLabel then
205
        upperMCC = (tp*tn) - (fp*fn)
206
        innerSquare = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)
207
        lowerMCC = math.sqrt(innerSquare)
208
        MatthewsCC = -2
209
        if lowerMCC>0 then MatthewsCC = upperMCC/lowerMCC end
210
        local signedMCC = signedValueFunction(MatthewsCC);
211
        -- print("signedMCC = "..signedMCC);
212
        
213
        if MatthewsCC > -2 then print("\n::::\tMatthews correlation coefficient = "..signedMCC.."\t::::\n");
214
        else print("Matthews correlation coefficient = NOT computable");    end
215
216
        accuracy = (tp + tn)/(tp + tn +fn + fp)
217
        print("accuracy = "..round(accuracy,2).. " = (tp + tn) / (tp + tn +fn + fp) \t  \t [worst = 0, best =  1]");
218
    
219
220
        recall = tp / (tp + fn)
221
        print("true positive rate = recall = "..round(recall,2).. " = tp / (tp + fn) \t  \t [worst = 0, best =  1]");
222
223
        specificity = tn / (tn + fp)
224
        print("true negative rate = specificity = "..round(specificity,2).. " = tn / (tn + fp) \t  \t [worst = 0, best =  1]");
225
226
        f1_score = -2
227
        if (tp+fp+fn)>0 then   
228
            f1_score = (2*tp) / (2*tp+fp+fn)
229
            print("f1_score = "..round(f1_score,2).." = (2*tp) / (2*tp+fp+fn) \t [worst = 0, best = 1]");
230
        else
231
            print("f1_score CANNOT be computed because (tp+fp+fn)==0")    
232
        end
233
234
        print("\n\nMCC \t\t F1_score \t accuracy \t TP_rate \t TN_rate")
235
        print(round(signedMCC,2).. " \t\t "..round(f1_score,2).." \t\t "..round(accuracy,2).." \t\t "..round(recall,2).." \t\t ".. round(specificity,2).."\n\n")
236
        
237
        
238
239
        -- local numberOfPredictedOnes = tp + fp;
240
        --       print("numberOfPredictedOnes = (TP + FP) = "..comma_value(numberOfPredictedOnes).." = "..round(numberOfPredictedOnes*100/(tp + tn + fn + fp),2).."%");
241
        --       
242
        --       io.write("\nDiagnosis: ");
243
        --       if (fn >= tp and (fn+tp)>0) then print("too many FN false negatives"); end
244
        --       if (fp >= tn and (fp+tn)>0) then print("too many FP false positives"); end
245
        --       
246
        --       
247
        --       if (tn > (10*fp) and tp > (10*fn)) then print("Excellent ! ! !");
248
        --       elseif (tn > (5*fp) and tp > (5*fn)) then print("Very good ! !"); 
249
        --       elseif (tn > (2*fp) and tp > (2*fn)) then print("Good !"); 
250
        --       elseif (tn >= fp and tp >= fn) then print("Alright"); 
251
        --       else print("Baaaad"); end
252
    end
253
254
    return {accuracy, arrayFPindices, arrayFPvalues, MatthewsCC, f1_score};
255
end
256
257
-- Permutations
258
-- tab = {1,2,3,4,5,6,7,8,9,10}
259
-- permute(tab, 10, 10)
260
function permute(tab, n, count)
261
    n = n or #tab
262
    for i = 1, count or n do
263
        local j = math.random(i, n)
264
        tab[i], tab[j] = tab[j], tab[i]
265
    end
266
    return tab
267
end
268
269
-- round a real value
270
function round(num, idp)
271
    local mult = 10^(idp or 0)
272
    return math.floor(num * mult + 0.5) / mult
273
end
274
275
-- ##############################
276
local profile_vett = {}
277
local csv = require("csv")
278
local fileName = "../data/LungCancerDataset_AllRecords_NORM_reduced_features.csv"
279
--     tostring(arg[1])
280
-- cervical_arranged_NORM.csv
281
-- cervical_arranged_NORM_ONLY_BIOPSY_TARGET.csv
282
283
print("Readin' "..tostring(fileName));
284
285
local f = csv.open(fileName)
286
local column_names = {}
287
288
local j = 0
289
290
for fields in f:lines() do
291
    if j>0 then
292
        profile_vett[j] = {}
293
        for i, v in ipairs(fields) do 
294
            profile_vett[j][i] = tonumber(v);
295
        end
296
        j = j + 1
297
    else
298
        for i, v in ipairs(fields) do 
299
            column_names[i] = v
300
        end
301
        j = j + 1
302
    end
303
end
304
305
OPTIM_PACKAGE = true
306
MAX_VALUE = 1
307
local output_number = 1
308
THRESHOLD = 0.5 -- ORIGINAL
309
-- THRESHOLD = 0.1529
310
XAVIER_INITIALIZATION = false
311
DROPOUT_FLAG = false
312
MOMENTUM = false
313
MOMENTUM_ALPHA = 0.5
314
LEARN_RATE = 0.01 -- default was 0.01
315
ITERATIONS = 200 -- default was 200 -- I'M ANALYZING THIS PARAMETER IN THIS ANALYSIS
316
317
local hidden_units = 50 -- default was 50
318
local mcc = "mcc"
319
local aupr = "aupr"
320
OPTIMIZE_SCORE = mcc
321
322
323
print("\nOPTIM_PACKAGE  = ".. tostring(OPTIM_PACKAGE));
324
print("XAVIER_INITIALIZATION = ".. tostring(XAVIER_INITIALIZATION));
325
print("DROPOUT_FLAG = ".. tostring(DROPOUT_FLAG));
326
print("MOMENTUM_ALPHA = ".. tostring(MOMENTUM_ALPHA));
327
328
print("MOMENTUM = ".. tostring(MOMENTUM));
329
print("LEARN_RATE = ".. tostring(LEARN_RATE)); 
330
print("ITERATIONS = ".. tostring(ITERATIONS)); 
331
332
-- local hidden_layers = 1 -- best is 1
333
local hiddenUnitVect = {5,50,100,150,200,250,300,350,400}
334
local hiddenLayerVect = {1, 2, 3}
335
-- local hiddenLayerVect = {1}
336
337
338
local max_values = {}
339
-- filePointer = io.open("normalized_data_file.csv", "w")
340
341
local profile_vett_data = {}
342
local label_vett = {}
343
344
for i=1,#profile_vett do
345
    profile_vett_data[i] = {}
346
    --   io.write("#"..i.."# ")
347
    --   io.flush()
348
    for j=1,#(profile_vett[1]) do
349
350
        if j<#(profile_vett[1]) then
351
            profile_vett_data[i][j] = (profile_vett[i][j]/MAX_VALUE)
352
            -- io.write("profile_vett_data["..i.."]["..j.."] = "..profile_vett_data[i][j].." ")
353
            -- filePointer:write(round(profile_vett_data[i][j],2)..",")
354
            -- io.flush()
355
        else
356
            label_vett[i] = profile_vett[i][j]
357
            -- filePointer:write(round(label_vett[i],2)..",")
358
            -- io.flush()
359
        end    
360
    end
361
    -- filePointer:write("\n")
362
    --   io.flush()
363
end
364
365
print("Number of value profiles (rows) = "..#profile_vett_data);
366
print("Number features (columns) = "..#(profile_vett_data[1]));
367
print("Number of targets (rows) = "..#label_vett);
368
369
if NORMALIZATION==true and #max_values ~= #(profile_vett_data[1]) then
370
    print("Error: different number of max_values and features. The program will stop");
371
    os.exit();
372
end
373
374
local patient_outcome = label_vett
375
local patients_vett = profile_vett_data
376
-- print(patients_vett)
377
378
-- filePointer:close()
379
-- os.exit()
380
381
-- ########################################################
382
383
-- START
384
385
local timeStart = os.time();
386
local indexVect = {}; 
387
388
for i=1, #patients_vett do indexVect[i] = i;  end
389
permutedIndexVect = permute(indexVect, #indexVect, #indexVect);
390
391
-- VALIDATION_SET_PERC = 20
392
-- TEST_SET_SIZE = 100
393
-- 
394
-- local validation_set_size = round((VALIDATION_SET_PERC*(#patients_vett-TEST_SET_SIZE))/100)
395
-- 
396
-- print("training_set_size = "..((#patients_vett-TEST_SET_SIZE)-validation_set_size).." elements");
397
-- print("validation_set_size = "..validation_set_size.." elements\n");
398
-- 
399
-- print("TEST_SET_SIZE = "..TEST_SET_SIZE.." elements\n");
400
print("#patients_vett = "..#patients_vett);
401
402
TRAINING_SET_PERC = 60
403
VALIDATION_SET_PERC = 20
404
TEST_SET_PERC = 20
405
406
local training_set_size = round((TRAINING_SET_PERC*(#patients_vett))/100)
407
local validation_set_size = round((VALIDATION_SET_PERC*(#patients_vett))/100)
408
local test_set_size = #patients_vett - validation_set_size - training_set_size
409
410
print("\ntraining_set_size = "..training_set_size);
411
print("validation_set_size = "..validation_set_size);
412
print("test_set_size = "..test_set_size.."\n");
413
414
-- os.exit()
415
416
train_patient_profile = {}
417
validation_patient_profile = {}
418
test_patient_profile = {}
419
modelFileVect = {}
420
421
local original_validation_indexes = {}
422
423
for i=1,#patients_vett do
424
    if i>=1 and i<=(training_set_size) then
425
        train_patient_profile[#train_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}}
426
        --print("training outcome["..#train_patient_profile.."] = "..train_patient_profile[#train_patient_profile][2][1]);
427
428
    elseif i>= (training_set_size+1) and i <= (training_set_size+validation_set_size) then
429
430
        original_validation_indexes[#original_validation_indexes+1] = permutedIndexVect[i];
431
        -- print("original_validation_indexes =".. permutedIndexVect[i]);
432
        validation_patient_profile[#validation_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}}
433
        --print("validation outcome["..#validation_patient_profile.."] = "..validation_patient_profile[#validation_patient_profile][2][1]);
434
    else
435
        test_patient_profile[#test_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}}
436
    end
437
end
438
439
440
require 'nn'
441
input_number = (#(train_patient_profile[1][1]))[1]
442
443
function train_patient_profile:size() return #train_patient_profile end
444
445
function validation_patient_profile:size() return #validation_patient_profile end 
446
447
local printError = false  
448
local fileName = nil
449
local filePointer = nil
450
if printError == true then 
451
    fileName = "./mse_log/positive_error_progress"..tostring(os.time())..".csv" 
452
    filePointer = io.open(fileName, "w")  
453
end
454
455
456
-- OPTIMIZATION LOOPS  
457
local MCC_vect = {}  
458
local f1score_vect = {}  
459
local auroc_vett = {}
460
local aupr_vett = {}
461
local hus_vect = {}
462
local hl_vect = {}
463
464
for b=1,#hiddenLayerVect do
465
    for a=1,#hiddenUnitVect do
466
467
            local hidden_units = hiddenUnitVect[a]
468
            local hidden_layers = hiddenLayerVect[b]
469
            print("$$$ hidden_units = "..hidden_units.."\t hidden_layers = "..hidden_layers.." $$$")
470
471
            local perceptron = createPerceptron(input_number, hidden_units, hidden_layers, output_number)
472
            local criterion = nn.MSECriterion()  
473
            local lossSum = 0
474
            local positiveLossSum = 0
475
            local error_progress = 0
476
            local numberOfOnes = 0
477
            local positiveErrorProgress = 0
478
479
            if OPTIM_PACKAGE == false then
480
                myTrainer = nn.StochasticGradient(perceptron, criterion)
481
                myTrainer.learningRate = LEARN_RATE
482
                myTrainer.maxIteration = ITERATIONS
483
                myTrainer:train(train_patient_profile)
484
            else
485
                require 'optim'
486
                local params, gradParams = perceptron:getParameters()     
487
                local optimState = nil
488
489
                if MOMENTUM==true then 
490
                    optimState = {learningRate = LEARN_RATE}
491
                else 
492
                    optimState = {learningRate = LEARN_RATE,
493
                    momentum = MOMENTUM_ALPHA }
494
                end
495
496
                local total_runs = ITERATIONS*#train_patient_profile
497
                local loopIterations = 1
498
499
                for epoch=1,ITERATIONS do
500
                    for k=1,#train_patient_profile do
501
                        -- Function feval 
502
                        local function feval(params)
503
                            gradParams:zero()
504
                            local thisProfile = train_patient_profile[k][1]
505
                            local thisLabel = train_patient_profile[k][2]
506
                            local thisPrediction = perceptron:forward(thisProfile)
507
                        -- [-1,+1] -> [0,1]
508
                        thisPrediction = (thisPrediction+1)/2
509
                        
510
                            local loss = criterion:forward(thisPrediction, thisLabel)
511
512
                            
513
                                                    
514
                            lossSum = lossSum + loss
515
                            error_progress = lossSum*100 / (loopIterations*MAX_MSE)
516
                        
517
                            --print("thisLabel[1] = "..thisLabel[1].." positiveLossSum = "..positiveLossSum.." numberOfOnes = "..numberOfOnes);
518
                    
519
                            if thisLabel[1]==1 then
520
                                positiveLossSum = positiveLossSum + loss
521
                                numberOfOnes = numberOfOnes + 1       
522
                            end
523
                        
524
                            if (numberOfOnes > 0 ) then 
525
                                positiveErrorProgress = positiveLossSum*100 / (numberOfOnes*MAX_MSE) 
526
                            end
527
                        
528
                            if ((loopIterations*100/total_runs)*25)%100==0 then
529
                                io.write("completion: ", round((loopIterations*100/total_runs),2).."%" )
530
                                io.write(" (epoch="..epoch..")(element="..k..") loss = "..round(loss,3).." ")      
531
                                io.write("\terror progress = "..round(error_progress,5).."%\n")
532
                            end
533
                        
534
                            if printError== true then
535
                                filePointer:write(loopIterations..","..positiveErrorProgress.."\n")
536
                            end
537
        
538
                            local dloss_doutput = criterion:backward(thisPrediction, thisLabel)
539
        
540
                            perceptron:backward(thisProfile, dloss_doutput)
541
542
                            return loss,gradParams
543
                        end
544
                    optim.sgd(feval, params, optimState)
545
                    loopIterations = loopIterations+1
546
                    end     
547
                end
548
            end
549
        print("\n\n### executeTest(perceptron, validation_patient_profile)")     
550
        local testOutput = executeTest(perceptron, validation_patient_profile)
551
552
        MCC_vect[#MCC_vect+1] = testOutput[1]
553
        f1score_vect[#f1score_vect+1] = testOutput[3]
554
        auroc_vett[#auroc_vett+1] = testOutput[4]
555
        aupr_vett[#aupr_vett+1] = testOutput[5]
556
557
        hus_vect[#hus_vect+1] = hidden_units
558
        hl_vect[#hl_vect+1] = hidden_layers
559
560
        local modelFile = "./models/model_hus"..hidden_units.."_hl"..hidden_layers.."_time"..tostring(os.time());
561
        torch.save(tostring(modelFile), perceptron);     
562
        print("Saved model file: "..tostring(modelFile).."\n");
563
        modelFileVect[#modelFileVect+1] = modelFile;
564
565
    end
566
end
567
568
local maxMCC = -1
569
local maxMCCpos = -1
570
571
local max_aupr = -1
572
local max_aupr_pos = -1
573
574
for k=1,#MCC_vect do
575
    io.write("@ @ ["..k.."] ")
576
    io.write("\tAUPR = "..round(aupr_vett[k],2).."% ")
577
    io.write("\tMCC = "..round(MCC_vect[k],2))
578
    io.write("\tF1_score  = "..round(f1score_vect[k],2))
579
    io.write("\tAUROC  = "..round(auroc_vett[k],2).."% ")
580
    io.write("\thidden units = "..hus_vect[k].." ")
581
    io.write("\thidden layers = "..hl_vect[k].." ")      
582
    io.write(" @ @ \n")
583
    io.flush()
584
585
    if MCC_vect[k]>=maxMCC then 
586
        maxMCC = MCC_vect[k]
587
        maxMCCpos = k
588
    end
589
    if aupr_vett[k]>=max_aupr then 
590
        max_aupr = aupr_vett[k]
591
        max_aupr_pos = k
592
    end
593
end
594
595
-- CHOOSING THE MODEL BY OPTIMZING THE MCC OR AUPR
596
local modelFileToLoad = nil
597
if OPTIMIZE_SCORE == mcc then
598
599
    modelFileToLoad = tostring(modelFileVect[maxMCCpos])
600
    print("\nmodelFileVect["..maxMCCpos.."]\nmodelFileToLoad ="..modelFileToLoad)
601
602
elseif OPTIMIZE_SCORE == aupr then
603
604
    modelFileToLoad = tostring(modelFileVect[max_aupr_pos])
605
    print("\nmodelFileVect["..max_aupr_pos.."]\nmodelFileToLoad ="..modelFileToLoad)
606
607
end
608
609
local loadedModel = torch.load(modelFileToLoad)
610
611
print("\n\n### executeTest(loadedModel, test_patient_profile)")
612
local executeTestOutput = executeTest(loadedModel, test_patient_profile)
613
614
local lastMCC = executeTestOutput[1]
615
local lastAccuracy = executeTestOutput[2]
616
local lastF1score = executeTestOutput[3]
617
618
print("':':':':' lastMCC = "..round(lastMCC,2).."  lastF1score = "..round(lastF1score,2).." ':':':':'")
619
620
for i=1,#modelFileVect do
621
    local command = "rm "..tostring(modelFileVect[i])
622
    io.write("command: "..command.." \n")
623
    local res = sys.execute(command)
624
    -- print("command response: "..res)
625
end
626
627
if printError == true then 
628
    filePointer:close()
629
end
630
631
printTime(timeStart, " complete execution")