Diff of /mrnet_orig.py [000000] .. [dc3c86]

Switch to unified view

a b/mrnet_orig.py
1
import numpy as np
2
from fastai.vision import *
3
import torch
4
5
data_path = Path('../data')
6
sag_path = data_path/'sagittal'
7
cor_path = data_path/'coronal'
8
ax_path = data_path/'axial'
9
10
weights = torch.load('loss_weights.pt')
11
12
class MRNet(nn.Module):
13
    def __init__(self, pretrained=True):
14
        super().__init__()
15
        self.model = models.alexnet(pretrained=pretrained)
16
        self.gap = nn.AdaptiveAvgPool2d(1)
17
        self.classifier = nn.Linear(256, 1)
18
19
    def forward(self, x):
20
        # in the original code, the input was squeezed here, but this won't work with fastai
21
        x = self.model.features(x)
22
        x = self.gap(x).view(x.size(0), -1)
23
        x = torch.max(x, 0, keepdim=True)[0]
24
        return torch.sigmoid(self.classifier(x))
25
26
    def __call__(self, x): return self.forward(x)
27
28
29
class WtBCELoss(nn.Module):
30
    def __init__(self, wts):
31
        super().__init__()
32
        self.wts = wts.float()
33
34
    def forward(self, output, target):
35
        loss = self.wts[0]*(target.float() * torch.log(output).float()) + self.wts[1]*((1-target).float() * torch.log(1-output).float())
36
        return torch.neg(torch.mean(loss))
37
38
39
class MR3DImDataBunch(ImageDataBunch):
40
    def __init__(self, *args, **kwargs):
41
        super().__init__(*args, **kwargs)
42
43
    def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]:
44
        "Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`."
45
        dl = self.dl(ds_type)
46
        w = self.num_workers
47
        self.num_workers = 0
48
        try:     x,y = next(iter(dl))
49
        finally: self.num_workers = w
50
        if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu)
51
        norm = getattr(self,'norm',False)
52
        if denorm and norm:
53
            x = self.denorm(x)
54
            if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True)
55
        x = torch.squeeze(x, dim=0) # squeeze needed here for learn.summary()
56
        return x,y
57
58
59
class MR3DImageList(ImageList):
60
    _bunch = MR3DImDataBunch # necessary for Data Block API functionality
61
    def __init__(self, *args, **kwargs):
62
        super().__init__(*args, **kwargs)
63
        self.max_slc = 51 # optimized for sagittal image stacks only...TODO rewrite for any max stack size
64
        self.c = 1
65
66
    # pads on both sides of image stack with zero arrays to equal max_slc
67
    def open(self, fn):
68
        x = np.load(fn)
69
        if x.shape[0] < self.max_slc:
70
            x_pad = np.zeros((self.max_slc, 256, 256))
71
            mid = x_pad.shape[0] // 2
72
            up = x.shape[0] // 2
73
            if x.shape[0] % 2 == 1: x_pad[mid-up:mid+up+1] = x
74
            else: x_pad[mid-up:mid+up] = x
75
        else:
76
            x_pad = x
77
        return self.arr2image(np.stack([x_pad]*3, axis=1))
78
79
    # converts np.ndarray to fastai Image class
80
    @staticmethod
81
    def arr2image(arr:np.ndarray, div:bool=True, cls:type=Image):
82
        x = Tensor(arr)
83
        if div == True: x.div_(255)
84
        return cls(x)
85
86
87
# squeeze input prior to loss calculation
88
class MRNetCallback(Callback):
89
    def on_batch_begin(self, last_input, **kwargs):
90
        x = torch.squeeze(last_input, dim=0)
91
        return dict(last_input=x)
92
93
94
class MRNetLearner(Learner):
95
    # redefine specifically for MRNet layer groups
96
    def freeze_to(self, n:int)->None:
97
        "Freeze layers up to layer group `n`."
98
        for g in self.layer_groups[:n]:
99
            if is_listy(g):
100
                for l in g:
101
                    if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
102
            else: requires_grad(g, False)
103
        for g in self.layer_groups[n:]: requires_grad(g, True)
104
        self.create_opt(defaults.lr)
105
106
    def freeze(self)->None:
107
        "Freeze up to the last layer group."
108
        self.freeze_to(-1)
109
        self.create_opt(defaults.lr)
110
111
    def unfreeze(self):
112
        "Unfreeze entire model."
113
        self.freeze_to(0)
114
        self.create_opt(defaults.lr)
115
116
117
def mrnet_learner(data:DataBunch, model:Callable=MRNet(), pretrained:bool=True, init=nn.init.kaiming_normal_, **kwargs:Any)->Learner:
118
    _layer_groups = [model.model.features, model.model.avgpool, model.model.classifier, model.gap, model.classifier]
119
    learn = MRNetLearner(data, model, layer_groups=_layer_groups, **kwargs)
120
    if pretrained: learn.freeze()
121
    if init: apply_init(model.classifier, init)
122
    return learn