a b/ndv/modules/dataloader.py
1
import torch, fastai, sys, os
2
from fastai.vision import *
3
from fastai.vision.data import SegmentationProcessor
4
import ants
5
from ants.core.ants_image import ANTsImage
6
from jupyterthemes import jtplot
7
sys.path.insert(0, './exp')
8
jtplot.style(theme='gruvboxd')
9
10
# Set a root directory
11
path = Path('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training')
12
13
def is_mod(fn:str, mod:str)->bool:
14
    "Check if file path contains a specified name of modality used for MRI"
15
    import re
16
    r = re.compile('.*' + mod, re.IGNORECASE)
17
    return True if r.match(fn) else False
18
19
def is_mods(fn:str, mods:Collection[str])->bool:
20
    "Check if file path contains specified names of modality used for MRI"
21
    import re
22
    return any([is_mod(fn, mod) for mod in mods])
23
24
def _path_to_same_str(p_fn):
25
    "path -> str, but same on nt+posix, for alpha-sort only"
26
    s_fn = str(p_fn)
27
    s_fn = s_fn.replace('\\','.')
28
    s_fn = s_fn.replace('/','.')
29
    return s_fn
30
31
def _get_files(path, file, modality):
32
    """
33
    Internal implementation for `get_files` to combine a parent directory with a file 
34
    to make a full path to file(s)
35
    """
36
    p = Path(path)
37
    res = [p/o for o in file if not o.startswith('.') and is_mods(o, modality)]
38
    assert len(res)==len(modality) #TODO: Assert message
39
    return res
40
41
def get_files(path:PathOrStr, modality:Union[str, Collection[str]], 
42
                presort:bool=False)->FilePathList:
43
    "Return a list of full file paths in `path` each of which contains modality in its name"
44
    file = [o.name for o in os.scandir(path) if o.is_file()]
45
    res = _get_files(path, file, modality)
46
    if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False)
47
    return res
48
49
def _repr_antsimage(self):
50
    if self.dimension == 3:
51
        s = 'NiftiImage ({})\n'.format(self.orientation)
52
    else:
53
        s = 'NiftiImage\n'
54
    s = s +\
55
        '\t {:<10} : {} ({})\n'.format('Pixel Type', self.pixeltype, self.dtype)+\
56
        '\t {:<10} : {}{}\n'.format('Components', self.components, ' (RGB)' if 'RGB' in self._libsuffix else '')+\
57
        '\t {:<10} : {}\n'.format('Dimensions', self.shape)+\
58
        '\t {:<10} : {}\n'.format('Spacing', tuple([round(s,4) for s in self.spacing]))+\
59
        '\t {:<10} : {}\n'.format('Origin', tuple([round(o,4) for o in self.origin]))+\
60
        '\t {:<10} : {}\n'.format('Direction', np.round(self.direction.flatten(),4))
61
    return s
62
63
# Modify the representation of `ANTsImage` object
64
ANTsImage.__repr__ = _repr_antsimage
65
66
class NiftiImage(ItemBase):
67
  "Support handling NIfTI image format" 
68
  #TODO: Extend the code so as to support various Python (medical) libraries that can read NIfTI format   
69
  def __init__(self, data:Union[Tensor,np.array], obj:ANTsImage, path:str): 
70
    self.data = data
71
    self.obj = obj
72
    self.path = path
73
    # Only works for a specific folder tree
74
    self.mod = self.path.split(".")[0].split("_")[-1]
75
  
76
  def __repr__(self): return str(self.obj) + '\t {:<10} : {}\n\n'.format('Modality', str(self.mod))
77
78
  def __getattr__(self, k:str):
79
    func = getattr(self.obj, k)
80
    if isinstance(func, Callable): return func
81
  
82
  def __setattr__(self, k, v):
83
    if k == 'obj':
84
        self.data = torch.tensor(v.numpy())
85
    return super().__setattr__(k, v)
86
87
  # This wraps ANTsPy's `plot` method to show NIfTI image
88
  def show(self, **kwargs):
89
    ants.plot(self.obj)
90
91
  # This wraps ANTsPy's `image_read` method to read NIfTI format
92
  @classmethod
93
  def create(cls, path:PathOrStr):
94
    nimg = ants.image_read(str(path))
95
    t = torch.tensor(nimg.numpy())
96
    return cls(t, nimg, path)
97
98
  def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs):
99
    key = lambda o : getattr(o, order, 0)
100
    for tfm in sorted(listify(tfms), key=key): self = tfm(self, *args, **kwargs) #ascending order eg. [3,2,1] -> [1,2,3]
101
    return self
102
103
class MultiNiftiImage(ItemBase):
104
  "Support handling multi-channel NIfTI images"
