Diff of /mrnet_itemlist.py [000000] .. [dc3c86]

Switch to side-by-side view

--- a
+++ b/mrnet_itemlist.py
@@ -0,0 +1,148 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy import ndimage as nd
+from fastai.vision import *
+
+
+# ItemBase subclass
+class MRNetCase(ItemBase):
+
+  # ItemBase.data   # this needs to be developed in parallel with the ItemList's .get method
+  # imagine that .get will return a list of three np.arrays, an array for each plane
+    def __init__(self, axial, coronal, sagittal):
+        self.axial,self.coronal,self.sagittal = axial,coronal,sagittal
+        self.obj = (axial,coronal,sagittal)
+        self.data = np.stack([axial,coronal,sagittal], axis=0)
+        
+  # __str__ representation, or __repr__, since __str__ falls back on __repr__ if not otherwise defined
+    def __repr__(self): 
+        return f'''
+        {self.__class__.__name__} 
+        `obj` attribute is tuple(axial, coronal, sagittal): 
+        {list(e.shape for e in self.obj)}
+            
+        `data` attribute is all three planar scan data arrays, 
+        with variations in slice count removed via interpolation
+        {self.data.shape}
+        {self.data}
+        '''
+
+  # apply_tfms (optional)
+
+# DataBunch subclass
+class MRNetCaseDataBunch(DataBunch):
+    "DataBunch for MRNet knee scan data."
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    @classmethod
+    def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,
+                fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':
+        "Create DataBunch from a `DataFrame` `df`."
+        src = (MRNetCaseList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
+                .split_by_rand_pct(valid_pct)
+                .label_from_df(label_delim=label_delim, cols=label_col))
+        return cls.create_from_ll(src, **kwargs)
+
+
+# ItemList subclass
+class MRNetCaseList(ItemList):
+  # class variables
+    _bunch = MRNetCaseDataBunch
+    # _processor
+    # _label_cls
+
+  # __init__ arguments
+  # items for this subclass will likely be a list/iterator of Case strings
+  # rather than filenames, since each case has 3 filenames, one for each plane
+#   "Any additional arguments to the __init__ call that are saved in the ItemList's state must be passed
+#   along in the `new` method, because that is what is used to created train and validation sets when splitting.
+#   To do that, need to add their names in the `copy_new` argument of custom ItemList during the __init__.
+#   However, be sure to keep **kwargs as is."
+    def __init__(self, items, path, **kwargs):
+        super().__init__(items=items, path=path, **kwargs)
+
+  # core methods
+    # get
+    def get(self, i):
+        # i indexes self.items, which is an ordered array of case numbers as strings
+        case = super().get(i)
+#       cases belong to either train or valid split in the folder structure
+        tv = 'train' if (self.path/'train'/'axial'/(case + '.npy')).exists() else 'valid'
+        imagearrays = []
+        for plane in ('axial','coronal','sagittal'):
+            # self.path is available from kwargs of ItemList superclass
+            fn  = self.path/tv/plane/(case + '.npy')
+            res = self.open(fn)
+            imagearrays.append(res) 
+        assert len(imagearrays) == 3
+        return MRNetCase(*imagearrays)
+
+    # since subclassing ItemList rather than ImageList, need an open method
+    def open(self, fn): return np.load(fn)
+
+    def reconstruct(self, t):
+        # t is the pytorch tensor corresponding to the .data attribute of MRNetCase
+        # the result of reconstruct should be to 
+        # "return the same kind of object as .get returns"
+        # which is a MRNetCase 
+        # and to build that, a tuple of numpy arrays is required
+        arrays = to_np(t)
+        return MRNetCase(arrays[0,:,:,:],arrays[1,:,:,:],arrays[2,:,:,:]) 
+
+    def show_batch(self, rows:int=4, ds_type:DatasetType=DatasetType.Train, **kwargs)->None:
+        "Show a batch of data in `ds_type` on a few `rows`."
+        x,y = self.one_batch(ds_type, True, True)
+        n_items = rows
+        if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
+        xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
+        #TODO: get rid of has_arg if possible
+        if has_arg(self.train_ds.y.reconstruct, 'x'):
+            ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
+        else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
+        self.train_ds.x.show_xys(xs, ys, **kwargs)
+
+    # custom show_xys method for show_batch
+    # would like to have each row correspond to a case from three planes
+    def show_xys(self, xs, ys, imgsize:int=4, 
+                 figsize:Tuple[int,int]=(13,13), **kwargs):
+        "Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
+        rows = len(xs)
+        fig, axarray = plt.subplots(rows, 3, figsize=figsize)        
+        # start by showing the middle slice, from each plane
+        planes = ('axial','coronal','sagittal')
+        for i,(x,y) in enumerate(zip(xs, ys)):
+            for p,plane in enumerate(planes): 
+                axarray[i,p].imshow(x.data[p,11,:,:])
+                axarray[i,p].set_title('{} ({})'.format(y, plane))
+        plt.tight_layout()
+    # TODO: def show_xyzs # once have predictions to work with
+
+    # TODO: def analyze_pred                                                                                                                                                                                                                                                                                                                    
+
+    @classmethod
+    def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=['.npy'], **kwargs)->'MRNetCaseList':
+        "Get the Case numbers for all MRNet cases, assuming directory structure unchanged from MRNet data download"
+        filepaths = get_files(path=path, extensions=extensions, recurse=True, **kwargs)
+        items = sorted(set([fp.stem for fp in filepaths]))
+        return cls(items=items, path=path)
+
+    def split_by_folder(self, train:str='train', valid:str='valid') -> 'MRNetCaseLists':
+        # given the list of items in the itemlist
+        # construct lists of train and valid indexes
+        # check whether the item is in a train folder or a validation folder
+        # arbitrarily choosing axial subfolder to check for case array
+        valid_idx = [i for i,case in enumerate(self.items) if (self.path/'valid'/'axial'/(case + '.npy')).exists()]
+        # then use split_by_idx to return split item lists
+        return self.split_by_idx(valid_idx=valid_idx)
+
+    def link_label_df(self, df):
+        "Associate labels to cases using pandas DataFrame having Case column and one or more label columns"
+        # want to be able to use the existing fastai code around multiple labels and such
+        # so, need to associate a df to the CaseList object
+        # which will be referenced in multiple places as self.inner_df
+        # first join the df to the case numbers in self.items
+        casesDF = pd.DataFrame({'Case': self.items})
+        self.inner_df = pd.merge(casesDF, df, on ='Case')
+
+