Diff of /data_set.py [000000] .. [4f54f1]

Switch to side-by-side view

--- a
+++ b/data_set.py
@@ -0,0 +1,193 @@
+import os
+import numpy as np
+import random as rnd
+
+import config
+from utils import read_csv_column, read_csv
+
+
+class DataLoader(object):
+    def __init__(self, 
+                 images_loader,
+                 labels_input=config.PATIENT_LABELS_CSV,
+                 exact_tests=config.TEST_PATIENTS_IDS,
+                 train_set=config.TRAINING_PATIENTS_IDS,
+                 validation_set=config.VALIDATION_PATINETS_IDS,
+                 add_transformed_positives=False):
+        self._images_loader = images_loader
+        self._labels = read_csv(labels_input)
+
+        self._exact_tests = []
+        if exact_tests:
+            self._exact_tests = read_csv_column(exact_tests)
+
+        self._train_set = list(read_csv_column(train_set, 
+                                               columns=[1]))
+        self._validation_set = list(read_csv_column(
+            validation_set, columns=[1]))
+        # Data augmentation for balancing the training set
+        if add_transformed_positives:
+            self._double_positive_class_data()
+
+        self._examples_count = len(self._validation_set) + len(self._train_set)
+        print("Total examples used for training and validation: ",
+            self._examples_count)
+        
+        print("Total patients used for validation: ", 
+            len(self._validation_set))
+        print("Total patients used for training: ", 
+            len(self._train_set))
+        self._exact_tests_count = len(self._exact_tests)
+
+    def _double_positive_class_data(self):
+        positive = self.patients_from_class(self._train_set, 
+                                            config.CANCER_CLS)
+        print("Patients with cancer are: {}".format(len(positive)))
+        # Anotate that original image should be transformed
+        positive = [positive_name + '-augm' for positive_name in positive]
+        self._train_set.extend(positive)
+
+    def patients_from_class(self, patient_ids, clazz):
+        return [patient for patient in patient_ids 
+                if self.get_label(patient) == clazz]
+
+    @property
+    def exact_tests_count(self):
+        return self._exact_tests_count
+
+    @property
+    def examples_count(self):
+        return self._examples_count
+    
+    def train_samples_count(self):
+        return len(self._train_set)
+    
+    def validation_samples_count(self):
+        return len(self._validation_set)
+
+    def get_label(self, patient_id):
+        if 'augm' in patient_id:
+            patient_id = patient_id.split('-')[0]
+        try:
+            clazz = self._labels.get_value(patient_id, config.COLUMN_NAME)
+            return clazz
+        except KeyError as e:
+            print("No key found for patient with id {} in the labels.".format(
+                   patient_id))
+        return None
+
+    def has_label(self, patient):
+        try:
+            self._labels.get_value(patient, config.COLUMN_NAME)
+        except KeyError as e:
+            return False
+        return True
+
+    def get_training_set(self):
+        return DataSet(self._train_set, self)
+
+    def get_validation_set(self):
+        return DataSet(self._validation_set, self, False)
+
+    def get_exact_tests_set(self):
+        return DataSet(self._exact_tests, self, False)
+
+    def load_image(self, patient):
+        return self._images_loader.load_scans(patient)
+
+    def results_out_dir(self):
+        out_dir = os.path.join(config.MODELS_STORE_DIR, 
+                               config.SELECTED_MODEL)
+        if not os.path.exists(out_dir):
+            os.makedirs(out_dir)
+
+        return out_dir
+
+
+class DataSet(object):
+    def __init__(self, data_set, data_loader, shuffle=True):
+        self._data_set = data_set
+        self._data_loader = data_loader
+        self._index_in_epoch = 0
+        self._finished_epochs = 0
+        self._num_samples = len(self._data_set)
+        self._shuffle = shuffle
+
+    def next_batch(self, batch_size):
+        if self._index_in_epoch >= self._num_samples:
+            # Epoche has finished, start new iteration
+            self._finished_epochs += 1
+            self._index_in_epoch = 0
+            # Shuffle data
+            if self._shuffle:
+                rnd.shuffle(self._data_set)
+        
+        start = self._index_in_epoch
+        self._index_in_epoch += batch_size
+        end = self._index_in_epoch
+
+        if end > self._num_samples:
+            print("Not enough data for the batch to be retrieved.")
+            return [], []
+
+        data_set, labels = [], []
+        try:
+            for patient in self._data_set[start:end]:
+                image, label = self._patient_with_label(patient)
+                if len(image) and label is not None:
+                    labels.append(label)
+                    data_set.append(image)
+
+            if len(data_set) < batch_size:
+                print("Current batch size is less: {}".format(len(data_set)))
+                print("Start {}, end {}, samples {}".format(start, end, 
+                    self._num_samples))
+        except FileNotFoundError as e:
+            print("Unable to laod image for patient" + patient + 
+                ". Please check if you have downloaded the data.",
+                " Otherwise use the data_collector.py script.")
+
+        return data_set, labels
+
+    # Used during exact testing phase, here no labels are returned
+    def next_patient(self):
+        if self._index_in_epoch >= self._num_samples:
+            return (None, [])
+
+        patient_id = self._data_set[self._index_in_epoch]
+        self._index_in_epoch += 1
+        image = self._load_patient(patient_id)
+        if self._validate_input_shape(image):
+            return (patient_id, image)
+        return (patient_id, [])
+
+    def _patient_with_label(self, patient_id):
+        label = self._data_loader.get_label(patient_id)
+        if label is None:
+            return ([], None)
+        
+        image = self._load_patient(patient_id)
+        if self._validate_input_shape(image):
+            return (image, label)
+        
+        return ([], None)
+
+    def _load_patient(self, patient):
+        return self._data_loader.load_image(patient)
+
+    def _validate_input_shape(self, patient_image):
+        return patient_image.shape == config.IMG_SHAPE
+
+    @property
+    def num_samples(self):
+        return self._num_samples
+
+    @property
+    def finished_epochs(self):
+        return self._finished_epochs
+    
+
+if __name__ == '__main__':
+    data_loader = DataLoader()
+    tr_set = data_loader.get_training_set()
+    val_set = data_loader.get_validation_set()
\ No newline at end of file