|
a |
|
b/make_final_split.py |
|
|
1 |
import numpy as np |
|
|
2 |
import hashlib |
|
|
3 |
|
|
|
4 |
import utils |
|
|
5 |
import utils_lung |
|
|
6 |
import pathfinder |
|
|
7 |
|
|
|
8 |
rng = np.random.RandomState(42) |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
tvt_ids = utils.load_pkl(pathfinder.VALIDATION_SPLIT_PATH) |
|
|
12 |
train_pids, valid_pids, test_pids = tvt_ids['training'], tvt_ids['validation'], tvt_ids['test'] |
|
|
13 |
all_pids = train_pids + valid_pids + test_pids |
|
|
14 |
print 'total number of pids', len(all_pids) |
|
|
15 |
|
|
|
16 |
id2label = utils_lung.read_labels(pathfinder.LABELS_PATH) |
|
|
17 |
id2label_test = utils_lung.read_test_labels(pathfinder.TEST_LABELS_PATH) |
|
|
18 |
id2label.update(id2label_test) |
|
|
19 |
n_patients = len(id2label) |
|
|
20 |
|
|
|
21 |
pos_ids = [] |
|
|
22 |
neg_ids = [] |
|
|
23 |
|
|
|
24 |
for pid, label in id2label.iteritems(): |
|
|
25 |
if label ==1 : |
|
|
26 |
pos_ids.append(pid) |
|
|
27 |
elif label == 0 : |
|
|
28 |
neg_ids.append(pid) |
|
|
29 |
else: |
|
|
30 |
raise ValueError("weird shit is going down") |
|
|
31 |
|
|
|
32 |
pos_ratio = 1. * len(pos_ids) / n_patients |
|
|
33 |
print 'pos id ratio', pos_ratio |
|
|
34 |
|
|
|
35 |
split_ratio = 0.15 |
|
|
36 |
n_target_split = int(np.round(split_ratio*n_patients)) |
|
|
37 |
print 'given split ratio', split_ratio |
|
|
38 |
print 'target split ratio', 1. * n_target_split / n_patients |
|
|
39 |
|
|
|
40 |
n_pos_ftest = int(np.round(split_ratio*len(pos_ids))) |
|
|
41 |
n_neg_ftest = int(np.round(split_ratio*len(neg_ids))) |
|
|
42 |
|
|
|
43 |
final_pos_test = rng.choice(pos_ids,n_pos_ftest, replace=False) |
|
|
44 |
final_neg_test = rng.choice(neg_ids,n_neg_ftest, replace=False) |
|
|
45 |
final_test = np.append(final_pos_test,final_neg_test) |
|
|
46 |
print 'pos id ratio final test set', 1.*len(final_pos_test) / (len(final_test)) |
|
|
47 |
|
|
|
48 |
final_train = [] |
|
|
49 |
final_pos_train = [] |
|
|
50 |
final_neg_train = [] |
|
|
51 |
for pid in all_pids: |
|
|
52 |
if pid not in final_test: |
|
|
53 |
final_train.append(pid) |
|
|
54 |
if id2label[pid] == 1: |
|
|
55 |
final_pos_train.append(pid) |
|
|
56 |
elif id2label[pid] == 0: |
|
|
57 |
final_neg_train.append(pid) |
|
|
58 |
else: |
|
|
59 |
raise ValueError("weird shit is going down") |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
|
|
|
63 |
print 'pos id ratio final train set', 1.*len(final_pos_train) / (len(final_train)) |
|
|
64 |
print 'final test/(train+test):', 1.*len(final_test) / (len(final_train) + len(final_test)) |
|
|
65 |
|
|
|
66 |
concat_str = ''.join(final_test) |
|
|
67 |
print 'md5 of concatenated pids:', hashlib.md5(concat_str).hexdigest() |
|
|
68 |
|
|
|
69 |
output = {'train':final_train, 'test':final_test} |
|
|
70 |
output_name = pathfinder.METADATA_PATH+'final_split.pkl' |
|
|
71 |
utils.save_pkl(output, output_name) |
|
|
72 |
print 'final split saved at ', output_name |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
|
|
|
76 |
|
|
|
77 |
|