[4f54f1]: / split_data.py

Download this file

74 lines (51 with data), 2.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import numpy as np
import random as rnd
import pandas as pd
import config
from utils import read_csv_column, read_csv, store_to_csv
def load_labels():
return read_csv(config.PATIENT_LABELS_CSV)
def load_test_ids():
return read_csv_column(config.TEST_PATIENTS_IDS)
def get_patient_name(patient_file):
return os.path.basename(patient_file).split('.')[0]
def load_patient_ids():
test_ids = set(load_test_ids())
patient_ids = [get_patient_name(patient_id)
for patient_id in os.listdir(config.ALL_IMGS)
if get_patient_name(patient_id) not in test_ids]
return patient_ids
def get_class(labels, patient_id):
return labels.get_value(patient_id, config.COLUMN_NAME)
def count_patients_from_class(patient_ids, labels, clazz):
return len([patient for patient in patient_ids
if get_class(labels, patient) == clazz])
def split_data():
labels = load_labels()
total = len(labels)
print("Total labels loaded: ", total)
patient_ids = load_patient_ids()
print("Total patient ids loaded: ", len(patient_ids))
print("Patient with cancer are: ", count_patients_from_class(
patient_ids, labels, config.CANCER_CLS))
validation_size = int(0.15 * total)
validation_set = rnd.sample(patient_ids, validation_size)
train_set = [patient for patient in patient_ids
if patient not in validation_set]
print("Patients for training: ", len(train_set))
print("Patients for validation: ", len(validation_set))
print("Patients with cancer in validation set {}, no cancer {}.".format(
count_patients_from_class(validation_set, labels, config.CANCER_CLS),
count_patients_from_class(validation_set, labels, config.NO_CANCER_CLS)))
print("Patients with cancer in training set {}, no cancer {}.".format(
count_patients_from_class(train_set, labels, config.CANCER_CLS),
count_patients_from_class(train_set, labels, config.NO_CANCER_CLS)))
validation_labels = [get_class(labels, p) for p in validation_set]
store_to_csv(validation_set, validation_labels,
config.VALIDATION_PATINETS_IDS)
train_labels = [get_class(labels, p) for p in train_set]
store_to_csv(train_set, train_labels,
config.TRAINING_PATIENTS_IDS)
if __name__ == '__main__':
split_data()