105
  def __init__(self, obj:Tuple[NiftiImage]):
106
    self.obj = obj # type annotation violated when `subregionify` is used. Should be fixed.
107
    self.data = None
108
  
109
  def __repr__(self): 
110
        return f"Inside {self.__class__.__name__}:\n {[self.obj[i] for i in range(len(self.obj))]}"       
111
   
112
  def __getitem__(self, i):
113
        return self.obj[i]
114
        
115
  @classmethod
116
  def create(cls, paths:FilePathList):
117
    obj = tuple([NiftiImage.create(str(path)) for path in paths])
118
    return cls(obj)
119
120
  def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs):
121
    self.obj = tuple([self.obj[i].apply_tfms(tfms, order, *args, **kwargs) for i in range(len(self.obj))])
122
    self.data = torch.stack([nft.data for nft in self.obj], dim=0)
123
    return self
124
125
  @property
126
  def data(self):
127
    return self._data
128
129
  @data.setter
130
  def data(self, _):
131
    self._data = ( torch.stack([nft.data for nft in self.obj], dim=0) 
132
                  if hasattr(self.obj[0], "data") 
133
                  else torch.stack([torch.tensor(nft.numpy()) for nft in self.obj], dim=0) )
134
135
class NiftiImageList(ItemList):
136
     
137
  def __repr__(self)->str: 
138
    return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, 
139
                                                len(self.items), show_some(self.items, n_max=4, sep="\n"), 
140
                                                self.path)  
141
  def get(self, i)->NiftiImage:
142
    fn = str(self.items[i])
143
    return NiftiImage.create(fn)
144
145
class MultiNiftiImageList(ItemList):
146
147
  def __repr__(self)->str: 
148
    return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, 
149
                                                len(self.items), show_some(self.items, n_max=4, sep="\n"), 
150
                                                self.path)  
151
  def get(self, i)->MultiNiftiImage:
152
    filepaths = [str(self.items[i][x]) for x in range(len(self.items[i]))]
153
    return MultiNiftiImage.create(filepaths)
154
155
  @classmethod
156
  def from_folder(cls, folderpaths:FilePathList, modality:Union[str, Collection[str]], 
157
                  presort:bool=False, **kwargs):
158
    """
159
    This method assumes a list of full paths to the desired files's parent folders 
160
    and returns NiftiImageTupleList whose item is a nested list with each sublist 
161
    belonging to its parent folder
162
    -------------------------------------------------------------------------
163
        Test:
164
        assert len(filepaths) == len(path)
165
        
166
    """
167
    filepaths=[]
168
    for fp in folderpaths:
169
      filepath = get_files(fp, modality=modality, presort=True)
170
      filepaths.append(filepath)
171
        
172
    return cls(items=filepaths, path=path, **kwargs)
173
174
hgg_subdirs = (path/'HGG').ls()
175
lgg_subdirs = (path/'LGG').ls()
176
parent_folders = hgg_subdirs + lgg_subdirs
177
178
def get_parents(path:Path, pname:str, shuffle:bool=True, pct=0.2):
179
  "List a certain percent of items under a specified parent directory randomly or not"
180
  from random import shuffle
181
  ps = [d[i] for r,d,_ in os.walk(path) for i in range(len(d)) if Path(r).name==pname] 
182
  if shuffle: shuffle(ps)
183
  return ps[:round((pct*len(ps)))]
184
185
def write_val_list(fname:str='valid.txt', vals:List[str]=None):
186
  "Write a list of names into `fname` to be used for train/validation split"
187
  val_list = vals
188
  with open(fname, 'w') as f:
189
    f.write('\n'.join(val_list))
190
  print("{} items written into {}.".format(len(val_list), fname))
191
192
val_list = get_parents(path, 'HGG', pct=0.15) + get_parents(path, 'LGG', pct=0.1)
193
write_val_list('valid.txt', val_list)
194
195
def split_by_parents(self, valid_names:'ItemList')->'ItemLists':
196
  "Split the data by using the parent names in `valid_names` for validation."
197
  return self.split_by_valid_func(lambda o: o.parent.name in valid_names)
198
199
def split_by_pname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists':
200
  "Split the data by using the parent names in `fname` for the validation set. `path` will override `self.path`."
201
  path = Path(ifnone(path, self.path))
202
  valid_names = loadtxt_str(path/fname)
203
  return self.split_by_parents(valid_names) 
204
205
def split_by_valid_func(self, func:Callable)->'ItemLists':
206
  "Split the data by result of `func` (which returns `True` for validation set)."
207
  valid_idx = [i for i,o in enumerate(self.items) if func(o[0])]
