--- a +++ b/ndv/modules/dataloader.py @@ -0,0 +1,281 @@ +import torch, fastai, sys, os +from fastai.vision import * +from fastai.vision.data import SegmentationProcessor +import ants +from ants.core.ants_image import ANTsImage +from jupyterthemes import jtplot +sys.path.insert(0, './exp') +jtplot.style(theme='gruvboxd') + +# Set a root directory +path = Path('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training') + +def is_mod(fn:str, mod:str)->bool: + "Check if file path contains a specified name of modality used for MRI" + import re + r = re.compile('.*' + mod, re.IGNORECASE) + return True if r.match(fn) else False + +def is_mods(fn:str, mods:Collection[str])->bool: + "Check if file path contains specified names of modality used for MRI" + import re + return any([is_mod(fn, mod) for mod in mods]) + +def _path_to_same_str(p_fn): + "path -> str, but same on nt+posix, for alpha-sort only" + s_fn = str(p_fn) + s_fn = s_fn.replace('\\','.') + s_fn = s_fn.replace('/','.') + return s_fn + +def _get_files(path, file, modality): + """ + Internal implementation for `get_files` to combine a parent directory with a file + to make a full path to file(s) + """ + p = Path(path) + res = [p/o for o in file if not o.startswith('.') and is_mods(o, modality)] + assert len(res)==len(modality) #TODO: Assert message + return res + +def get_files(path:PathOrStr, modality:Union[str, Collection[str]], + presort:bool=False)->FilePathList: + "Return a list of full file paths in `path` each of which contains modality in its name" + file = [o.name for o in os.scandir(path) if o.is_file()] + res = _get_files(path, file, modality) + if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False) + return res + +def _repr_antsimage(self): + if self.dimension == 3: + s = 'NiftiImage ({})\n'.format(self.orientation) + else: + s = 'NiftiImage\n' + s = s +\ + '\t {:<10} : {} ({})\n'.format('Pixel Type', self.pixeltype, self.dtype)+\ + '\t {:<10} : {}{}\n'.format('Components', self.components, ' (RGB)' if 'RGB' in self._libsuffix else '')+\ + '\t {:<10} : {}\n'.format('Dimensions', self.shape)+\ + '\t {:<10} : {}\n'.format('Spacing', tuple([round(s,4) for s in self.spacing]))+\ + '\t {:<10} : {}\n'.format('Origin', tuple([round(o,4) for o in self.origin]))+\ + '\t {:<10} : {}\n'.format('Direction', np.round(self.direction.flatten(),4)) + return s + +# Modify the representation of `ANTsImage` object +ANTsImage.__repr__ = _repr_antsimage + +class NiftiImage(ItemBase): + "Support handling NIfTI image format" + #TODO: Extend the code so as to support various Python (medical) libraries that can read NIfTI format + def __init__(self, data:Union[Tensor,np.array], obj:ANTsImage, path:str): + self.data = data + self.obj = obj + self.path = path + # Only works for a specific folder tree + self.mod = self.path.split(".")[0].split("_")[-1] + + def __repr__(self): return str(self.obj) + '\t {:<10} : {}\n\n'.format('Modality', str(self.mod)) + + def __getattr__(self, k:str): + func = getattr(self.obj, k) + if isinstance(func, Callable): return func + + def __setattr__(self, k, v): + if k == 'obj': + self.data = torch.tensor(v.numpy()) + return super().__setattr__(k, v) + + # This wraps ANTsPy's `plot` method to show NIfTI image + def show(self, **kwargs): + ants.plot(self.obj) + + # This wraps ANTsPy's `image_read` method to read NIfTI format + @classmethod + def create(cls, path:PathOrStr): + nimg = ants.image_read(str(path)) + t = torch.tensor(nimg.numpy()) + return cls(t, nimg, path) + + def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs): + key = lambda o : getattr(o, order, 0) + for tfm in sorted(listify(tfms), key=key): self = tfm(self, *args, **kwargs) #ascending order eg. [3,2,1] -> [1,2,3] + return self + +class MultiNiftiImage(ItemBase): + "Support handling multi-channel NIfTI images" + def __init__(self, obj:Tuple[NiftiImage]): + self.obj = obj # type annotation violated when `subregionify` is used. Should be fixed. + self.data = None + + def __repr__(self): + return f"Inside {self.__class__.__name__}:\n {[self.obj[i] for i in range(len(self.obj))]}" + + def __getitem__(self, i): + return self.obj[i] + + @classmethod + def create(cls, paths:FilePathList): + obj = tuple([NiftiImage.create(str(path)) for path in paths]) + return cls(obj) + + def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs): + self.obj = tuple([self.obj[i].apply_tfms(tfms, order, *args, **kwargs) for i in range(len(self.obj))]) + self.data = torch.stack([nft.data for nft in self.obj], dim=0) + return self + + @property + def data(self): + return self._data + + @data.setter + def data(self, _): + self._data = ( torch.stack([nft.data for nft in self.obj], dim=0) + if hasattr(self.obj[0], "data") + else torch.stack([torch.tensor(nft.numpy()) for nft in self.obj], dim=0) ) + +class NiftiImageList(ItemList): + + def __repr__(self)->str: + return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, + len(self.items), show_some(self.items, n_max=4, sep="\n"), + self.path) + def get(self, i)->NiftiImage: + fn = str(self.items[i]) + return NiftiImage.create(fn) + +class MultiNiftiImageList(ItemList): + + def __repr__(self)->str: + return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, + len(self.items), show_some(self.items, n_max=4, sep="\n"), + self.path) + def get(self, i)->MultiNiftiImage: + filepaths = [str(self.items[i][x]) for x in range(len(self.items[i]))] + return MultiNiftiImage.create(filepaths) + + @classmethod + def from_folder(cls, folderpaths:FilePathList, modality:Union[str, Collection[str]], + presort:bool=False, **kwargs): + """ + This method assumes a list of full paths to the desired files's parent folders + and returns NiftiImageTupleList whose item is a nested list with each sublist + belonging to its parent folder + ------------------------------------------------------------------------- + Test: + assert len(filepaths) == len(path) + + """ + filepaths=[] + for fp in folderpaths: + filepath = get_files(fp, modality=modality, presort=True) + filepaths.append(filepath) + + return cls(items=filepaths, path=path, **kwargs) + +hgg_subdirs = (path/'HGG').ls() +lgg_subdirs = (path/'LGG').ls() +parent_folders = hgg_subdirs + lgg_subdirs + +def get_parents(path:Path, pname:str, shuffle:bool=True, pct=0.2): + "List a certain percent of items under a specified parent directory randomly or not" + from random import shuffle + ps = [d[i] for r,d,_ in os.walk(path) for i in range(len(d)) if Path(r).name==pname] + if shuffle: shuffle(ps) + return ps[:round((pct*len(ps)))] + +def write_val_list(fname:str='valid.txt', vals:List[str]=None): + "Write a list of names into `fname` to be used for train/validation split" + val_list = vals + with open(fname, 'w') as f: + f.write('\n'.join(val_list)) + print("{} items written into {}.".format(len(val_list), fname)) + +val_list = get_parents(path, 'HGG', pct=0.15) + get_parents(path, 'LGG', pct=0.1) +write_val_list('valid.txt', val_list) + +def split_by_parents(self, valid_names:'ItemList')->'ItemLists': + "Split the data by using the parent names in `valid_names` for validation." + return self.split_by_valid_func(lambda o: o.parent.name in valid_names) + +def split_by_pname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists': + "Split the data by using the parent names in `fname` for the validation set. `path` will override `self.path`." + path = Path(ifnone(path, self.path)) + valid_names = loadtxt_str(path/fname) + return self.split_by_parents(valid_names) + +def split_by_valid_func(self, func:Callable)->'ItemLists': + "Split the data by result of `func` (which returns `True` for validation set)." + valid_idx = [i for i,o in enumerate(self.items) if func(o[0])] + return self.split_by_idx(valid_idx) + +def _repr_labellist(self)->str: + items = [self[i] for i in range(min(1,len(self.items)))] + res = f'{self.__class__.__name__} ({len(self.items)} items)\n' + res += f'x: {self.x.__class__.__name__}\n{show_some([i[0] for i in items], n_max=1)}\n' + res += f'y: {self.y.__class__.__name__}\n{show_some([i[1] for i in items], n_max=1)}\n' + return res + f'Path: {self.path}' + +# Modify the methods of `MultiNiftiImageList` object +MultiNiftiImageList.split_by_parents = split_by_parents +MultiNiftiImageList.split_by_pname_file = split_by_pname_file +MultiNiftiImageList.split_by_valid_func = split_by_valid_func + +# Modify the representation of `LabelList` object +LabelList.__repr__ = _repr_labellist + +class NiftiSegmentationLabelList(NiftiImageList): + "`ItemList` for NIfTI segmentatoin masks" + _processor=SegmentationProcessor + + def __init__(self, items:Iterator, classes:Collection=None, **kwargs): + super().__init__(items, **kwargs) + self.copy_new.append('classes') + self.classes,self.loss_func = classes,None + + def reconstruct(self, t:Tensor): + obj = ants.from_numpy(t.numpy()) + path = self.path + return NiftiImage(t, obj, path) + +get_y_fn = lambda x: x[0].parent/Path(x[0].as_posix().split(os.sep)[-2]+'_seg.nii.gz') + +subregion = np.array(['WT', 'TC', 'ET']) + +def crop_3d(item:NiftiImage, do_resolve=False, *args, lowerind:Tuple, upperind:Tuple, **kwargs): + "Crop 3-dimensional NIfTI image by slicing indices from lower to upper indices per image axis" + cropped_item = item.obj.crop_indices(lowerind, upperind) + item.obj = cropped_item + return item + +def standardize(item:NiftiImage, do_resolve=False, *args, **kwargs): + "Standardize our custom itembase `NiftiImage` to have zero mean and unit std based on non-zero voxels only" + arr = item.obj.numpy() + arr_nonzero = arr[arr!=0] + arr_nonzero = (arr_nonzero - arr_nonzero.mean()) / arr_nonzero.std() + arr[arr!=0] = arr_nonzero / arr_nonzero.max() + item.obj = ants.from_numpy(arr) + return item + +def subregionify(item:NiftiImage, do_resolve=False, *args, **kwargs): + "Combine the three annotations into 3 nested subregions: Whole Tumor(WT), Tumor Core(TC), Enhancing Tumor(ET)" + arr = item.obj.numpy() + wt_arr = arr.copy() + wt_arr[wt_arr==1.] = 1.; wt_arr[wt_arr==2.] = 1.; wt_arr[wt_arr==4.] = 1. + tc_arr = arr.copy() + tc_arr[tc_arr==1.] = 1.; tc_arr[tc_arr==2.] = 0.; tc_arr[tc_arr==4.] = 1. + et_arr = arr.copy() + et_arr[et_arr==1.] = 0.; et_arr[et_arr==2.] = 0.; et_arr[et_arr==4.] = 1. + return MultiNiftiImage([ants.from_numpy(arr) for arr in [wt_arr, tc_arr, et_arr]]) + +crop_3d = Transform(crop_3d, order=0) # Applied to 'x' first then `y` for a implementation detail with overwrite +standardize = Transform(standardize, order=1) # Only applied to 'x' +subregionify = Transform(subregionify, order=1) # Only applied to 'y' + +x_transform = [crop_3d, standardize] +y_transform = [crop_3d, subregionify] + +data = (MultiNiftiImageList.from_folder(parent_folders, modality=['Flair', 'T1', 'T2', 'T1ce']) + .split_by_pname_file(fname='valid.txt', path=Path('.')) + .label_from_func(get_y_fn, classes=subregion, label_cls=NiftiSegmentationLabelList) + .transform((x_transform, x_transform), tfm_y=False, lowerind=(40,28,10), upperind=(200,220,138)) + .transform_y((y_transform, y_transform), lowerind=(40,28,10), upperind=(200,220,138)) + .databunch(bs=1, collate_fn=data_collate, num_workers=0)) \ No newline at end of file