Diff of /make_final_split.py [000000] .. [70b6b3]

Switch to unified view

a b/make_final_split.py
1
import numpy as np
2
import hashlib
3
4
import utils
5
import utils_lung
6
import pathfinder
7
8
rng = np.random.RandomState(42)
9
10
11
tvt_ids = utils.load_pkl(pathfinder.VALIDATION_SPLIT_PATH)
12
train_pids, valid_pids, test_pids = tvt_ids['training'], tvt_ids['validation'], tvt_ids['test']
13
all_pids = train_pids + valid_pids + test_pids
14
print 'total number of pids', len(all_pids)
15
16
id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
17
id2label_test = utils_lung.read_test_labels(pathfinder.TEST_LABELS_PATH)
18
id2label.update(id2label_test)
19
n_patients = len(id2label)
20
21
pos_ids = []
22
neg_ids = []
23
24
for pid, label in id2label.iteritems():
25
    if label ==1 :
26
        pos_ids.append(pid)
27
    elif label == 0 :
28
        neg_ids.append(pid)
29
    else:
30
        raise ValueError("weird shit is going down")
31
32
pos_ratio = 1. * len(pos_ids) / n_patients 
33
print 'pos id ratio', pos_ratio
34
35
split_ratio = 0.15
36
n_target_split = int(np.round(split_ratio*n_patients))
37
print 'given split ratio', split_ratio
38
print 'target split ratio', 1. * n_target_split / n_patients
39
40
n_pos_ftest = int(np.round(split_ratio*len(pos_ids)))
41
n_neg_ftest = int(np.round(split_ratio*len(neg_ids)))
42
43
final_pos_test = rng.choice(pos_ids,n_pos_ftest, replace=False)
44
final_neg_test = rng.choice(neg_ids,n_neg_ftest, replace=False)
45
final_test = np.append(final_pos_test,final_neg_test)
46
print 'pos id ratio final test set', 1.*len(final_pos_test) / (len(final_test)) 
47
48
final_train = []
49
final_pos_train = []
50
final_neg_train = []
51
for pid in all_pids:
52
    if pid not in final_test:
53
        final_train.append(pid)
54
        if id2label[pid] == 1:
55
            final_pos_train.append(pid)
56
        elif id2label[pid] == 0:
57
            final_neg_train.append(pid)
58
        else:
59
            raise ValueError("weird shit is going down")
60
61
62
63
print 'pos id ratio final train set', 1.*len(final_pos_train) / (len(final_train))  
64
print 'final test/(train+test):', 1.*len(final_test) / (len(final_train) + len(final_test))
65
66
concat_str = ''.join(final_test)
67
print 'md5 of concatenated pids:', hashlib.md5(concat_str).hexdigest()
68
69
output = {'train':final_train, 'test':final_test}
70
output_name = pathfinder.METADATA_PATH+'final_split.pkl'
71
utils.save_pkl(output, output_name)
72
print 'final split saved at ', output_name
73
74
75
76
77