--- a
+++ b/dataset.py
@@ -0,0 +1,88 @@
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+import nibabel as nib
+from utils import (generate_ex_list, gen_mask, correct_dims, resize_img,
+                   center_crop, find_and_crop_lesions, random_crop)
+from random import random
+
+
+class MRIDataset(Dataset):
+    def __init__(self,
+                 directory,
+                 size,
+                 sampling_mode="random",
+                 deterministic=False):
+        assert sampling_mode in ['resize', 'center', 'center_val', 'random']
+        self.directory = directory
+        self.inputs, self.labels = generate_ex_list(self.directory)
+        self.size = size
+        self.sampling_mode = sampling_mode
+        self.deterministic = deterministic
+        self.current_item_path = None  # use this to find the sample's path for debugging and visual inspection
+
+    def __len__(self):
+        return len(self.inputs)
+
+    def __getitem__(self, idx):
+        self.current_item_path = self.inputs[idx]
+        input_img = correct_dims(nib.load(self.inputs[idx]).get_data())
+        label_img = gen_mask(self.labels[idx])
+
+        # Resize to input image and label to size (self.size x self.size x self.size)
+        if self.sampling_mode == "resize":
+            ex, label = resize_img(input_img, label_img, self.size)
+
+        # Constant center-crop sample of size (self.size x self.size x self.size)
+        elif self.sampling_mode == 'center':
+            ex, label = center_crop(input_img, label_img, self.size)
+
+        # Find centers of lesion masks and crop image to include them
+        # to measure consistent validation performance with small crops
+        elif self.sampling_mode == "center_val":
+            ex, label = find_and_crop_lesions(input_img, label_img, self.size,
+                                              self.deterministic)
+
+        # Randomly crop sample of size (self.size x self.size x self.size)
+        elif self.sampling_mode == "random":
+            ex, label = random_crop(input_img, label_img, self.size)
+
+        else:
+            print("Invalid sampling mode.")
+            exit()
+
+        ex = np.divide(ex, 255.0)
+        label = np.array([(label > 0).astype(int)]).squeeze()
+        # (experimental) APPLY RANDOM FLIPPING ALONG EACH AXIS
+        if not self.deterministic:
+            for i in range(3):
+                if random() > 0.5:
+                    ex = np.flip(ex, i)
+                    label = np.flip(label, i)
+        inputs = torch.from_numpy(ex.copy()).type(
+            torch.FloatTensor).unsqueeze(0)
+        labels = torch.from_numpy(label.copy()).type(
+            torch.FloatTensor).unsqueeze(0)
+        return inputs, labels
+
+    def _load_full_label(self, label_paths):
+        """
+        Loading full-size label for evaluation.
+        """
+        label_full = gen_mask(label_paths)
+        label_full = np.array([(label_full > 0).astype(int)]).squeeze()
+        return torch.Tensor(label_full)
+
+    def _project_full_label(self, input_path, preds):
+        """
+        Projecting a predicted subvolume to full-size volume for evaluation.
+        Only supposed to work with center sampling_mode.
+        """
+        size = nib.load(input_path).get_data().shape
+        preds_full = np.zeros(size)
+        coords = [0]*3
+        for i in range(3):
+            coords[i] = int((size[i]-self.size[i])//2)
+        x, y, z = coords
+        preds_full[x:x+self.size[0], y:y+self.size[1], z:z+self.size[2]] = preds
+        return torch.Tensor(preds_full)