|
a |
|
b/evaluate_submission.py |
|
|
1 |
import numpy as np |
|
|
2 |
import csv |
|
|
3 |
import collections |
|
|
4 |
import sys |
|
|
5 |
import utils_lung |
|
|
6 |
import pathfinder |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
# Call this method to know to leaderboard_performance |
|
|
10 |
def leaderboard_performance(submission_file_path): |
|
|
11 |
real = utils_lung.read_test_labels(pathfinder.TEST_LABELS_PATH) |
|
|
12 |
pred = parse_predictions(submission_file_path) |
|
|
13 |
|
|
|
14 |
real = collections.OrderedDict(sorted(real.iteritems())) |
|
|
15 |
pred = collections.OrderedDict(sorted(pred.iteritems())) |
|
|
16 |
|
|
|
17 |
check_validity(real, pred) |
|
|
18 |
|
|
|
19 |
return log_loss(real.values(), pred.values()) |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def parse_predictions(submission_file_path): |
|
|
23 |
pred = {} |
|
|
24 |
with open(submission_file_path, 'rb') as csvfile: |
|
|
25 |
reader = csv.DictReader(csvfile, delimiter=',') |
|
|
26 |
for row in reader: |
|
|
27 |
pred[row['id']] = float(row['cancer']) |
|
|
28 |
return pred |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def check_validity(real, pred): |
|
|
32 |
if len(real) != len(pred): |
|
|
33 |
raise ValueError( |
|
|
34 |
'The amount of test set labels (={}) does not match with the amount of predictions (={})'.format(len(real), |
|
|
35 |
len(pred))) |
|
|
36 |
|
|
|
37 |
if len(real.viewkeys() & pred.viewkeys()) != len(real): |
|
|
38 |
raise ValueError( |
|
|
39 |
'The patients in the test set does not match with the patients in the predictions' |
|
|
40 |
) |
|
|
41 |
|
|
|
42 |
if real.viewkeys() != pred.viewkeys(): |
|
|
43 |
raise ValueError( |
|
|
44 |
'The patients in the test set does not match with the patients in the predictions' |
|
|
45 |
) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def log_loss(y_real, y_pred, eps=1e-15): |
|
|
49 |
y_pred = np.clip(y_pred, eps, 1 - eps) |
|
|
50 |
y_real = np.array(y_real) |
|
|
51 |
losses = y_real * np.log(y_pred) + (1 - y_real) * np.log(1 - y_pred) |
|
|
52 |
return - np.average(losses) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
if __name__ == '__main__': |
|
|
56 |
# if len(sys.argv) < 2: |
|
|
57 |
# sys.exit("Usage: evaluate_submission.py <absolute path to csv") |
|
|
58 |
# |
|
|
59 |
# submission_path = sys.argv[1] |
|
|
60 |
submission_path = '/home/user/Downloads/submission_0.55555.csv' |
|
|
61 |
loss = leaderboard_performance(submission_path) |
|
|
62 |
print loss |