|
a |
|
b/bin/ann_script_val.lua |
|
|
1 |
|
|
|
2 |
|
|
|
3 |
print('\n\n @ @ @ @ @ @ START @ @ @ @ @ @ @ '); |
|
|
4 |
print('file: script_start.lua'); |
|
|
5 |
print('author Davide Chicco <davide.chicco@gmail.com>'); |
|
|
6 |
print(os.date("%c", os.time())); |
|
|
7 |
|
|
|
8 |
MAX_MSE = 4 |
|
|
9 |
-- RANDOM_SEED = 42 |
|
|
10 |
|
|
|
11 |
-- torch.manualSeed(RANDOM_SEED) |
|
|
12 |
-- math.randomseed(RANDOM_SEED) |
|
|
13 |
|
|
|
14 |
local timeStart = os.time() |
|
|
15 |
|
|
|
16 |
-- createPerceptron |
|
|
17 |
function createPerceptron(this_input_number, this_hidden_units, this_hidden_layers, this_output_number) |
|
|
18 |
perceptron = nn.Sequential() |
|
|
19 |
|
|
|
20 |
perceptron:add(nn.Linear(this_input_number, this_hidden_units)) |
|
|
21 |
-- perceptron:add(nn.Sigmoid()) |
|
|
22 |
perceptron:add(nn.ReLU()) |
|
|
23 |
if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end |
|
|
24 |
|
|
|
25 |
for w=1,this_hidden_layers do |
|
|
26 |
perceptron:add(nn.Linear(this_hidden_units, this_hidden_units)) |
|
|
27 |
-- perceptron:add(nn.Sigmoid()) |
|
|
28 |
perceptron:add(nn.ReLU()) |
|
|
29 |
if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end |
|
|
30 |
end |
|
|
31 |
perceptron:add(nn.Linear(this_hidden_units, this_output_number)) |
|
|
32 |
|
|
|
33 |
if XAVIER_INITIALIZATION==true then |
|
|
34 |
print("XAVIER_INITIALIZATION = "..tostring(XAVIER_INITIALIZATION)) |
|
|
35 |
perceptron = require("./weight-init.lua")(perceptron, 'xavier') -- XAVIER |
|
|
36 |
end |
|
|
37 |
|
|
|
38 |
return perceptron; |
|
|
39 |
end |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
-- function executeTest |
|
|
43 |
function executeTest(testPerceptron, dataset_patient_profile) |
|
|
44 |
local correctPredictions = 0 |
|
|
45 |
local atleastOneTrue = false |
|
|
46 |
local atleastOneFalse = false |
|
|
47 |
local predictionTestVect = {} |
|
|
48 |
local truthVect = {} |
|
|
49 |
|
|
|
50 |
for i=1,#dataset_patient_profile do |
|
|
51 |
local current_label = dataset_patient_profile[i][2][1] |
|
|
52 |
local original_prediction = testPerceptron:forward(dataset_patient_profile[i][1])[1] |
|
|
53 |
|
|
|
54 |
-- io.write("original_prediction = ".. original_prediction) |
|
|
55 |
prediction = original_prediction --(original_prediction+1)/2 |
|
|
56 |
predictionTestVect[i] = prediction |
|
|
57 |
truthVect[i] = current_label |
|
|
58 |
|
|
|
59 |
-- io.write(" prediction = ".. round(prediction,2)) |
|
|
60 |
-- io.write(" current_label = ".. current_label.."\n") |
|
|
61 |
-- io.flush() |
|
|
62 |
|
|
|
63 |
local labelResult = false |
|
|
64 |
if current_label >= THRESHOLD and prediction >= THRESHOLD then |
|
|
65 |
labelResult = true |
|
|
66 |
elseif current_label < THRESHOLD and prediction < THRESHOLD then |
|
|
67 |
labelResult = true |
|
|
68 |
end |
|
|
69 |
|
|
|
70 |
if labelResult==true then correctPredictions = correctPredictions + 1; end |
|
|
71 |
if prediction>=THRESHOLD then |
|
|
72 |
atleastOneTrue = true |
|
|
73 |
else |
|
|
74 |
atleastOneFalse = true |
|
|
75 |
end |
|
|
76 |
end |
|
|
77 |
print("\nCorrect predictions = "..round(correctPredictions*100/#dataset_patient_profile,2).."%") |
|
|
78 |
|
|
|
79 |
if atleastOneTrue==false then print("ATTENTION: all the predictions are FALSE") end |
|
|
80 |
if atleastOneFalse==false then print("ATTENTION: all the predictions are TRUE") end |
|
|
81 |
|
|
|
82 |
require './metrics_ROC_AUC_computer.lua' |
|
|
83 |
local output_AUC_computer = metrics_ROC_AUC_computer(predictionTestVect, truthVect) |
|
|
84 |
local auroc = output_AUC_computer[1] |
|
|
85 |
local aupr = output_AUC_computer[2] |
|
|
86 |
|
|
|
87 |
local printValues = false |
|
|
88 |
local output_confusion_matrix = confusion_matrix(predictionTestVect, truthVect, THRESHOLD, printValues) |
|
|
89 |
|
|
|
90 |
return {output_confusion_matrix[4], output_confusion_matrix[1], output_confusion_matrix[5], auroc, aupr}; |
|
|
91 |
-- MCC, accuracy, f1_score, AUROC, AUPR |
|
|
92 |
end |
|
|
93 |
|
|
|
94 |
-- Function sleep |
|
|
95 |
function sleep(n) os.execute("sleep " .. tonumber(n)); end |
|
|
96 |
|
|
|
97 |
-- Function table.contains |
|
|
98 |
function table.contains(table, element) |
|
|
99 |
local count = 1 |
|
|
100 |
for _, value in pairs(table) do |
|
|
101 |
-- print("value: "..tostring(value).." element: "..tostring(element)); |
|
|
102 |
if tostring(value) == tostring(element) or value==element then |
|
|
103 |
return {true,count} |
|
|
104 |
end |
|
|
105 |
|
|
|
106 |
count = count + 1 |
|
|
107 |
end |
|
|
108 |
return {false,-1} |
|
|
109 |
end |
|
|
110 |
|
|
|
111 |
-- Function that prints |
|
|
112 |
function printTime(timeStart, stringToPrint) |
|
|
113 |
timeEnd = os.time(); |
|
|
114 |
duration = timeEnd - timeStart; |
|
|
115 |
print('\nduration '..stringToPrint.. ': '.. comma_value(tonumber(duration)).. ' seconds'); |
|
|
116 |
io.flush(); |
|
|
117 |
print('duration '..stringToPrint.. ': '..string.format("%.2d days, %.2d hours, %.2d minutes, %.2d seconds", (duration/(60*60))/24, duration/(60*60)%24, duration/60%60, duration%60)) |
|
|
118 |
io.flush(); |
|
|
119 |
|
|
|
120 |
return duration; |
|
|
121 |
end |
|
|
122 |
|
|
|
123 |
-- Function that reads a value and returns the string of the signed value |
|
|
124 |
function signedValueFunction(value) |
|
|
125 |
local value = tonumber(value); |
|
|
126 |
--print("value = "..value); |
|
|
127 |
local charPlus = "" |
|
|
128 |
if tonumber(value) >= 0 then charPlus = "+"; end |
|
|
129 |
local outputString = charPlus..""..tostring(round(value,2)); |
|
|
130 |
--print("outputString = "..outputString); |
|
|
131 |
return tostring(outputString); |
|
|
132 |
end |
|
|
133 |
|
|
|
134 |
-- from sam_lie |
|
|
135 |
-- Compatible with Lua 5.0 and 5.1. |
|
|
136 |
-- Disclaimer : use at own risk especially for hedge fund reports :-) |
|
|
137 |
--============================================================ |
|
|
138 |
-- add comma to separate thousands |
|
|
139 |
-- |
|
|
140 |
function comma_value(amount) |
|
|
141 |
local formatted = amount |
|
|
142 |
while true do |
|
|
143 |
formatted, k = string.gsub(formatted, "^(-?%d+)(%d%d%d)", '%1,%2') |
|
|
144 |
if (k==0) then |
|
|
145 |
break |
|
|
146 |
end |
|
|
147 |
end |
|
|
148 |
return formatted |
|
|
149 |
end |
|
|
150 |
|
|
|
151 |
-- function that computes the confusion matrix |
|
|
152 |
function confusion_matrix(predictionTestVect, truthVect, threshold, printValues) |
|
|
153 |
|
|
|
154 |
local tp = 0 |
|
|
155 |
local tn = 0 |
|
|
156 |
local fp = 0 |
|
|
157 |
local fn = 0 |
|
|
158 |
local f1_score = -2 |
|
|
159 |
local MatthewsCC = -2 |
|
|
160 |
local accuracy = -2 |
|
|
161 |
local arrayFPindices = {} |
|
|
162 |
local arrayFPvalues = {} |
|
|
163 |
local arrayTPvalues = {} |
|
|
164 |
local areaRoc = 0 |
|
|
165 |
|
|
|
166 |
local fpRateVett = {} |
|
|
167 |
local tpRateVett = {} |
|
|
168 |
local precisionVett = {} |
|
|
169 |
local recallVett = {} |
|
|
170 |
|
|
|
171 |
for i=1,#predictionTestVect do |
|
|
172 |
|
|
|
173 |
if printValues == true then |
|
|
174 |
io.write("predictionTestVect["..i.."] = ".. round(predictionTestVect[i],4).."\ttruthVect["..i.."] = "..truthVect[i].." "); |
|
|
175 |
io.flush(); |
|
|
176 |
end |
|
|
177 |
|
|
|
178 |
if predictionTestVect[i] >= threshold and truthVect[i] >= threshold then |
|
|
179 |
tp = tp + 1 |
|
|
180 |
arrayTPvalues[#arrayTPvalues+1] = predictionTestVect[i] |
|
|
181 |
if printValues == true then print(" TP ") end |
|
|
182 |
elseif predictionTestVect[i] < threshold and truthVect[i] >= threshold then |
|
|
183 |
fn = fn + 1 |
|
|
184 |
if printValues == true then print(" FN ") end |
|
|
185 |
elseif predictionTestVect[i] >= threshold and truthVect[i] < threshold then |
|
|
186 |
fp = fp + 1 |
|
|
187 |
if printValues == true then print(" FP ") end |
|
|
188 |
arrayFPindices[#arrayFPindices+1] = i; |
|
|
189 |
arrayFPvalues[#arrayFPvalues+1] = predictionTestVect[i] |
|
|
190 |
elseif predictionTestVect[i] < threshold and truthVect[i] < threshold then |
|
|
191 |
tn = tn + 1 |
|
|
192 |
if printValues == true then print(" TN ") end |
|
|
193 |
end |
|
|
194 |
end |
|
|
195 |
|
|
|
196 |
print("TOTAL:") |
|
|
197 |
print(" FN = "..comma_value(fn).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction < threshold)"); |
|
|
198 |
print(" TP = "..comma_value(tp).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction >= threshold)\n"); |
|
|
199 |
print(" FP = "..comma_value(fp).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction >= threshold)"); |
|
|
200 |
print(" TN = "..comma_value(tn).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction < threshold)\n"); |
|
|
201 |
|
|
|
202 |
local continueLabel = true |
|
|
203 |
|
|
|
204 |
if continueLabel then |
|
|
205 |
upperMCC = (tp*tn) - (fp*fn) |
|
|
206 |
innerSquare = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) |
|
|
207 |
lowerMCC = math.sqrt(innerSquare) |
|
|
208 |
MatthewsCC = -2 |
|
|
209 |
if lowerMCC>0 then MatthewsCC = upperMCC/lowerMCC end |
|
|
210 |
local signedMCC = signedValueFunction(MatthewsCC); |
|
|
211 |
-- print("signedMCC = "..signedMCC); |
|
|
212 |
|
|
|
213 |
if MatthewsCC > -2 then print("\n::::\tMatthews correlation coefficient = "..signedMCC.."\t::::\n"); |
|
|
214 |
else print("Matthews correlation coefficient = NOT computable"); end |
|
|
215 |
|
|
|
216 |
accuracy = (tp + tn)/(tp + tn +fn + fp) |
|
|
217 |
print("accuracy = "..round(accuracy,2).. " = (tp + tn) / (tp + tn +fn + fp) \t \t [worst = 0, best = 1]"); |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
recall = tp / (tp + fn) |
|
|
221 |
print("true positive rate = recall = "..round(recall,2).. " = tp / (tp + fn) \t \t [worst = 0, best = 1]"); |
|
|
222 |
|
|
|
223 |
specificity = tn / (tn + fp) |
|
|
224 |
print("true negative rate = specificity = "..round(specificity,2).. " = tn / (tn + fp) \t \t [worst = 0, best = 1]"); |
|
|
225 |
|
|
|
226 |
f1_score = -2 |
|
|
227 |
if (tp+fp+fn)>0 then |
|
|
228 |
f1_score = (2*tp) / (2*tp+fp+fn) |
|
|
229 |
print("f1_score = "..round(f1_score,2).." = (2*tp) / (2*tp+fp+fn) \t [worst = 0, best = 1]"); |
|
|
230 |
else |
|
|
231 |
print("f1_score CANNOT be computed because (tp+fp+fn)==0") |
|
|
232 |
end |
|
|
233 |
|
|
|
234 |
print("\n\nMCC \t\t F1_score \t accuracy \t TP_rate \t TN_rate") |
|
|
235 |
print(round(signedMCC,2).. " \t\t "..round(f1_score,2).." \t\t "..round(accuracy,2).." \t\t "..round(recall,2).." \t\t ".. round(specificity,2).."\n\n") |
|
|
236 |
|
|
|
237 |
|
|
|
238 |
|
|
|
239 |
-- local numberOfPredictedOnes = tp + fp; |
|
|
240 |
-- print("numberOfPredictedOnes = (TP + FP) = "..comma_value(numberOfPredictedOnes).." = "..round(numberOfPredictedOnes*100/(tp + tn + fn + fp),2).."%"); |
|
|
241 |
-- |
|
|
242 |
-- io.write("\nDiagnosis: "); |
|
|
243 |
-- if (fn >= tp and (fn+tp)>0) then print("too many FN false negatives"); end |
|
|
244 |
-- if (fp >= tn and (fp+tn)>0) then print("too many FP false positives"); end |
|
|
245 |
-- |
|
|
246 |
-- |
|
|
247 |
-- if (tn > (10*fp) and tp > (10*fn)) then print("Excellent ! ! !"); |
|
|
248 |
-- elseif (tn > (5*fp) and tp > (5*fn)) then print("Very good ! !"); |
|
|
249 |
-- elseif (tn > (2*fp) and tp > (2*fn)) then print("Good !"); |
|
|
250 |
-- elseif (tn >= fp and tp >= fn) then print("Alright"); |
|
|
251 |
-- else print("Baaaad"); end |
|
|
252 |
end |
|
|
253 |
|
|
|
254 |
return {accuracy, arrayFPindices, arrayFPvalues, MatthewsCC, f1_score}; |
|
|
255 |
end |
|
|
256 |
|
|
|
257 |
-- Permutations |
|
|
258 |
-- tab = {1,2,3,4,5,6,7,8,9,10} |
|
|
259 |
-- permute(tab, 10, 10) |
|
|
260 |
function permute(tab, n, count) |
|
|
261 |
n = n or #tab |
|
|
262 |
for i = 1, count or n do |
|
|
263 |
local j = math.random(i, n) |
|
|
264 |
tab[i], tab[j] = tab[j], tab[i] |
|
|
265 |
end |
|
|
266 |
return tab |
|
|
267 |
end |
|
|
268 |
|
|
|
269 |
-- round a real value |
|
|
270 |
function round(num, idp) |
|
|
271 |
local mult = 10^(idp or 0) |
|
|
272 |
return math.floor(num * mult + 0.5) / mult |
|
|
273 |
end |
|
|
274 |
|
|
|
275 |
-- ############################## |
|
|
276 |
local profile_vett = {} |
|
|
277 |
local csv = require("csv") |
|
|
278 |
local fileName = "../data/LungCancerDataset_AllRecords_NORM_reduced_features.csv" |
|
|
279 |
-- tostring(arg[1]) |
|
|
280 |
-- cervical_arranged_NORM.csv |
|
|
281 |
-- cervical_arranged_NORM_ONLY_BIOPSY_TARGET.csv |
|
|
282 |
|
|
|
283 |
print("Readin' "..tostring(fileName)); |
|
|
284 |
|
|
|
285 |
local f = csv.open(fileName) |
|
|
286 |
local column_names = {} |
|
|
287 |
|
|
|
288 |
local j = 0 |
|
|
289 |
|
|
|
290 |
for fields in f:lines() do |
|
|
291 |
if j>0 then |
|
|
292 |
profile_vett[j] = {} |
|
|
293 |
for i, v in ipairs(fields) do |
|
|
294 |
profile_vett[j][i] = tonumber(v); |
|
|
295 |
end |
|
|
296 |
j = j + 1 |
|
|
297 |
else |
|
|
298 |
for i, v in ipairs(fields) do |
|
|
299 |
column_names[i] = v |
|
|
300 |
end |
|
|
301 |
j = j + 1 |
|
|
302 |
end |
|
|
303 |
end |
|
|
304 |
|
|
|
305 |
OPTIM_PACKAGE = true |
|
|
306 |
MAX_VALUE = 1 |
|
|
307 |
local output_number = 1 |
|
|
308 |
THRESHOLD = 0.5 -- ORIGINAL |
|
|
309 |
-- THRESHOLD = 0.1529 |
|
|
310 |
XAVIER_INITIALIZATION = false |
|
|
311 |
DROPOUT_FLAG = false |
|
|
312 |
MOMENTUM = false |
|
|
313 |
MOMENTUM_ALPHA = 0.5 |
|
|
314 |
LEARN_RATE = 0.01 -- default was 0.01 |
|
|
315 |
ITERATIONS = 200 -- default was 200 -- I'M ANALYZING THIS PARAMETER IN THIS ANALYSIS |
|
|
316 |
|
|
|
317 |
local hidden_units = 50 -- default was 50 |
|
|
318 |
local mcc = "mcc" |
|
|
319 |
local aupr = "aupr" |
|
|
320 |
OPTIMIZE_SCORE = mcc |
|
|
321 |
|
|
|
322 |
|
|
|
323 |
print("\nOPTIM_PACKAGE = ".. tostring(OPTIM_PACKAGE)); |
|
|
324 |
print("XAVIER_INITIALIZATION = ".. tostring(XAVIER_INITIALIZATION)); |
|
|
325 |
print("DROPOUT_FLAG = ".. tostring(DROPOUT_FLAG)); |
|
|
326 |
print("MOMENTUM_ALPHA = ".. tostring(MOMENTUM_ALPHA)); |
|
|
327 |
|
|
|
328 |
print("MOMENTUM = ".. tostring(MOMENTUM)); |
|
|
329 |
print("LEARN_RATE = ".. tostring(LEARN_RATE)); |
|
|
330 |
print("ITERATIONS = ".. tostring(ITERATIONS)); |
|
|
331 |
|
|
|
332 |
-- local hidden_layers = 1 -- best is 1 |
|
|
333 |
local hiddenUnitVect = {5,50,100,150,200,250,300,350,400} |
|
|
334 |
local hiddenLayerVect = {1, 2, 3} |
|
|
335 |
-- local hiddenLayerVect = {1} |
|
|
336 |
|
|
|
337 |
|
|
|
338 |
local max_values = {} |
|
|
339 |
-- filePointer = io.open("normalized_data_file.csv", "w") |
|
|
340 |
|
|
|
341 |
local profile_vett_data = {} |
|
|
342 |
local label_vett = {} |
|
|
343 |
|
|
|
344 |
for i=1,#profile_vett do |
|
|
345 |
profile_vett_data[i] = {} |
|
|
346 |
-- io.write("#"..i.."# ") |
|
|
347 |
-- io.flush() |
|
|
348 |
for j=1,#(profile_vett[1]) do |
|
|
349 |
|
|
|
350 |
if j<#(profile_vett[1]) then |
|
|
351 |
profile_vett_data[i][j] = (profile_vett[i][j]/MAX_VALUE) |
|
|
352 |
-- io.write("profile_vett_data["..i.."]["..j.."] = "..profile_vett_data[i][j].." ") |
|
|
353 |
-- filePointer:write(round(profile_vett_data[i][j],2)..",") |
|
|
354 |
-- io.flush() |
|
|
355 |
else |
|
|
356 |
label_vett[i] = profile_vett[i][j] |
|
|
357 |
-- filePointer:write(round(label_vett[i],2)..",") |
|
|
358 |
-- io.flush() |
|
|
359 |
end |
|
|
360 |
end |
|
|
361 |
-- filePointer:write("\n") |
|
|
362 |
-- io.flush() |
|
|
363 |
end |
|
|
364 |
|
|
|
365 |
print("Number of value profiles (rows) = "..#profile_vett_data); |
|
|
366 |
print("Number features (columns) = "..#(profile_vett_data[1])); |
|
|
367 |
print("Number of targets (rows) = "..#label_vett); |
|
|
368 |
|
|
|
369 |
if NORMALIZATION==true and #max_values ~= #(profile_vett_data[1]) then |
|
|
370 |
print("Error: different number of max_values and features. The program will stop"); |
|
|
371 |
os.exit(); |
|
|
372 |
end |
|
|
373 |
|
|
|
374 |
local patient_outcome = label_vett |
|
|
375 |
local patients_vett = profile_vett_data |
|
|
376 |
-- print(patients_vett) |
|
|
377 |
|
|
|
378 |
-- filePointer:close() |
|
|
379 |
-- os.exit() |
|
|
380 |
|
|
|
381 |
-- ######################################################## |
|
|
382 |
|
|
|
383 |
-- START |
|
|
384 |
|
|
|
385 |
local timeStart = os.time(); |
|
|
386 |
local indexVect = {}; |
|
|
387 |
|
|
|
388 |
for i=1, #patients_vett do indexVect[i] = i; end |
|
|
389 |
permutedIndexVect = permute(indexVect, #indexVect, #indexVect); |
|
|
390 |
|
|
|
391 |
-- VALIDATION_SET_PERC = 20 |
|
|
392 |
-- TEST_SET_SIZE = 100 |
|
|
393 |
-- |
|
|
394 |
-- local validation_set_size = round((VALIDATION_SET_PERC*(#patients_vett-TEST_SET_SIZE))/100) |
|
|
395 |
-- |
|
|
396 |
-- print("training_set_size = "..((#patients_vett-TEST_SET_SIZE)-validation_set_size).." elements"); |
|
|
397 |
-- print("validation_set_size = "..validation_set_size.." elements\n"); |
|
|
398 |
-- |
|
|
399 |
-- print("TEST_SET_SIZE = "..TEST_SET_SIZE.." elements\n"); |
|
|
400 |
print("#patients_vett = "..#patients_vett); |
|
|
401 |
|
|
|
402 |
TRAINING_SET_PERC = 60 |
|
|
403 |
VALIDATION_SET_PERC = 20 |
|
|
404 |
TEST_SET_PERC = 20 |
|
|
405 |
|
|
|
406 |
local training_set_size = round((TRAINING_SET_PERC*(#patients_vett))/100) |
|
|
407 |
local validation_set_size = round((VALIDATION_SET_PERC*(#patients_vett))/100) |
|
|
408 |
local test_set_size = #patients_vett - validation_set_size - training_set_size |
|
|
409 |
|
|
|
410 |
print("\ntraining_set_size = "..training_set_size); |
|
|
411 |
print("validation_set_size = "..validation_set_size); |
|
|
412 |
print("test_set_size = "..test_set_size.."\n"); |
|
|
413 |
|
|
|
414 |
-- os.exit() |
|
|
415 |
|
|
|
416 |
train_patient_profile = {} |
|
|
417 |
validation_patient_profile = {} |
|
|
418 |
test_patient_profile = {} |
|
|
419 |
modelFileVect = {} |
|
|
420 |
|
|
|
421 |
local original_validation_indexes = {} |
|
|
422 |
|
|
|
423 |
for i=1,#patients_vett do |
|
|
424 |
if i>=1 and i<=(training_set_size) then |
|
|
425 |
train_patient_profile[#train_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}} |
|
|
426 |
--print("training outcome["..#train_patient_profile.."] = "..train_patient_profile[#train_patient_profile][2][1]); |
|
|
427 |
|
|
|
428 |
elseif i>= (training_set_size+1) and i <= (training_set_size+validation_set_size) then |
|
|
429 |
|
|
|
430 |
original_validation_indexes[#original_validation_indexes+1] = permutedIndexVect[i]; |
|
|
431 |
-- print("original_validation_indexes =".. permutedIndexVect[i]); |
|
|
432 |
validation_patient_profile[#validation_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}} |
|
|
433 |
--print("validation outcome["..#validation_patient_profile.."] = "..validation_patient_profile[#validation_patient_profile][2][1]); |
|
|
434 |
else |
|
|
435 |
test_patient_profile[#test_patient_profile+1] = {torch.Tensor(patients_vett[permutedIndexVect[i]]), torch.Tensor{patient_outcome[permutedIndexVect[i]]}} |
|
|
436 |
end |
|
|
437 |
end |
|
|
438 |
|
|
|
439 |
|
|
|
440 |
require 'nn' |
|
|
441 |
input_number = (#(train_patient_profile[1][1]))[1] |
|
|
442 |
|
|
|
443 |
function train_patient_profile:size() return #train_patient_profile end |
|
|
444 |
|
|
|
445 |
function validation_patient_profile:size() return #validation_patient_profile end |
|
|
446 |
|
|
|
447 |
local printError = false |
|
|
448 |
local fileName = nil |
|
|
449 |
local filePointer = nil |
|
|
450 |
if printError == true then |
|
|
451 |
fileName = "./mse_log/positive_error_progress"..tostring(os.time())..".csv" |
|
|
452 |
filePointer = io.open(fileName, "w") |
|
|
453 |
end |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
-- OPTIMIZATION LOOPS |
|
|
457 |
local MCC_vect = {} |
|
|
458 |
local f1score_vect = {} |
|
|
459 |
local auroc_vett = {} |
|
|
460 |
local aupr_vett = {} |
|
|
461 |
local hus_vect = {} |
|
|
462 |
local hl_vect = {} |
|
|
463 |
|
|
|
464 |
for b=1,#hiddenLayerVect do |
|
|
465 |
for a=1,#hiddenUnitVect do |
|
|
466 |
|
|
|
467 |
local hidden_units = hiddenUnitVect[a] |
|
|
468 |
local hidden_layers = hiddenLayerVect[b] |
|
|
469 |
print("$$$ hidden_units = "..hidden_units.."\t hidden_layers = "..hidden_layers.." $$$") |
|
|
470 |
|
|
|
471 |
local perceptron = createPerceptron(input_number, hidden_units, hidden_layers, output_number) |
|
|
472 |
local criterion = nn.MSECriterion() |
|
|
473 |
local lossSum = 0 |
|
|
474 |
local positiveLossSum = 0 |
|
|
475 |
local error_progress = 0 |
|
|
476 |
local numberOfOnes = 0 |
|
|
477 |
local positiveErrorProgress = 0 |
|
|
478 |
|
|
|
479 |
if OPTIM_PACKAGE == false then |
|
|
480 |
myTrainer = nn.StochasticGradient(perceptron, criterion) |
|
|
481 |
myTrainer.learningRate = LEARN_RATE |
|
|
482 |
myTrainer.maxIteration = ITERATIONS |
|
|
483 |
myTrainer:train(train_patient_profile) |
|
|
484 |
else |
|
|
485 |
require 'optim' |
|
|
486 |
local params, gradParams = perceptron:getParameters() |
|
|
487 |
local optimState = nil |
|
|
488 |
|
|
|
489 |
if MOMENTUM==true then |
|
|
490 |
optimState = {learningRate = LEARN_RATE} |
|
|
491 |
else |
|
|
492 |
optimState = {learningRate = LEARN_RATE, |
|
|
493 |
momentum = MOMENTUM_ALPHA } |
|
|
494 |
end |
|
|
495 |
|
|
|
496 |
local total_runs = ITERATIONS*#train_patient_profile |
|
|
497 |
local loopIterations = 1 |
|
|
498 |
|
|
|
499 |
for epoch=1,ITERATIONS do |
|
|
500 |
for k=1,#train_patient_profile do |
|
|
501 |
-- Function feval |
|
|
502 |
local function feval(params) |
|
|
503 |
gradParams:zero() |
|
|
504 |
local thisProfile = train_patient_profile[k][1] |
|
|
505 |
local thisLabel = train_patient_profile[k][2] |
|
|
506 |
local thisPrediction = perceptron:forward(thisProfile) |
|
|
507 |
-- [-1,+1] -> [0,1] |
|
|
508 |
thisPrediction = (thisPrediction+1)/2 |
|
|
509 |
|
|
|
510 |
local loss = criterion:forward(thisPrediction, thisLabel) |
|
|
511 |
|
|
|
512 |
|
|
|
513 |
|
|
|
514 |
lossSum = lossSum + loss |
|
|
515 |
error_progress = lossSum*100 / (loopIterations*MAX_MSE) |
|
|
516 |
|
|
|
517 |
--print("thisLabel[1] = "..thisLabel[1].." positiveLossSum = "..positiveLossSum.." numberOfOnes = "..numberOfOnes); |
|
|
518 |
|
|
|
519 |
if thisLabel[1]==1 then |
|
|
520 |
positiveLossSum = positiveLossSum + loss |
|
|
521 |
numberOfOnes = numberOfOnes + 1 |
|
|
522 |
end |
|
|
523 |
|
|
|
524 |
if (numberOfOnes > 0 ) then |
|
|
525 |
positiveErrorProgress = positiveLossSum*100 / (numberOfOnes*MAX_MSE) |
|
|
526 |
end |
|
|
527 |
|
|
|
528 |
if ((loopIterations*100/total_runs)*25)%100==0 then |
|
|
529 |
io.write("completion: ", round((loopIterations*100/total_runs),2).."%" ) |
|
|
530 |
io.write(" (epoch="..epoch..")(element="..k..") loss = "..round(loss,3).." ") |
|
|
531 |
io.write("\terror progress = "..round(error_progress,5).."%\n") |
|
|
532 |
end |
|
|
533 |
|
|
|
534 |
if printError== true then |
|
|
535 |
filePointer:write(loopIterations..","..positiveErrorProgress.."\n") |
|
|
536 |
end |
|
|
537 |
|
|
|
538 |
local dloss_doutput = criterion:backward(thisPrediction, thisLabel) |
|
|
539 |
|
|
|
540 |
perceptron:backward(thisProfile, dloss_doutput) |
|
|
541 |
|
|
|
542 |
return loss,gradParams |
|
|
543 |
end |
|
|
544 |
optim.sgd(feval, params, optimState) |
|
|
545 |
loopIterations = loopIterations+1 |
|
|
546 |
end |
|
|
547 |
end |
|
|
548 |
end |
|
|
549 |
print("\n\n### executeTest(perceptron, validation_patient_profile)") |
|
|
550 |
local testOutput = executeTest(perceptron, validation_patient_profile) |
|
|
551 |
|
|
|
552 |
MCC_vect[#MCC_vect+1] = testOutput[1] |
|
|
553 |
f1score_vect[#f1score_vect+1] = testOutput[3] |
|
|
554 |
auroc_vett[#auroc_vett+1] = testOutput[4] |
|
|
555 |
aupr_vett[#aupr_vett+1] = testOutput[5] |
|
|
556 |
|
|
|
557 |
hus_vect[#hus_vect+1] = hidden_units |
|
|
558 |
hl_vect[#hl_vect+1] = hidden_layers |
|
|
559 |
|
|
|
560 |
local modelFile = "./models/model_hus"..hidden_units.."_hl"..hidden_layers.."_time"..tostring(os.time()); |
|
|
561 |
torch.save(tostring(modelFile), perceptron); |
|
|
562 |
print("Saved model file: "..tostring(modelFile).."\n"); |
|
|
563 |
modelFileVect[#modelFileVect+1] = modelFile; |
|
|
564 |
|
|
|
565 |
end |
|
|
566 |
end |
|
|
567 |
|
|
|
568 |
local maxMCC = -1 |
|
|
569 |
local maxMCCpos = -1 |
|
|
570 |
|
|
|
571 |
local max_aupr = -1 |
|
|
572 |
local max_aupr_pos = -1 |
|
|
573 |
|
|
|
574 |
for k=1,#MCC_vect do |
|
|
575 |
io.write("@ @ ["..k.."] ") |
|
|
576 |
io.write("\tAUPR = "..round(aupr_vett[k],2).."% ") |
|
|
577 |
io.write("\tMCC = "..round(MCC_vect[k],2)) |
|
|
578 |
io.write("\tF1_score = "..round(f1score_vect[k],2)) |
|
|
579 |
io.write("\tAUROC = "..round(auroc_vett[k],2).."% ") |
|
|
580 |
io.write("\thidden units = "..hus_vect[k].." ") |
|
|
581 |
io.write("\thidden layers = "..hl_vect[k].." ") |
|
|
582 |
io.write(" @ @ \n") |
|
|
583 |
io.flush() |
|
|
584 |
|
|
|
585 |
if MCC_vect[k]>=maxMCC then |
|
|
586 |
maxMCC = MCC_vect[k] |
|
|
587 |
maxMCCpos = k |
|
|
588 |
end |
|
|
589 |
if aupr_vett[k]>=max_aupr then |
|
|
590 |
max_aupr = aupr_vett[k] |
|
|
591 |
max_aupr_pos = k |
|
|
592 |
end |
|
|
593 |
end |
|
|
594 |
|
|
|
595 |
-- CHOOSING THE MODEL BY OPTIMZING THE MCC OR AUPR |
|
|
596 |
local modelFileToLoad = nil |
|
|
597 |
if OPTIMIZE_SCORE == mcc then |
|
|
598 |
|
|
|
599 |
modelFileToLoad = tostring(modelFileVect[maxMCCpos]) |
|
|
600 |
print("\nmodelFileVect["..maxMCCpos.."]\nmodelFileToLoad ="..modelFileToLoad) |
|
|
601 |
|
|
|
602 |
elseif OPTIMIZE_SCORE == aupr then |
|
|
603 |
|
|
|
604 |
modelFileToLoad = tostring(modelFileVect[max_aupr_pos]) |
|
|
605 |
print("\nmodelFileVect["..max_aupr_pos.."]\nmodelFileToLoad ="..modelFileToLoad) |
|
|
606 |
|
|
|
607 |
end |
|
|
608 |
|
|
|
609 |
local loadedModel = torch.load(modelFileToLoad) |
|
|
610 |
|
|
|
611 |
print("\n\n### executeTest(loadedModel, test_patient_profile)") |
|
|
612 |
local executeTestOutput = executeTest(loadedModel, test_patient_profile) |
|
|
613 |
|
|
|
614 |
local lastMCC = executeTestOutput[1] |
|
|
615 |
local lastAccuracy = executeTestOutput[2] |
|
|
616 |
local lastF1score = executeTestOutput[3] |
|
|
617 |
|
|
|
618 |
print("':':':':' lastMCC = "..round(lastMCC,2).." lastF1score = "..round(lastF1score,2).." ':':':':'") |
|
|
619 |
|
|
|
620 |
for i=1,#modelFileVect do |
|
|
621 |
local command = "rm "..tostring(modelFileVect[i]) |
|
|
622 |
io.write("command: "..command.." \n") |
|
|
623 |
local res = sys.execute(command) |
|
|
624 |
-- print("command response: "..res) |
|
|
625 |
end |
|
|
626 |
|
|
|
627 |
if printError == true then |
|
|
628 |
filePointer:close() |
|
|
629 |
end |
|
|
630 |
|
|
|
631 |
printTime(timeStart, " complete execution") |