--- a
+++ b/data/base_dataset.py
@@ -0,0 +1,215 @@
+# Manuel A. Morales (moralesq@mit.edu)
+# Harvard-MIT Department of Health Sciences & Technology  
+# Athinoula A. Martinos Center for Biomedical Imaging
+
+import numpy as np
+from abc import ABC, abstractmethod
+from tensorflow.keras.utils import Sequence
+from scipy.ndimage.measurements import center_of_mass
+
+import nibabel as nib
+from dipy.align.reslice import reslice
+
+class BaseDataset(Sequence, ABC):
+    """This class is an abstract base class (ABC) for datasets."""
+
+    def __init__(self, opt):
+        self.opt  = opt
+        self.root = opt.dataroot
+
+    @abstractmethod
+    def __len__(self):
+        """Return the size of the dataset."""
+        return 
+    
+    @abstractmethod
+    def __getitem__(self, idx):
+        """Return a data point and its metadata information."""
+        pass
+                       
+class Transforms():
+    
+    def __init__(self, opt):
+        self.opt = opt 
+        self.transform, self.transform_inv = self.get_transforms(opt)
+       
+    def __crop__(self, x, inv=False):
+        
+        if inv:
+            nx, ny = self.original_shape[:2]
+            xinv = np.zeros(self.original_shape[:2] + x.shape[2:])
+            xinv[nx//2-64:nx//2+64, ny//2-64:ny//2+64] += x
+            return xinv
+        else:
+            nx, ny = x.shape[:2]
+            return x[nx//2-64:nx//2+64, ny//2-64:ny//2+64]
+    
+    def __reshape_to_carson__(self, x, inv=False):
+        
+        if inv:
+            if len(self.original_shape)==3:
+                x = x.transpose(1,2,0,3)
+            elif len(self.original_shape)==4:
+                nx,ny,nz,nt=self.original_shape
+                Nx, Ny = x.shape[1:3]
+                x = x.reshape((nt, nz, Nx, Ny, self.opt.nlabels))
+                x = x.transpose(2,3,1,0,4)                
+        else:
+            if len(x.shape) == 3:
+                nx,ny,nz=x.shape
+                x=x.transpose(2,0,1)
+            elif len(x.shape) == 4:
+                nx,ny,nz,nt=x.shape
+                x=x.transpose(3,2,0,1)
+                x=x.reshape((nt*nz,nx,ny))            
+        return x
+
+    def __reshape_to_carmen__(self, x, inv=False):
+        if inv:
+            x = np.concatenate((np.zeros(x[:1].shape), x))
+            x = x.transpose((1,2,3,0,4)) 
+        else:
+            assert len(x.shape) == 4
+            nx,ny,nz,nt=x.shape
+            x=x.transpose(3,0,1,2)
+            x=np.stack((np.repeat(x[:1],nt-1,axis=0), x[1:nt]), -1)
+        return x  
+    
+    def __zscore__(self, x):
+
+        if len(x.shape) == 3:
+            axis=(1,2) # normalize in-plane images independently
+        elif len(x.shape) == 5:
+            axis=(1,2,3) # normalize volumes independently
+
+        self.mu = x.mean(axis=axis, keepdims=True)
+        self.sd = x.std(axis=axis, keepdims=True)
+        return (x - self.mu)/(self.sd + 1e-8)
+
+    def get_transforms(self, opt):
+
+        transform_list     = []
+        transform_inv_list = []
+        if 'crop' in opt.preprocess:
+            transform_list.append(self.__crop__)
+            transform_inv_list.append(lambda x:self.__crop__(x,inv=True))
+        if 'reshape_to_carson' in opt.preprocess:
+            transform_list.append(self.__reshape_to_carson__)
+            transform_inv_list.append(lambda x:self.__reshape_to_carson__(x,inv=True))
+        elif 'reshape_to_carmen' in opt.preprocess:
+            transform_list.append(self.__reshape_to_carmen__)
+            transform_inv_list.append(lambda x:self.__reshape_to_carmen__(x,inv=True))
+        if 'zscore' in opt.preprocess:
+            transform_list.append(self.__zscore__)                
+        
+        return transform_list, transform_inv_list
+          
+    def apply(self, x):
+        
+        self.original_shape = x.shape
+        for transform in self.transform:
+            x = transform(x)
+        return x
+    
+    def apply_inv(self, x):
+        
+        for transform in self.transform_inv[::-1]:
+            x = transform(x)
+        return x    
+    
+
+def _centercrop(x):
+    nx, ny = x.shape[:2]
+    return x[nx//2-64:nx//2+64,ny//2-64:ny//2+64]
+
+def _roll(x,rx,ry):
+    x = np.roll(x,rx,axis=0)
+    x = np.roll(x,ry,axis=1)
+    return x
+
+def _roll2center(x, center):
+    return _roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
+    
+def _roll2center_crop(x, center):
+    x = _roll2center(x, center)
+    return _centercrop(x)
+    
+    
+#####################################################
+## FUNCTIONS TO ADD MORE FLEXIBILITY IN SEGMENTATION
+#####################################################
+
+def resample_nifti_inv(nifti_resampled, zooms, order=1, mode='nearest'):
+    """ Resample `nifti_resampled` to `zooms` resolution.
+    """
+    data_resampled   = nifti_resampled.get_fdata()
+    zooms_resampled  = nifti_resampled.header.get_zooms()[:3]
+    affine_resampled = nifti_resampled.affine 
+        
+    data_resampled, affine_resampled = reslice(data_resampled, 
+                                               affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
+
+    nifti = nib.Nifti1Image(data_resampled, affine_resampled)
+    
+    return nifti
+    
+def convert_back_to_nifti(data_resampled, nifti_info_subject, inv_256x256=False, order=1, mode='nearest'):
+
+    if inv_256x256:
+        data_resampled_mod_corr = roll_and_pad_256x256_to_center_inv(data_resampled, nifti_info=nifti_info_subject)
+    else:
+        data_resampled_mod_corr = data_resampled
+        
+    affine           = nifti_info_subject['affine']
+    affine_resampled = nifti_info_subject['affine_resampled']
+    zooms            = nifti_info_subject['zooms'][:3]
+    zooms_resampled  = nifti_info_subject['zooms_resampled'][:3]
+    
+    data_resampled, affine_resampled = reslice(data_resampled_mod_corr, 
+                                               affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
+    nifti = nib.Nifti1Image(data_resampled, affine_resampled)
+    
+    return nifti
+
+def roll(x,rx,ry):
+        x = np.roll(x,rx,axis=0)
+        x = np.roll(x,ry,axis=1)
+        return x
+    
+def roll2center(x, center):
+    return roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
+    
+def pad_256x256(x):
+        xpad = (512-x.shape[0])//2, (512-x.shape[0])-(512-x.shape[0])//2
+        ypad = (512-x.shape[1])//2, (512-x.shape[1])-(512-x.shape[1])//2
+        pads = (xpad,ypad)+((0,0),)*(len(x.shape)-2)
+        vals = ((0,0),)*len(x.shape)
+        x = np.pad(x, pads, 'constant', constant_values=vals)
+        x = x[512//2-256//2:512//2+256//2,512//2-256//2:512//2+256//2]
+        return x
+    
+def roll_and_pad_256x256_to_center(x, center):
+    x = roll2center(x, center)
+    x = pad_256x256(x)
+    return x
+
+def roll_and_pad_256x256_to_center_inv(x, nifti_info):
+
+    # Recover 256x256 array that was center-cropped to 128x128!
+    x_256_256 = np.zeros((256,256)+x.shape[2:])
+    x_256_256[128-64:128+64,128-64:128+64] += x
+    
+    # Coordinates to put the image in its original location.
+    cx, cy         = nifti_info['center_resampled'][:2]
+    cx_mod, cy_mod = nifti_info['center_resampled_256x256'][:2]
+    
+    x_inv = np.zeros(nifti_info['shape_resampled'][:3]+x.shape[3:])
+
+    dx = min(int(cx),64)
+    dy = min(int(cy),64)
+    if (dx!=64)|(dy!=64):
+        print('WARNING:FOV < 128x128!')
+
+    x_inv[int(cx-dx):int(cx+dx),int(cy-dy):int(cy+dy)] += x_256_256[int(cx_mod-dx):int(cx_mod+dx),
+                                                                    int(cy_mod-dy):int(cy_mod+dy)]
+    return x_inv