Diff of /split_data.py [000000] .. [4f54f1]

Switch to unified view

a b/split_data.py
1
import os
2
import numpy as np
3
import random as rnd
4
import pandas as pd
5
6
import config
7
from utils import read_csv_column, read_csv, store_to_csv
8
9
10
def load_labels():
11
    return read_csv(config.PATIENT_LABELS_CSV)
12
13
14
def load_test_ids():
15
    return read_csv_column(config.TEST_PATIENTS_IDS)
16
17
18
def get_patient_name(patient_file):
19
    return os.path.basename(patient_file).split('.')[0]
20
21
22
def load_patient_ids():
23
    test_ids = set(load_test_ids())
24
    patient_ids = [get_patient_name(patient_id)
25
                   for patient_id in os.listdir(config.ALL_IMGS)
26
                   if get_patient_name(patient_id) not in test_ids]
27
                   
28
    return patient_ids
29
30
def get_class(labels, patient_id):
31
    return labels.get_value(patient_id, config.COLUMN_NAME)
32
33
34
def count_patients_from_class(patient_ids, labels, clazz):
35
    return len([patient for patient in patient_ids 
36
                if get_class(labels, patient) == clazz])
37
38
39
def split_data():
40
    labels = load_labels()
41
    total = len(labels)
42
    print("Total labels loaded: ", total)
43
    patient_ids = load_patient_ids()
44
    print("Total patient ids loaded: ", len(patient_ids))
45
    print("Patient with cancer are: ", count_patients_from_class(
46
        patient_ids, labels, config.CANCER_CLS))
47
48
    validation_size = int(0.15 * total)
49
50
    validation_set = rnd.sample(patient_ids, validation_size)
51
    train_set = [patient for patient in patient_ids 
52
                 if patient not in validation_set]
53
54
    print("Patients for training: ", len(train_set))
55
    print("Patients for validation: ", len(validation_set))
56
57
    print("Patients with cancer in validation set {}, no cancer {}.".format(
58
        count_patients_from_class(validation_set, labels, config.CANCER_CLS),
59
        count_patients_from_class(validation_set, labels, config.NO_CANCER_CLS)))
60
    print("Patients with cancer in training set {}, no cancer {}.".format(
61
        count_patients_from_class(train_set, labels, config.CANCER_CLS),
62
        count_patients_from_class(train_set, labels, config.NO_CANCER_CLS)))
63
64
    validation_labels = [get_class(labels, p) for p in validation_set]
65
    store_to_csv(validation_set, validation_labels, 
66
        config.VALIDATION_PATINETS_IDS)
67
    
68
    train_labels = [get_class(labels, p) for p in train_set]
69
    store_to_csv(train_set, train_labels, 
70
        config.TRAINING_PATIENTS_IDS)
71
72
73
if __name__ == '__main__':
74
    split_data()