Switch to side-by-side view

--- a
+++ b/Segmentation/utils/data_loader_3d.py
@@ -0,0 +1,218 @@
+from glob import glob
+import h5py
+import numpy as np
+from random import randint
+from tensorflow.keras.utils import Sequence
+import tensorflow as tf
+from math import ceil
+
+
+class VolumeGenerator(Sequence):
+    def __init__(self, batch_size, sample_shape=(364, 364, 160),
+                 file_path='t', shuffle_order=True,
+                 normalise_input=True, remove_outliers=True,
+                 transform_angle=False, transform_position=False,
+                 get_slice=False, get_position=False, skip_empty=True,
+                 examples_per_load=1, train_debug=False):
+        self.batch_size = batch_size
+        self.sample_shape = sample_shape
+        self.data_paths = VolumeGenerator.get_paths(file_path)
+        self.shuffle_order = shuffle_order
+        self.normalise_input = normalise_input
+        self.remove_outliers = remove_outliers
+        self.transform_angle = transform_angle
+        self.transform_position = transform_position
+        self.get_slice = get_slice
+        self.get_position = get_position
+        self.skip_empty = skip_empty
+        self.examples_per_load = examples_per_load
+        self.train_debug = train_debug
+
+        if self.train_debug:
+            cut = int(len(self.data_paths) / 5)
+            self.data_paths = self.data_paths[:cut]
+
+        assert self.batch_size <= len(self.data_paths), f"Batch size {self.batch_size} must be less than or equal to number of training examples {len(self.data_paths)}"
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        self.indexes = np.arange(len(self.data_paths))
+        if self.shuffle_order:
+            np.random.shuffle(self.indexes)
+
+    def __len__(self):
+        return ceil(len(self.data_paths) / self.batch_size)
+
+    def __getitem__(self, index):
+        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
+        batch = [self.data_paths[idx] for idx in indexes]
+        x, y = self.generate_batch(batch)
+        return x, y
+
+    def generate_batch(self, batch, skip_fail=3):
+        x_train, y_train = [], []
+        if self.get_position:
+            image_arr, pos_arr = [], []
+        for sample_path in batch:
+            count = self.examples_per_load
+            skip_count = skip_fail
+            x_path, y_path = sample_path
+
+            volume_x_original = VolumeGenerator.load_file(x_path)
+            volume_y_original = VolumeGenerator.load_file(y_path)
+
+            while count > 0:
+                sample_pos, sample_pos_max = VolumeGenerator.get_sample_pos(volume_x_original.shape, self.sample_shape,
+                                                                            self.transform_position)
+
+                volume_x = VolumeGenerator.sample_from_volume(volume_x_original, self.sample_shape, sample_pos)
+                volume_y = VolumeGenerator.sample_from_volume(volume_y_original, self.sample_shape, sample_pos)
+                volume_y = np.any(volume_y, axis=-1)
+
+                if self.normalise_input or self.remove_outliers:
+                    mean = tf.math.reduce_mean(volume_x)
+                    if self.remove_outliers:
+                        np.clip(volume_x, None, 0.01, volume_x)
+                    if self.normalise_input:
+                        volume_x = VolumeGenerator.normalise(volume_x, mean)
+
+                volume_x = VolumeGenerator.expand_dim_as_float(volume_x)
+                volume_y = VolumeGenerator.expand_dim_as_float(volume_y)
+
+                if self.get_slice:
+                    slice_idx = int((self.sample_shape[2] + 1) / 2) - 1
+                    assert slice_idx >= 0
+                    volume_y = volume_y[:, :, slice_idx]
+
+                if self.skip_empty:
+                    if np.sum(volume_y) == 0:
+                        skip_count -= 1
+                        if skip_count > 0:
+                            continue
+
+                if self.get_position:
+                    image_arr.append(volume_x)
+                    pos = np.empty(3, dtype=np.float32)
+                    for i in range(3):
+                        pos[i] = VolumeGenerator.normalise_position(sample_pos[i], sample_pos_max[i])
+                    pos_arr.append(pos)
+                else:
+                    x_train.append(volume_x)
+                y_train.append(volume_y)
+                count -= 1
+
+        if self.get_position:
+            image_arr = np.stack(image_arr, axis=0)
+            pos_arr = np.stack(pos_arr, axis=0)
+            x_train = [image_arr, pos_arr]
+        else:
+            x_train = np.stack(x_train, axis=0)
+        y_train = np.stack(y_train, axis=0)
+        return x_train, y_train
+
+    @staticmethod
+    def get_sample_pos(volume_shape, sample_shape, transform_position):
+        """
+        - Get the position required to translate the volumes by. Ranges from 0 to volume_shape - sample_shape
+        - If (volume_shape - sample_shape) == 0, sample and volume same shape. Also the position is centred.
+        """
+        vol_x, vol_y, vol_z = volume_shape[0] - 1, volume_shape[1] - 1, volume_shape[2] - 1
+        samp_x, samp_y, samp_z = sample_shape[0] - 1, sample_shape[1] - 1, sample_shape[2] - 1
+        centre_x = int(vol_x / 2) - int(samp_x / 2)
+        centre_y = int(vol_y / 2) - int(samp_y / 2)
+        centre_z = int(vol_z / 2) - int(samp_z / 2)
+        x_max = volume_shape[0] - sample_shape[0]
+        y_max = volume_shape[1] - sample_shape[1]
+        z_max = volume_shape[2] - sample_shape[2]
+        pos_max = np.array([x_max, y_max, z_max], dtype=np.int32)
+        pos = None
+        if transform_position == "normal":
+            stddev_x = int(centre_x / 4)
+            stddev_y = int(centre_y / 4)
+            stddev_z = int(centre_z / 4)
+            x_pos = np.random.normal(centre_x, stddev_x)
+            y_pos = np.random.normal(centre_y, stddev_y)
+            z_pos = np.random.normal(centre_z, stddev_z)
+            float_pos = np.array([x_pos, y_pos, z_pos], dtype=np.float32)
+            float_pos = np.clip(float_pos, 0, [x_max, y_max, z_max])
+            pos = np.rint(float_pos)
+        elif transform_position == "uniform":
+            x_pos = np.random.uniform(0, x_max)
+            y_pos = np.random.uniform(0, y_max)
+            z_pos = np.random.uniform(0, z_max)
+            float_pos = np.array([x_pos, y_pos, z_pos], dtype=np.float32)
+            pos = np.rint(float_pos)
+        else:
+            x_pos = centre_x
+            y_pos = centre_y
+            z_pos = centre_z
+            pos = np.array([x_pos, y_pos, z_pos], dtype=np.int32)
+        pos = pos.astype(int)
+        return pos, pos_max
+
+    @staticmethod
+    def get_paths(file_path):
+        if file_path == "t":
+            file_path = "./Data/train/train"
+        elif file_path == "v":
+            file_path = "./Data/valid/valid"
+        X_list = glob(f'{file_path}*.im')
+        Y_list = glob(f'{file_path}*.seg')
+        data_paths = []
+        for x_name in X_list:
+            x_id = x_name[-10:-3]
+            y_name = f'{file_path}_{x_id}.seg'
+            assert y_name in Y_list, "{y_name} is missing in the data file"
+            data_paths.append([x_name, y_name])
+        return data_paths
+
+    @staticmethod
+    def load_file(file):
+        with h5py.File(file, 'r') as hf:
+            volume = np.array(hf['data'])
+        return volume
+
+    @staticmethod
+    def sample_from_volume(volume, sample_shape, sample_pos):
+        pos_x, pos_y, pos_z = sample_pos
+        volume_sample = volume[pos_x: pos_x + sample_shape[0],
+                               pos_y: pos_y + sample_shape[1],
+                               pos_z: pos_z + sample_shape[2]]
+        return volume_sample
+
+    @staticmethod
+    def normalise(x_image, mean=None, std=None):
+        if mean is None:
+            mean = tf.math.reduce_mean(x_image)
+        if std is None:
+            std = tf.math.reduce_std(x_image)
+        return (x_image - mean) / std
+
+    @staticmethod
+    def expand_dim_as_float(volume):
+        return np.expand_dims(volume, axis=-1).astype(np.float32)
+
+    @staticmethod
+    def normalise_position(pos, pos_max):
+        """
+            - Recieved the pos which is a value from 0 to (length - sample size)
+            - A value scaled between -1 and 1 where 0 represents a sample from the centre.
+        """
+        if pos_max == 0:
+            return 0
+        return 2 * ((pos / pos_max) - 0.5)
+
+
+if __name__ == "__main__":
+    import sys
+    import os
+    sys.path.insert(0, os.getcwd())
+
+    add_pos = True
+    vol_gen = VolumeGenerator(1, (384, 384, 128), get_position=add_pos, examples_per_load=1)
+    x, y = vol_gen.__getitem__(0)
+    if add_pos:
+        print(x[0].shape)
+        print(x[1].shape)
+    print(y.shape)
+    print(y.dtype)