Diff of /make_final_split.py [000000] .. [70b6b3]

Switch to side-by-side view

--- a
+++ b/make_final_split.py
@@ -0,0 +1,77 @@
+import numpy as np
+import hashlib
+
+import utils
+import utils_lung
+import pathfinder
+
+rng = np.random.RandomState(42)
+
+
+tvt_ids = utils.load_pkl(pathfinder.VALIDATION_SPLIT_PATH)
+train_pids, valid_pids, test_pids = tvt_ids['training'], tvt_ids['validation'], tvt_ids['test']
+all_pids = train_pids + valid_pids + test_pids
+print 'total number of pids', len(all_pids)
+
+id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
+id2label_test = utils_lung.read_test_labels(pathfinder.TEST_LABELS_PATH)
+id2label.update(id2label_test)
+n_patients = len(id2label)
+
+pos_ids = []
+neg_ids = []
+
+for pid, label in id2label.iteritems():
+	if label ==1 :
+		pos_ids.append(pid)
+	elif label == 0 :
+		neg_ids.append(pid)
+	else:
+		raise ValueError("weird shit is going down")
+
+pos_ratio = 1. * len(pos_ids) / n_patients 
+print 'pos id ratio', pos_ratio
+
+split_ratio = 0.15
+n_target_split = int(np.round(split_ratio*n_patients))
+print 'given split ratio', split_ratio
+print 'target split ratio', 1. * n_target_split / n_patients
+
+n_pos_ftest = int(np.round(split_ratio*len(pos_ids)))
+n_neg_ftest = int(np.round(split_ratio*len(neg_ids)))
+
+final_pos_test = rng.choice(pos_ids,n_pos_ftest, replace=False)
+final_neg_test = rng.choice(neg_ids,n_neg_ftest, replace=False)
+final_test = np.append(final_pos_test,final_neg_test)
+print 'pos id ratio final test set', 1.*len(final_pos_test) / (len(final_test))	
+
+final_train = []
+final_pos_train = []
+final_neg_train = []
+for pid in all_pids:
+	if pid not in final_test:
+		final_train.append(pid)
+		if id2label[pid] == 1:
+			final_pos_train.append(pid)
+		elif id2label[pid] == 0:
+			final_neg_train.append(pid)
+		else:
+			raise ValueError("weird shit is going down")
+
+
+
+print 'pos id ratio final train set', 1.*len(final_pos_train) / (len(final_train))	
+print 'final test/(train+test):', 1.*len(final_test) / (len(final_train) + len(final_test))
+
+concat_str = ''.join(final_test)
+print 'md5 of concatenated pids:', hashlib.md5(concat_str).hexdigest()
+
+output = {'train':final_train, 'test':final_test}
+output_name = pathfinder.METADATA_PATH+'final_split.pkl'
+utils.save_pkl(output, output_name)
+print 'final split saved at ', output_name
+
+
+
+
+