--- a
+++ b/utils/data_utils.py
@@ -0,0 +1,233 @@
+import os
+
+import h5py
+import nibabel as nb
+import numpy as np
+import torch
+import torch.utils.data as data
+from torchvision import transforms
+import utils.preprocessor as preprocessor
+
+
+# transform_train = transforms.Compose([
+#     transforms.RandomCrop(200, padding=56),
+#     transforms.ToTensor(),
+# ])
+
+
+class ImdbData(data.Dataset):
+    def __init__(self, X, y, w, transforms=None):
+        self.X = X if len(X.shape) == 4 else X[:, np.newaxis, :, :]
+        self.y = y
+        self.w = w
+        self.transforms = transforms
+
+    def __getitem__(self, index):
+        img = torch.from_numpy(self.X[index])
+        label = torch.from_numpy(self.y[index])
+        weight = torch.from_numpy(self.w[index])
+        return img, label, weight
+
+    def __len__(self):
+        return len(self.y)
+
+
+def get_imdb_dataset(data_params):
+    data_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_data_file']), 'r')
+    label_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_label_file']), 'r')
+    class_weight_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_class_weights_file']), 'r')
+    weight_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_weights_file']), 'r')
+
+    data_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_data_file']), 'r')
+    label_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_label_file']), 'r')
+    class_weight_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_class_weights_file']), 'r')
+    weight_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_weights_file']), 'r')
+
+    return (ImdbData(data_train['data'][()], label_train['label'][()], class_weight_train['class_weights'][()]),
+            ImdbData(data_test['data'][()], label_test['label'][()], class_weight_test['class_weights'][()]))
+
+
+def load_dataset(file_paths,
+                 orientation,
+                 remap_config,
+                 return_weights=False,
+                 reduce_slices=False,
+                 remove_black=False):
+    print("Loading and preprocessing data...")
+    volume_list, labelmap_list, headers, class_weights_list, weights_list = [], [], [], [], []
+
+    for file_path in file_paths:
+        volume, labelmap, class_weights, weights, header = load_and_preprocess(file_path, orientation,
+                                                                               remap_config=remap_config,
+                                                                               reduce_slices=reduce_slices,
+                                                                               remove_black=remove_black,
+                                                                               return_weights=return_weights)
+
+        volume_list.append(volume)
+        labelmap_list.append(labelmap)
+
+        if return_weights:
+            class_weights_list.append(class_weights)
+            weights_list.append(weights)
+
+        headers.append(header)
+
+        print("#", end='', flush=True)
+    print("100%", flush=True)
+    if return_weights:
+        return volume_list, labelmap_list, class_weights_list, weights_list, headers
+    else:
+        return volume_list, labelmap_list, headers
+
+
+def load_and_preprocess(file_path, orientation, remap_config, reduce_slices=False,
+                        remove_black=False,
+                        return_weights=False):
+    volume, labelmap, header = load_data(file_path, orientation)
+
+    volume, labelmap, class_weights, weights = preprocess(volume, labelmap, remap_config=remap_config,
+                                                          reduce_slices=reduce_slices,
+                                                          remove_black=remove_black,
+                                                          return_weights=return_weights)
+    return volume, labelmap, class_weights, weights, header
+
+
+def load_and_preprocess_eval(file_path, orientation, notlabel=True):
+    volume_nifty = nb.load(file_path[0])
+    header = volume_nifty.header
+    volume = volume_nifty.get_fdata()
+    if notlabel:
+        volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
+    else:
+        volume = np.round(volume)
+    if orientation == "COR":
+        volume = volume.transpose((2, 0, 1))
+    elif orientation == "AXI":
+        volume = volume.transpose((1, 2, 0))
+    return volume, header
+
+
+def load_data(file_path, orientation):
+    volume_nifty, labelmap_nifty = nb.load(file_path[0]), nb.load(file_path[1])
+    volume, labelmap = volume_nifty.get_fdata(), labelmap_nifty.get_fdata()
+    volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
+    volume, labelmap = preprocessor.rotate_orientation(volume, labelmap, orientation)
+    return volume, labelmap, volume_nifty.header
+
+
+def preprocess(volume, labelmap, remap_config, reduce_slices=False, remove_black=False, return_weights=False):
+    if reduce_slices:
+        volume, labelmap = preprocessor.reduce_slices(volume, labelmap)
+
+    if remap_config:
+        labelmap = preprocessor.remap_labels(labelmap, remap_config)
+
+    if remove_black:
+        volume, labelmap = preprocessor.remove_black(volume, labelmap)
+
+    if return_weights:
+        class_weights, weights = preprocessor.estimate_weights_mfb(labelmap)
+        return volume, labelmap, class_weights, weights
+    else:
+        return volume, labelmap, None, None
+
+
+# def load_file_paths(data_dir, label_dir, volumes_txt_file=None):
+#     """
+#     This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label.
+#     It should be modified to suit the need of the project
+#     :param data_dir: Directory which contains the data files
+#     :param label_dir: Directory which contains the label files
+#     :param volumes_txt_file: (Optional) Path to the a csv file, when provided only these data points will be read
+#     :return: list of file paths as string
+#     """
+#
+#     volume_exclude_list = ['IXI290', 'IXI423']
+#     if volumes_txt_file:
+#         with open(volumes_txt_file) as file_handle:
+#             volumes_to_use = file_handle.read().splitlines()
+#     else:
+#         volumes_to_use = [name for name in os.listdir(data_dir) if
+#                           name.startswith('IXI') and name not in volume_exclude_list]
+#
+#     file_paths = [
+#         [os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol, 'mri/aseg.auto_noCCseg.mgz')]
+#         for
+#         vol in volumes_to_use]
+#     return file_paths
+
+
+def load_file_paths(data_dir, label_dir, data_id, volumes_txt_file=None):
+    """
+    This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label.
+    It should be modified to suit the need of the project
+    :param data_dir: Directory which contains the data files
+    :param label_dir: Directory which contains the label files
+    :param data_id: A flag indicates the name of Dataset for proper file reading
+    :param volumes_txt_file: (Optional) Path to the a csv file, when provided only these data points will be read
+    :return: list of file paths as string
+    """
+
+    if volumes_txt_file:
+        with open(volumes_txt_file) as file_handle:
+            volumes_to_use = file_handle.read().splitlines()
+    else:
+        volumes_to_use = [name for name in os.listdir(data_dir)]
+
+    if data_id == "MALC":
+        file_paths = [
+            [os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol + '_glm.mgz')]
+            for
+            vol in volumes_to_use]
+    elif data_id == "ADNI":
+        file_paths = [
+            [os.path.join(data_dir, vol, 'orig.mgz'), os.path.join(label_dir, vol, 'Lab_con.mgz')]
+            for
+            vol in volumes_to_use]
+    elif data_id == "CANDI":
+        file_paths = [
+            [os.path.join(data_dir, vol + '/' + vol + '_1.mgz'),
+             os.path.join(label_dir, vol + '/' + vol + '_1_seg.mgz')]
+            for
+            vol in volumes_to_use]
+    elif data_id == "IBSR":
+        file_paths = [
+            [os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol + '_map.nii.gz')]
+            for
+            vol in volumes_to_use]
+    else:
+        raise ValueError("Invalid entry, valid options are MALC, ADNI, CANDI and IBSR")
+
+    return file_paths
+
+
+def load_file_paths_eval(data_dir, volumes_txt_file, dir_struct):
+    """
+    This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label.
+    It should be modified to suit the need of the project
+    :param data_dir: Directory which contains the data files
+    :param volumes_txt_file:  Path to the a csv file, when provided only these data points will be read
+    :param dir_struct: If the id_list is in FreeSurfer style or normal
+    :return: list of file paths as string
+    """
+
+    with open(volumes_txt_file) as file_handle:
+        volumes_to_use = file_handle.read().splitlines()
+    if dir_struct == "FS":
+        file_paths = [
+            [os.path.join(data_dir, vol, 'mri/orig.mgz')]
+            for
+            vol in volumes_to_use]
+    elif dir_struct == "Linear":
+        file_paths = [
+            [os.path.join(data_dir, vol)]
+            for
+            vol in volumes_to_use]
+    elif dir_struct == "part_FS":
+        file_paths = [
+            [os.path.join(data_dir, vol, 'orig.mgz')]
+            for
+            vol in volumes_to_use]
+    else:
+        raise ValueError("Invalid entry, valid options are FS and Linear")
+    return file_paths