Diff of /data_set.py [000000] .. [4f54f1]

Switch to unified view

a b/data_set.py
1
import os
2
import numpy as np
3
import random as rnd
4
5
import config
6
from utils import read_csv_column, read_csv
7
8
9
class DataLoader(object):
10
    def __init__(self, 
11
                 images_loader,
12
                 labels_input=config.PATIENT_LABELS_CSV,
13
                 exact_tests=config.TEST_PATIENTS_IDS,
14
                 train_set=config.TRAINING_PATIENTS_IDS,
15
                 validation_set=config.VALIDATION_PATINETS_IDS,
16
                 add_transformed_positives=False):
17
        self._images_loader = images_loader
18
        self._labels = read_csv(labels_input)
19
20
        self._exact_tests = []
21
        if exact_tests:
22
            self._exact_tests = read_csv_column(exact_tests)
23
24
        self._train_set = list(read_csv_column(train_set, 
25
                                               columns=[1]))
26
        self._validation_set = list(read_csv_column(
27
            validation_set, columns=[1]))
28
        # Data augmentation for balancing the training set
29
        if add_transformed_positives:
30
            self._double_positive_class_data()
31
32
        self._examples_count = len(self._validation_set) + len(self._train_set)
33
        print("Total examples used for training and validation: ",
34
            self._examples_count)
35
        
36
        print("Total patients used for validation: ", 
37
            len(self._validation_set))
38
        print("Total patients used for training: ", 
39
            len(self._train_set))
40
        self._exact_tests_count = len(self._exact_tests)
41
42
    def _double_positive_class_data(self):
43
        positive = self.patients_from_class(self._train_set, 
44
                                            config.CANCER_CLS)
45
        print("Patients with cancer are: {}".format(len(positive)))
46
        # Anotate that original image should be transformed
47
        positive = [positive_name + '-augm' for positive_name in positive]
48
        self._train_set.extend(positive)
49
50
    def patients_from_class(self, patient_ids, clazz):
51
        return [patient for patient in patient_ids 
52
                if self.get_label(patient) == clazz]
53
54
    @property
55
    def exact_tests_count(self):
56
        return self._exact_tests_count
57
58
    @property
59
    def examples_count(self):
60
        return self._examples_count
61
    
62
    def train_samples_count(self):
63
        return len(self._train_set)
64
    
65
    def validation_samples_count(self):
66
        return len(self._validation_set)
67
68
    def get_label(self, patient_id):
69
        if 'augm' in patient_id:
70
            patient_id = patient_id.split('-')[0]
71
        try:
72
            clazz = self._labels.get_value(patient_id, config.COLUMN_NAME)
73
            return clazz
74
        except KeyError as e:
75
            print("No key found for patient with id {} in the labels.".format(
76
                   patient_id))
77
        return None
78
79
    def has_label(self, patient):
80
        try:
81
            self._labels.get_value(patient, config.COLUMN_NAME)
82
        except KeyError as e:
83
            return False
84
        return True
85
86
    def get_training_set(self):
87
        return DataSet(self._train_set, self)
88
89
    def get_validation_set(self):
90
        return DataSet(self._validation_set, self, False)
91
92
    def get_exact_tests_set(self):
93
        return DataSet(self._exact_tests, self, False)
94
95
    def load_image(self, patient):
96
        return self._images_loader.load_scans(patient)
97
98
    def results_out_dir(self):
99
        out_dir = os.path.join(config.MODELS_STORE_DIR, 
100
                               config.SELECTED_MODEL)
101
        if not os.path.exists(out_dir):
102
            os.makedirs(out_dir)
103
104
        return out_dir
105
106
107
class DataSet(object):
108
    def __init__(self, data_set, data_loader, shuffle=True):
109
        self._data_set = data_set
110
        self._data_loader = data_loader
111
        self._index_in_epoch = 0
112
        self._finished_epochs = 0
113
        self._num_samples = len(self._data_set)
114
        self._shuffle = shuffle
115
116
    def next_batch(self, batch_size):
117
        if self._index_in_epoch >= self._num_samples:
118
            # Epoche has finished, start new iteration
119
            self._finished_epochs += 1
120
            self._index_in_epoch = 0
121
            # Shuffle data
122
            if self._shuffle:
123
                rnd.shuffle(self._data_set)
124
        
125
        start = self._index_in_epoch
126
        self._index_in_epoch += batch_size
127
        end = self._index_in_epoch
128
129
        if end > self._num_samples:
130
            print("Not enough data for the batch to be retrieved.")
131
            return [], []
132
133
        data_set, labels = [], []
134
        try:
135
            for patient in self._data_set[start:end]:
136
                image, label = self._patient_with_label(patient)
137
                if len(image) and label is not None:
138
                    labels.append(label)
139
                    data_set.append(image)
140
141
            if len(data_set) < batch_size:
142
                print("Current batch size is less: {}".format(len(data_set)))
143
                print("Start {}, end {}, samples {}".format(start, end, 
144
                    self._num_samples))
145
        except FileNotFoundError as e:
146
            print("Unable to laod image for patient" + patient + 
147
                ". Please check if you have downloaded the data.",
148
                " Otherwise use the data_collector.py script.")
149
150
        return data_set, labels
151
152
    # Used during exact testing phase, here no labels are returned
153
    def next_patient(self):
154
        if self._index_in_epoch >= self._num_samples:
155
            return (None, [])
156
157
        patient_id = self._data_set[self._index_in_epoch]
158
        self._index_in_epoch += 1
159
        image = self._load_patient(patient_id)
160
        if self._validate_input_shape(image):
161
            return (patient_id, image)
162
        return (patient_id, [])
163
164
    def _patient_with_label(self, patient_id):
165
        label = self._data_loader.get_label(patient_id)
166
        if label is None:
167
            return ([], None)
168
        
169
        image = self._load_patient(patient_id)
170
        if self._validate_input_shape(image):
171
            return (image, label)
172
        
173
        return ([], None)
174
175
    def _load_patient(self, patient):
176
        return self._data_loader.load_image(patient)
177
178
    def _validate_input_shape(self, patient_image):
179
        return patient_image.shape == config.IMG_SHAPE
180
181
    @property
182
    def num_samples(self):
183
        return self._num_samples
184
185
    @property
186
    def finished_epochs(self):
187
        return self._finished_epochs
188
    
189
190
if __name__ == '__main__':
191
    data_loader = DataLoader()
192
    tr_set = data_loader.get_training_set()
193
    val_set = data_loader.get_validation_set()