|
a |
|
b/mrnet_orig.py |
|
|
1 |
import numpy as np |
|
|
2 |
from fastai.vision import * |
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
data_path = Path('../data') |
|
|
6 |
sag_path = data_path/'sagittal' |
|
|
7 |
cor_path = data_path/'coronal' |
|
|
8 |
ax_path = data_path/'axial' |
|
|
9 |
|
|
|
10 |
weights = torch.load('loss_weights.pt') |
|
|
11 |
|
|
|
12 |
class MRNet(nn.Module): |
|
|
13 |
def __init__(self, pretrained=True): |
|
|
14 |
super().__init__() |
|
|
15 |
self.model = models.alexnet(pretrained=pretrained) |
|
|
16 |
self.gap = nn.AdaptiveAvgPool2d(1) |
|
|
17 |
self.classifier = nn.Linear(256, 1) |
|
|
18 |
|
|
|
19 |
def forward(self, x): |
|
|
20 |
# in the original code, the input was squeezed here, but this won't work with fastai |
|
|
21 |
x = self.model.features(x) |
|
|
22 |
x = self.gap(x).view(x.size(0), -1) |
|
|
23 |
x = torch.max(x, 0, keepdim=True)[0] |
|
|
24 |
return torch.sigmoid(self.classifier(x)) |
|
|
25 |
|
|
|
26 |
def __call__(self, x): return self.forward(x) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
class WtBCELoss(nn.Module): |
|
|
30 |
def __init__(self, wts): |
|
|
31 |
super().__init__() |
|
|
32 |
self.wts = wts.float() |
|
|
33 |
|
|
|
34 |
def forward(self, output, target): |
|
|
35 |
loss = self.wts[0]*(target.float() * torch.log(output).float()) + self.wts[1]*((1-target).float() * torch.log(1-output).float()) |
|
|
36 |
return torch.neg(torch.mean(loss)) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
class MR3DImDataBunch(ImageDataBunch): |
|
|
40 |
def __init__(self, *args, **kwargs): |
|
|
41 |
super().__init__(*args, **kwargs) |
|
|
42 |
|
|
|
43 |
def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]: |
|
|
44 |
"Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`." |
|
|
45 |
dl = self.dl(ds_type) |
|
|
46 |
w = self.num_workers |
|
|
47 |
self.num_workers = 0 |
|
|
48 |
try: x,y = next(iter(dl)) |
|
|
49 |
finally: self.num_workers = w |
|
|
50 |
if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu) |
|
|
51 |
norm = getattr(self,'norm',False) |
|
|
52 |
if denorm and norm: |
|
|
53 |
x = self.denorm(x) |
|
|
54 |
if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True) |
|
|
55 |
x = torch.squeeze(x, dim=0) # squeeze needed here for learn.summary() |
|
|
56 |
return x,y |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
class MR3DImageList(ImageList): |
|
|
60 |
_bunch = MR3DImDataBunch # necessary for Data Block API functionality |
|
|
61 |
def __init__(self, *args, **kwargs): |
|
|
62 |
super().__init__(*args, **kwargs) |
|
|
63 |
self.max_slc = 51 # optimized for sagittal image stacks only...TODO rewrite for any max stack size |
|
|
64 |
self.c = 1 |
|
|
65 |
|
|
|
66 |
# pads on both sides of image stack with zero arrays to equal max_slc |
|
|
67 |
def open(self, fn): |
|
|
68 |
x = np.load(fn) |
|
|
69 |
if x.shape[0] < self.max_slc: |
|
|
70 |
x_pad = np.zeros((self.max_slc, 256, 256)) |
|
|
71 |
mid = x_pad.shape[0] // 2 |
|
|
72 |
up = x.shape[0] // 2 |
|
|
73 |
if x.shape[0] % 2 == 1: x_pad[mid-up:mid+up+1] = x |
|
|
74 |
else: x_pad[mid-up:mid+up] = x |
|
|
75 |
else: |
|
|
76 |
x_pad = x |
|
|
77 |
return self.arr2image(np.stack([x_pad]*3, axis=1)) |
|
|
78 |
|
|
|
79 |
# converts np.ndarray to fastai Image class |
|
|
80 |
@staticmethod |
|
|
81 |
def arr2image(arr:np.ndarray, div:bool=True, cls:type=Image): |
|
|
82 |
x = Tensor(arr) |
|
|
83 |
if div == True: x.div_(255) |
|
|
84 |
return cls(x) |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
# squeeze input prior to loss calculation |
|
|
88 |
class MRNetCallback(Callback): |
|
|
89 |
def on_batch_begin(self, last_input, **kwargs): |
|
|
90 |
x = torch.squeeze(last_input, dim=0) |
|
|
91 |
return dict(last_input=x) |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
class MRNetLearner(Learner): |
|
|
95 |
# redefine specifically for MRNet layer groups |
|
|
96 |
def freeze_to(self, n:int)->None: |
|
|
97 |
"Freeze layers up to layer group `n`." |
|
|
98 |
for g in self.layer_groups[:n]: |
|
|
99 |
if is_listy(g): |
|
|
100 |
for l in g: |
|
|
101 |
if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False) |
|
|
102 |
else: requires_grad(g, False) |
|
|
103 |
for g in self.layer_groups[n:]: requires_grad(g, True) |
|
|
104 |
self.create_opt(defaults.lr) |
|
|
105 |
|
|
|
106 |
def freeze(self)->None: |
|
|
107 |
"Freeze up to the last layer group." |
|
|
108 |
self.freeze_to(-1) |
|
|
109 |
self.create_opt(defaults.lr) |
|
|
110 |
|
|
|
111 |
def unfreeze(self): |
|
|
112 |
"Unfreeze entire model." |
|
|
113 |
self.freeze_to(0) |
|
|
114 |
self.create_opt(defaults.lr) |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
def mrnet_learner(data:DataBunch, model:Callable=MRNet(), pretrained:bool=True, init=nn.init.kaiming_normal_, **kwargs:Any)->Learner: |
|
|
118 |
_layer_groups = [model.model.features, model.model.avgpool, model.model.classifier, model.gap, model.classifier] |
|
|
119 |
learn = MRNetLearner(data, model, layer_groups=_layer_groups, **kwargs) |
|
|
120 |
if pretrained: learn.freeze() |
|
|
121 |
if init: apply_init(model.classifier, init) |
|
|
122 |
return learn |