--- a +++ b/mrnet_orig.py @@ -0,0 +1,122 @@ +import numpy as np +from fastai.vision import * +import torch + +data_path = Path('../data') +sag_path = data_path/'sagittal' +cor_path = data_path/'coronal' +ax_path = data_path/'axial' + +weights = torch.load('loss_weights.pt') + +class MRNet(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.model = models.alexnet(pretrained=pretrained) + self.gap = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(256, 1) + + def forward(self, x): + # in the original code, the input was squeezed here, but this won't work with fastai + x = self.model.features(x) + x = self.gap(x).view(x.size(0), -1) + x = torch.max(x, 0, keepdim=True)[0] + return torch.sigmoid(self.classifier(x)) + + def __call__(self, x): return self.forward(x) + + +class WtBCELoss(nn.Module): + def __init__(self, wts): + super().__init__() + self.wts = wts.float() + + def forward(self, output, target): + loss = self.wts[0]*(target.float() * torch.log(output).float()) + self.wts[1]*((1-target).float() * torch.log(1-output).float()) + return torch.neg(torch.mean(loss)) + + +class MR3DImDataBunch(ImageDataBunch): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]: + "Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`." + dl = self.dl(ds_type) + w = self.num_workers + self.num_workers = 0 + try: x,y = next(iter(dl)) + finally: self.num_workers = w + if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu) + norm = getattr(self,'norm',False) + if denorm and norm: + x = self.denorm(x) + if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True) + x = torch.squeeze(x, dim=0) # squeeze needed here for learn.summary() + return x,y + + +class MR3DImageList(ImageList): + _bunch = MR3DImDataBunch # necessary for Data Block API functionality + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_slc = 51 # optimized for sagittal image stacks only...TODO rewrite for any max stack size + self.c = 1 + + # pads on both sides of image stack with zero arrays to equal max_slc + def open(self, fn): + x = np.load(fn) + if x.shape[0] < self.max_slc: + x_pad = np.zeros((self.max_slc, 256, 256)) + mid = x_pad.shape[0] // 2 + up = x.shape[0] // 2 + if x.shape[0] % 2 == 1: x_pad[mid-up:mid+up+1] = x + else: x_pad[mid-up:mid+up] = x + else: + x_pad = x + return self.arr2image(np.stack([x_pad]*3, axis=1)) + + # converts np.ndarray to fastai Image class + @staticmethod + def arr2image(arr:np.ndarray, div:bool=True, cls:type=Image): + x = Tensor(arr) + if div == True: x.div_(255) + return cls(x) + + +# squeeze input prior to loss calculation +class MRNetCallback(Callback): + def on_batch_begin(self, last_input, **kwargs): + x = torch.squeeze(last_input, dim=0) + return dict(last_input=x) + + +class MRNetLearner(Learner): + # redefine specifically for MRNet layer groups + def freeze_to(self, n:int)->None: + "Freeze layers up to layer group `n`." + for g in self.layer_groups[:n]: + if is_listy(g): + for l in g: + if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False) + else: requires_grad(g, False) + for g in self.layer_groups[n:]: requires_grad(g, True) + self.create_opt(defaults.lr) + + def freeze(self)->None: + "Freeze up to the last layer group." + self.freeze_to(-1) + self.create_opt(defaults.lr) + + def unfreeze(self): + "Unfreeze entire model." + self.freeze_to(0) + self.create_opt(defaults.lr) + + +def mrnet_learner(data:DataBunch, model:Callable=MRNet(), pretrained:bool=True, init=nn.init.kaiming_normal_, **kwargs:Any)->Learner: + _layer_groups = [model.model.features, model.model.avgpool, model.model.classifier, model.gap, model.classifier] + learn = MRNetLearner(data, model, layer_groups=_layer_groups, **kwargs) + if pretrained: learn.freeze() + if init: apply_init(model.classifier, init) + return learn