# Manuel A. Morales (moralesq@mit.edu)
# Harvard-MIT Department of Health Sciences & Technology
# Athinoula A. Martinos Center for Biomedical Imaging
import numpy as np
from abc import ABC, abstractmethod
from tensorflow.keras.utils import Sequence
from scipy.ndimage.measurements import center_of_mass
import nibabel as nib
from dipy.align.reslice import reslice
class BaseDataset(Sequence, ABC):
"""This class is an abstract base class (ABC) for datasets."""
def __init__(self, opt):
self.opt = opt
self.root = opt.dataroot
@abstractmethod
def __len__(self):
"""Return the size of the dataset."""
return
@abstractmethod
def __getitem__(self, idx):
"""Return a data point and its metadata information."""
pass
class Transforms():
def __init__(self, opt):
self.opt = opt
self.transform, self.transform_inv = self.get_transforms(opt)
def __crop__(self, x, inv=False):
if inv:
nx, ny = self.original_shape[:2]
xinv = np.zeros(self.original_shape[:2] + x.shape[2:])
xinv[nx//2-64:nx//2+64, ny//2-64:ny//2+64] += x
return xinv
else:
nx, ny = x.shape[:2]
return x[nx//2-64:nx//2+64, ny//2-64:ny//2+64]
def __reshape_to_carson__(self, x, inv=False):
if inv:
if len(self.original_shape)==3:
x = x.transpose(1,2,0,3)
elif len(self.original_shape)==4:
nx,ny,nz,nt=self.original_shape
Nx, Ny = x.shape[1:3]
x = x.reshape((nt, nz, Nx, Ny, self.opt.nlabels))
x = x.transpose(2,3,1,0,4)
else:
if len(x.shape) == 3:
nx,ny,nz=x.shape
x=x.transpose(2,0,1)
elif len(x.shape) == 4:
nx,ny,nz,nt=x.shape
x=x.transpose(3,2,0,1)
x=x.reshape((nt*nz,nx,ny))
return x
def __reshape_to_carmen__(self, x, inv=False):
if inv:
x = np.concatenate((np.zeros(x[:1].shape), x))
x = x.transpose((1,2,3,0,4))
else:
assert len(x.shape) == 4
nx,ny,nz,nt=x.shape
x=x.transpose(3,0,1,2)
x=np.stack((np.repeat(x[:1],nt-1,axis=0), x[1:nt]), -1)
return x
def __zscore__(self, x):
if len(x.shape) == 3:
axis=(1,2) # normalize in-plane images independently
elif len(x.shape) == 5:
axis=(1,2,3) # normalize volumes independently
self.mu = x.mean(axis=axis, keepdims=True)
self.sd = x.std(axis=axis, keepdims=True)
return (x - self.mu)/(self.sd + 1e-8)
def get_transforms(self, opt):
transform_list = []
transform_inv_list = []
if 'crop' in opt.preprocess:
transform_list.append(self.__crop__)
transform_inv_list.append(lambda x:self.__crop__(x,inv=True))
if 'reshape_to_carson' in opt.preprocess:
transform_list.append(self.__reshape_to_carson__)
transform_inv_list.append(lambda x:self.__reshape_to_carson__(x,inv=True))
elif 'reshape_to_carmen' in opt.preprocess:
transform_list.append(self.__reshape_to_carmen__)
transform_inv_list.append(lambda x:self.__reshape_to_carmen__(x,inv=True))
if 'zscore' in opt.preprocess:
transform_list.append(self.__zscore__)
return transform_list, transform_inv_list
def apply(self, x):
self.original_shape = x.shape
for transform in self.transform:
x = transform(x)
return x
def apply_inv(self, x):
for transform in self.transform_inv[::-1]:
x = transform(x)
return x
def _centercrop(x):
nx, ny = x.shape[:2]
return x[nx//2-64:nx//2+64,ny//2-64:ny//2+64]
def _roll(x,rx,ry):
x = np.roll(x,rx,axis=0)
x = np.roll(x,ry,axis=1)
return x
def _roll2center(x, center):
return _roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
def _roll2center_crop(x, center):
x = _roll2center(x, center)
return _centercrop(x)
#####################################################
## FUNCTIONS TO ADD MORE FLEXIBILITY IN SEGMENTATION
#####################################################
def resample_nifti_inv(nifti_resampled, zooms, order=1, mode='nearest'):
""" Resample `nifti_resampled` to `zooms` resolution.
"""
data_resampled = nifti_resampled.get_fdata()
zooms_resampled = nifti_resampled.header.get_zooms()[:3]
affine_resampled = nifti_resampled.affine
data_resampled, affine_resampled = reslice(data_resampled,
affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
nifti = nib.Nifti1Image(data_resampled, affine_resampled)
return nifti
def convert_back_to_nifti(data_resampled, nifti_info_subject, inv_256x256=False, order=1, mode='nearest'):
if inv_256x256:
data_resampled_mod_corr = roll_and_pad_256x256_to_center_inv(data_resampled, nifti_info=nifti_info_subject)
else:
data_resampled_mod_corr = data_resampled
affine = nifti_info_subject['affine']
affine_resampled = nifti_info_subject['affine_resampled']
zooms = nifti_info_subject['zooms'][:3]
zooms_resampled = nifti_info_subject['zooms_resampled'][:3]
data_resampled, affine_resampled = reslice(data_resampled_mod_corr,
affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
nifti = nib.Nifti1Image(data_resampled, affine_resampled)
return nifti
def roll(x,rx,ry):
x = np.roll(x,rx,axis=0)
x = np.roll(x,ry,axis=1)
return x
def roll2center(x, center):
return roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
def pad_256x256(x):
xpad = (512-x.shape[0])//2, (512-x.shape[0])-(512-x.shape[0])//2
ypad = (512-x.shape[1])//2, (512-x.shape[1])-(512-x.shape[1])//2
pads = (xpad,ypad)+((0,0),)*(len(x.shape)-2)
vals = ((0,0),)*len(x.shape)
x = np.pad(x, pads, 'constant', constant_values=vals)
x = x[512//2-256//2:512//2+256//2,512//2-256//2:512//2+256//2]
return x
def roll_and_pad_256x256_to_center(x, center):
x = roll2center(x, center)
x = pad_256x256(x)
return x
def roll_and_pad_256x256_to_center_inv(x, nifti_info):
# Recover 256x256 array that was center-cropped to 128x128!
x_256_256 = np.zeros((256,256)+x.shape[2:])
x_256_256[128-64:128+64,128-64:128+64] += x
# Coordinates to put the image in its original location.
cx, cy = nifti_info['center_resampled'][:2]
cx_mod, cy_mod = nifti_info['center_resampled_256x256'][:2]
x_inv = np.zeros(nifti_info['shape_resampled'][:3]+x.shape[3:])
dx = min(int(cx),64)
dy = min(int(cy),64)
if (dx!=64)|(dy!=64):
print('WARNING:FOV < 128x128!')
x_inv[int(cx-dx):int(cx+dx),int(cy-dy):int(cy+dy)] += x_256_256[int(cx_mod-dx):int(cx_mod+dx),
int(cy_mod-dy):int(cy_mod+dy)]
return x_inv