[dc3c86]: / mrnet_itemlist.py

Download this file

149 lines (126 with data), 7.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as nd
from fastai.vision import *
# ItemBase subclass
class MRNetCase(ItemBase):
# ItemBase.data # this needs to be developed in parallel with the ItemList's .get method
# imagine that .get will return a list of three np.arrays, an array for each plane
def __init__(self, axial, coronal, sagittal):
self.axial,self.coronal,self.sagittal = axial,coronal,sagittal
self.obj = (axial,coronal,sagittal)
self.data = np.stack([axial,coronal,sagittal], axis=0)
# __str__ representation, or __repr__, since __str__ falls back on __repr__ if not otherwise defined
def __repr__(self):
return f'''
{self.__class__.__name__}
`obj` attribute is tuple(axial, coronal, sagittal):
{list(e.shape for e in self.obj)}
`data` attribute is all three planar scan data arrays,
with variations in slice count removed via interpolation
{self.data.shape}
{self.data}
'''
# apply_tfms (optional)
# DataBunch subclass
class MRNetCaseDataBunch(DataBunch):
"DataBunch for MRNet knee scan data."
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,
fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':
"Create DataBunch from a `DataFrame` `df`."
src = (MRNetCaseList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
.split_by_rand_pct(valid_pct)
.label_from_df(label_delim=label_delim, cols=label_col))
return cls.create_from_ll(src, **kwargs)
# ItemList subclass
class MRNetCaseList(ItemList):
# class variables
_bunch = MRNetCaseDataBunch
# _processor
# _label_cls
# __init__ arguments
# items for this subclass will likely be a list/iterator of Case strings
# rather than filenames, since each case has 3 filenames, one for each plane
# "Any additional arguments to the __init__ call that are saved in the ItemList's state must be passed
# along in the `new` method, because that is what is used to created train and validation sets when splitting.
# To do that, need to add their names in the `copy_new` argument of custom ItemList during the __init__.
# However, be sure to keep **kwargs as is."
def __init__(self, items, path, **kwargs):
super().__init__(items=items, path=path, **kwargs)
# core methods
# get
def get(self, i):
# i indexes self.items, which is an ordered array of case numbers as strings
case = super().get(i)
# cases belong to either train or valid split in the folder structure
tv = 'train' if (self.path/'train'/'axial'/(case + '.npy')).exists() else 'valid'
imagearrays = []
for plane in ('axial','coronal','sagittal'):
# self.path is available from kwargs of ItemList superclass
fn = self.path/tv/plane/(case + '.npy')
res = self.open(fn)
imagearrays.append(res)
assert len(imagearrays) == 3
return MRNetCase(*imagearrays)
# since subclassing ItemList rather than ImageList, need an open method
def open(self, fn): return np.load(fn)
def reconstruct(self, t):
# t is the pytorch tensor corresponding to the .data attribute of MRNetCase
# the result of reconstruct should be to
# "return the same kind of object as .get returns"
# which is a MRNetCase
# and to build that, a tuple of numpy arrays is required
arrays = to_np(t)
return MRNetCase(arrays[0,:,:,:],arrays[1,:,:,:],arrays[2,:,:,:])
def show_batch(self, rows:int=4, ds_type:DatasetType=DatasetType.Train, **kwargs)->None:
"Show a batch of data in `ds_type` on a few `rows`."
x,y = self.one_batch(ds_type, True, True)
n_items = rows
if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
#TODO: get rid of has_arg if possible
if has_arg(self.train_ds.y.reconstruct, 'x'):
ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
self.train_ds.x.show_xys(xs, ys, **kwargs)
# custom show_xys method for show_batch
# would like to have each row correspond to a case from three planes
def show_xys(self, xs, ys, imgsize:int=4,
figsize:Tuple[int,int]=(13,13), **kwargs):
"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
rows = len(xs)
fig, axarray = plt.subplots(rows, 3, figsize=figsize)
# start by showing the middle slice, from each plane
planes = ('axial','coronal','sagittal')
for i,(x,y) in enumerate(zip(xs, ys)):
for p,plane in enumerate(planes):
axarray[i,p].imshow(x.data[p,11,:,:])
axarray[i,p].set_title('{} ({})'.format(y, plane))
plt.tight_layout()
# TODO: def show_xyzs # once have predictions to work with
# TODO: def analyze_pred
@classmethod
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=['.npy'], **kwargs)->'MRNetCaseList':
"Get the Case numbers for all MRNet cases, assuming directory structure unchanged from MRNet data download"
filepaths = get_files(path=path, extensions=extensions, recurse=True, **kwargs)
items = sorted(set([fp.stem for fp in filepaths]))
return cls(items=items, path=path)
def split_by_folder(self, train:str='train', valid:str='valid') -> 'MRNetCaseLists':
# given the list of items in the itemlist
# construct lists of train and valid indexes
# check whether the item is in a train folder or a validation folder
# arbitrarily choosing axial subfolder to check for case array
valid_idx = [i for i,case in enumerate(self.items) if (self.path/'valid'/'axial'/(case + '.npy')).exists()]
# then use split_by_idx to return split item lists
return self.split_by_idx(valid_idx=valid_idx)
def link_label_df(self, df):
"Associate labels to cases using pandas DataFrame having Case column and one or more label columns"
# want to be able to use the existing fastai code around multiple labels and such
# so, need to associate a df to the CaseList object
# which will be referenced in multiple places as self.inner_df
# first join the df to the case numbers in self.items
casesDF = pd.DataFrame({'Case': self.items})
self.inner_df = pd.merge(casesDF, df, on ='Case')