|
a |
|
b/data_set.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import random as rnd |
|
|
4 |
|
|
|
5 |
import config |
|
|
6 |
from utils import read_csv_column, read_csv |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class DataLoader(object): |
|
|
10 |
def __init__(self, |
|
|
11 |
images_loader, |
|
|
12 |
labels_input=config.PATIENT_LABELS_CSV, |
|
|
13 |
exact_tests=config.TEST_PATIENTS_IDS, |
|
|
14 |
train_set=config.TRAINING_PATIENTS_IDS, |
|
|
15 |
validation_set=config.VALIDATION_PATINETS_IDS, |
|
|
16 |
add_transformed_positives=False): |
|
|
17 |
self._images_loader = images_loader |
|
|
18 |
self._labels = read_csv(labels_input) |
|
|
19 |
|
|
|
20 |
self._exact_tests = [] |
|
|
21 |
if exact_tests: |
|
|
22 |
self._exact_tests = read_csv_column(exact_tests) |
|
|
23 |
|
|
|
24 |
self._train_set = list(read_csv_column(train_set, |
|
|
25 |
columns=[1])) |
|
|
26 |
self._validation_set = list(read_csv_column( |
|
|
27 |
validation_set, columns=[1])) |
|
|
28 |
# Data augmentation for balancing the training set |
|
|
29 |
if add_transformed_positives: |
|
|
30 |
self._double_positive_class_data() |
|
|
31 |
|
|
|
32 |
self._examples_count = len(self._validation_set) + len(self._train_set) |
|
|
33 |
print("Total examples used for training and validation: ", |
|
|
34 |
self._examples_count) |
|
|
35 |
|
|
|
36 |
print("Total patients used for validation: ", |
|
|
37 |
len(self._validation_set)) |
|
|
38 |
print("Total patients used for training: ", |
|
|
39 |
len(self._train_set)) |
|
|
40 |
self._exact_tests_count = len(self._exact_tests) |
|
|
41 |
|
|
|
42 |
def _double_positive_class_data(self): |
|
|
43 |
positive = self.patients_from_class(self._train_set, |
|
|
44 |
config.CANCER_CLS) |
|
|
45 |
print("Patients with cancer are: {}".format(len(positive))) |
|
|
46 |
# Anotate that original image should be transformed |
|
|
47 |
positive = [positive_name + '-augm' for positive_name in positive] |
|
|
48 |
self._train_set.extend(positive) |
|
|
49 |
|
|
|
50 |
def patients_from_class(self, patient_ids, clazz): |
|
|
51 |
return [patient for patient in patient_ids |
|
|
52 |
if self.get_label(patient) == clazz] |
|
|
53 |
|
|
|
54 |
@property |
|
|
55 |
def exact_tests_count(self): |
|
|
56 |
return self._exact_tests_count |
|
|
57 |
|
|
|
58 |
@property |
|
|
59 |
def examples_count(self): |
|
|
60 |
return self._examples_count |
|
|
61 |
|
|
|
62 |
def train_samples_count(self): |
|
|
63 |
return len(self._train_set) |
|
|
64 |
|
|
|
65 |
def validation_samples_count(self): |
|
|
66 |
return len(self._validation_set) |
|
|
67 |
|
|
|
68 |
def get_label(self, patient_id): |
|
|
69 |
if 'augm' in patient_id: |
|
|
70 |
patient_id = patient_id.split('-')[0] |
|
|
71 |
try: |
|
|
72 |
clazz = self._labels.get_value(patient_id, config.COLUMN_NAME) |
|
|
73 |
return clazz |
|
|
74 |
except KeyError as e: |
|
|
75 |
print("No key found for patient with id {} in the labels.".format( |
|
|
76 |
patient_id)) |
|
|
77 |
return None |
|
|
78 |
|
|
|
79 |
def has_label(self, patient): |
|
|
80 |
try: |
|
|
81 |
self._labels.get_value(patient, config.COLUMN_NAME) |
|
|
82 |
except KeyError as e: |
|
|
83 |
return False |
|
|
84 |
return True |
|
|
85 |
|
|
|
86 |
def get_training_set(self): |
|
|
87 |
return DataSet(self._train_set, self) |
|
|
88 |
|
|
|
89 |
def get_validation_set(self): |
|
|
90 |
return DataSet(self._validation_set, self, False) |
|
|
91 |
|
|
|
92 |
def get_exact_tests_set(self): |
|
|
93 |
return DataSet(self._exact_tests, self, False) |
|
|
94 |
|
|
|
95 |
def load_image(self, patient): |
|
|
96 |
return self._images_loader.load_scans(patient) |
|
|
97 |
|
|
|
98 |
def results_out_dir(self): |
|
|
99 |
out_dir = os.path.join(config.MODELS_STORE_DIR, |
|
|
100 |
config.SELECTED_MODEL) |
|
|
101 |
if not os.path.exists(out_dir): |
|
|
102 |
os.makedirs(out_dir) |
|
|
103 |
|
|
|
104 |
return out_dir |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
class DataSet(object): |
|
|
108 |
def __init__(self, data_set, data_loader, shuffle=True): |
|
|
109 |
self._data_set = data_set |
|
|
110 |
self._data_loader = data_loader |
|
|
111 |
self._index_in_epoch = 0 |
|
|
112 |
self._finished_epochs = 0 |
|
|
113 |
self._num_samples = len(self._data_set) |
|
|
114 |
self._shuffle = shuffle |
|
|
115 |
|
|
|
116 |
def next_batch(self, batch_size): |
|
|
117 |
if self._index_in_epoch >= self._num_samples: |
|
|
118 |
# Epoche has finished, start new iteration |
|
|
119 |
self._finished_epochs += 1 |
|
|
120 |
self._index_in_epoch = 0 |
|
|
121 |
# Shuffle data |
|
|
122 |
if self._shuffle: |
|
|
123 |
rnd.shuffle(self._data_set) |
|
|
124 |
|
|
|
125 |
start = self._index_in_epoch |
|
|
126 |
self._index_in_epoch += batch_size |
|
|
127 |
end = self._index_in_epoch |
|
|
128 |
|
|
|
129 |
if end > self._num_samples: |
|
|
130 |
print("Not enough data for the batch to be retrieved.") |
|
|
131 |
return [], [] |
|
|
132 |
|
|
|
133 |
data_set, labels = [], [] |
|
|
134 |
try: |
|
|
135 |
for patient in self._data_set[start:end]: |
|
|
136 |
image, label = self._patient_with_label(patient) |
|
|
137 |
if len(image) and label is not None: |
|
|
138 |
labels.append(label) |
|
|
139 |
data_set.append(image) |
|
|
140 |
|
|
|
141 |
if len(data_set) < batch_size: |
|
|
142 |
print("Current batch size is less: {}".format(len(data_set))) |
|
|
143 |
print("Start {}, end {}, samples {}".format(start, end, |
|
|
144 |
self._num_samples)) |
|
|
145 |
except FileNotFoundError as e: |
|
|
146 |
print("Unable to laod image for patient" + patient + |
|
|
147 |
". Please check if you have downloaded the data.", |
|
|
148 |
" Otherwise use the data_collector.py script.") |
|
|
149 |
|
|
|
150 |
return data_set, labels |
|
|
151 |
|
|
|
152 |
# Used during exact testing phase, here no labels are returned |
|
|
153 |
def next_patient(self): |
|
|
154 |
if self._index_in_epoch >= self._num_samples: |
|
|
155 |
return (None, []) |
|
|
156 |
|
|
|
157 |
patient_id = self._data_set[self._index_in_epoch] |
|
|
158 |
self._index_in_epoch += 1 |
|
|
159 |
image = self._load_patient(patient_id) |
|
|
160 |
if self._validate_input_shape(image): |
|
|
161 |
return (patient_id, image) |
|
|
162 |
return (patient_id, []) |
|
|
163 |
|
|
|
164 |
def _patient_with_label(self, patient_id): |
|
|
165 |
label = self._data_loader.get_label(patient_id) |
|
|
166 |
if label is None: |
|
|
167 |
return ([], None) |
|
|
168 |
|
|
|
169 |
image = self._load_patient(patient_id) |
|
|
170 |
if self._validate_input_shape(image): |
|
|
171 |
return (image, label) |
|
|
172 |
|
|
|
173 |
return ([], None) |
|
|
174 |
|
|
|
175 |
def _load_patient(self, patient): |
|
|
176 |
return self._data_loader.load_image(patient) |
|
|
177 |
|
|
|
178 |
def _validate_input_shape(self, patient_image): |
|
|
179 |
return patient_image.shape == config.IMG_SHAPE |
|
|
180 |
|
|
|
181 |
@property |
|
|
182 |
def num_samples(self): |
|
|
183 |
return self._num_samples |
|
|
184 |
|
|
|
185 |
@property |
|
|
186 |
def finished_epochs(self): |
|
|
187 |
return self._finished_epochs |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
if __name__ == '__main__': |
|
|
191 |
data_loader = DataLoader() |
|
|
192 |
tr_set = data_loader.get_training_set() |
|
|
193 |
val_set = data_loader.get_validation_set() |