208
  return self.split_by_idx(valid_idx)
209
    
210
def _repr_labellist(self)->str:
211
  items = [self[i] for i in range(min(1,len(self.items)))]
212
  res = f'{self.__class__.__name__} ({len(self.items)} items)\n'
213
  res += f'x: {self.x.__class__.__name__}\n{show_some([i[0] for i in items], n_max=1)}\n'
214
  res += f'y: {self.y.__class__.__name__}\n{show_some([i[1] for i in items], n_max=1)}\n'
215
  return res + f'Path: {self.path}'
216
217
# Modify the methods of `MultiNiftiImageList` object
218
MultiNiftiImageList.split_by_parents = split_by_parents
219
MultiNiftiImageList.split_by_pname_file = split_by_pname_file
220
MultiNiftiImageList.split_by_valid_func = split_by_valid_func
221
222
# Modify the representation of `LabelList` object
223
LabelList.__repr__ = _repr_labellist
224
225
class NiftiSegmentationLabelList(NiftiImageList):
226
  "`ItemList` for NIfTI segmentatoin masks"
227
  _processor=SegmentationProcessor
228
    
229
  def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
230
    super().__init__(items, **kwargs)
231
    self.copy_new.append('classes')
232
    self.classes,self.loss_func = classes,None
233
    
234
  def reconstruct(self, t:Tensor): 
235
    obj = ants.from_numpy(t.numpy())
236
    path = self.path
237
    return NiftiImage(t, obj, path)
238
239
get_y_fn = lambda x: x[0].parent/Path(x[0].as_posix().split(os.sep)[-2]+'_seg.nii.gz')
240
241
subregion = np.array(['WT', 'TC', 'ET']) 
242
243
def crop_3d(item:NiftiImage, do_resolve=False, *args, lowerind:Tuple, upperind:Tuple, **kwargs):
244
  "Crop 3-dimensional NIfTI image by slicing indices from lower to upper indices per image axis"
245
  cropped_item = item.obj.crop_indices(lowerind, upperind)
246
  item.obj = cropped_item
247
  return item
248
249
def standardize(item:NiftiImage, do_resolve=False, *args, **kwargs):
250
  "Standardize our custom itembase `NiftiImage` to have zero mean and unit std based on non-zero voxels only"
251
  arr = item.obj.numpy()
252
  arr_nonzero = arr[arr!=0]
253
  arr_nonzero = (arr_nonzero - arr_nonzero.mean()) / arr_nonzero.std()
254
  arr[arr!=0] = arr_nonzero / arr_nonzero.max()
255
  item.obj = ants.from_numpy(arr)
256
  return item
257
258
def subregionify(item:NiftiImage, do_resolve=False, *args, **kwargs): 
259
  "Combine the three annotations into 3 nested subregions: Whole Tumor(WT), Tumor Core(TC), Enhancing Tumor(ET)"
260
  arr = item.obj.numpy()
261
  wt_arr = arr.copy()
262
  wt_arr[wt_arr==1.] = 1.; wt_arr[wt_arr==2.] = 1.; wt_arr[wt_arr==4.] = 1.
263
  tc_arr = arr.copy()
264
  tc_arr[tc_arr==1.] = 1.; tc_arr[tc_arr==2.] = 0.; tc_arr[tc_arr==4.] = 1.
265
  et_arr = arr.copy()
266
  et_arr[et_arr==1.] = 0.; et_arr[et_arr==2.] = 0.; et_arr[et_arr==4.] = 1.
267
  return MultiNiftiImage([ants.from_numpy(arr) for arr in [wt_arr, tc_arr, et_arr]])
268
269
crop_3d = Transform(crop_3d, order=0)            # Applied to 'x' first then `y` for a implementation detail with overwrite
270
standardize = Transform(standardize, order=1)    # Only applied to 'x'
271
subregionify = Transform(subregionify, order=1)  # Only applied to 'y'
272
273
x_transform = [crop_3d, standardize]
274
y_transform = [crop_3d, subregionify]
275
276
data = (MultiNiftiImageList.from_folder(parent_folders, modality=['Flair', 'T1', 'T2', 'T1ce'])
277
               .split_by_pname_file(fname='valid.txt', path=Path('.'))
278
               .label_from_func(get_y_fn, classes=subregion, label_cls=NiftiSegmentationLabelList)
279
               .transform((x_transform, x_transform), tfm_y=False, lowerind=(40,28,10), upperind=(200,220,138))
280
               .transform_y((y_transform, y_transform), lowerind=(40,28,10), upperind=(200,220,138))
281
               .databunch(bs=1, collate_fn=data_collate, num_workers=0))