Diff of /ndv/modules/model.py [000000] .. [64faee]

Switch to unified view

a b/ndv/modules/model.py
1
# -*- coding: utf-8 -*-
2
"""01_model.ipynb
3
4
Automatically generated by Colaboratory.
5
6
Original file is located at
7
    https://colab.research.google.com/drive/1OWXPL8K-jKC4KgGmYXXkeBWOL2V1biuf
8
9
# Setup
10
"""
11
12
import torch
13
from fastai.callbacks import *
14
from fastai.vision import *
15
16
H = 160
17
W = 192
18
D = 128
19
20
def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):
21
    "A sequence of modules composed of Group Norm, ReLU and Conv3d in order"
22
    if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None
23
    return nn.Sequential(nn.GroupNorm(num_groups, c_in),
24
                         nn.ReLU(),
25
                         nn.Conv3d(c_in, c_out, ks, **conv_kwargs))
26
27
def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):
28
    "A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality"
29
    nf_inner = nf / 2 if bottle_neck else nf
30
    return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),
31
                        conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),
32
                        MergeLayer())
33
34
def upsize(c_in, c_out, ks=1, scale=2):
35
    "Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D            trilinear upsampling"
36
    return nn.Sequential(nn.Conv3d(c_in, c_out, ks),
37
                       nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=True))
38
39
def hook_debug(module, input, output):
40
    """
41
    Print out what's been hooked usually for debugging purpose
42
    ----------------------------------------------------------
43
       Example:
44
       Hooks(ms, hook_debug, is_forward=True, detach=False)
45
    
46
    """
47
    print('Hooking ' + module.__class__.__name__)
48
    print('output size:', output.data.size())
49
    return output
50
51
class Encoder(nn.Module):
52
    "Encoder part"
53
    def __init__(self):
54
        super().__init__()
55
        self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1)         
56
        self.res_block1 = reslike_block(32, num_groups=8)
57
        self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)
58
        self.res_block2 = reslike_block(64, num_groups=8)
59
        self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)
60
        self.res_block3 = reslike_block(64, num_groups=8)
61
        self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)
62
        self.res_block4 = reslike_block(128, num_groups=8)
63
        self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)
64
        self.res_block5 = reslike_block(128, num_groups=8)
65
        self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)
66
        self.res_block6 = reslike_block(256, num_groups=8)
67
        self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
68
        self.res_block7 = reslike_block(256, num_groups=8)
69
        self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
70
        self.res_block8 = reslike_block(256, num_groups=8)
71
        self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)
72
        self.res_block9 = reslike_block(256, num_groups=8)
73
    
74
    def forward(self, x):
75
        x = self.conv1(x)                                           # Output size: (1, 32, 160, 192, 128)
76
        x = self.res_block1(x)                                      # Output size: (1, 32, 160, 192, 128)
77
        x = self.conv_block1(x)                                     # Output size: (1, 64, 80, 96, 64)
78
        x = self.res_block2(x)                                      # Output size: (1, 64, 80, 96, 64)
79
        x = self.conv_block2(x)                                     # Output size: (1, 64, 80, 96, 64)
80
        x = self.res_block3(x)                                      # Output size: (1, 64, 80, 96, 64)
81
        x = self.conv_block3(x)                                     # Output size: (1, 128, 40, 48, 32)
82
        x = self.res_block4(x)                                      # Output size: (1, 128, 40, 48, 32)
83
        x = self.conv_block4(x)                                     # Output size: (1, 128, 40, 48, 32)
84
        x = self.res_block5(x)                                      # Output size: (1, 128, 40, 48, 32)
85
        x = self.conv_block5(x)                                     # Output size: (1, 256, 20, 24, 16)
86
        x = self.res_block6(x)                                      # Output size: (1, 256, 20, 24, 16)
87
        x = self.conv_block6(x)                                     # Output size: (1, 256, 20, 24, 16)
88
        x = self.res_block7(x)                                      # Output size: (1, 256, 20, 24, 16)
89
        x = self.conv_block7(x)                                     # Output size: (1, 256, 20, 24, 16)
90
        x = self.res_block8(x)                                      # Output size: (1, 256, 20, 24, 16)
91
        x = self.conv_block8(x)                                     # Output size: (1, 256, 20, 24, 16)
92
        x = self.res_block9(x)                                      # Output size: (1, 256, 20, 24, 16)
93
        return x
94
95
class Decoder(nn.Module):
96
    "Decoder Part"
97
    def __init__(self):
98
        super().__init__()
99
        self.upsize1 = upsize(256, 128)
100
        self.reslike1 = reslike_block(128, num_groups=8)
101
        self.upsize2 = upsize(128, 64)
102
        self.reslike2 = reslike_block(64, num_groups=8)
103
        self.upsize3 = upsize(64, 32)
104
        self.reslike3 = reslike_block(32, num_groups=8)
105
        self.conv1 = nn.Conv3d(32, 3, 1) 
106
        self.sigmoid1 = torch.nn.Sigmoid()
107
108
    def forward(self, x):
109
        x = self.upsize1(x)                                         # Output size: (1, 128, 40, 48, 32)
110
        x = x + hooks.stored[2]                                     # Output size: (1, 128, 40, 48, 32)
111
        x = self.reslike1(x)                                        # Output size: (1, 128, 40, 48, 32)
112
        x = self.upsize2(x)                                         # Output size: (1, 64, 80, 96, 64)
113
        x = x + hooks.stored[1]                                     # Output size: (1, 64, 80, 96, 64)
114
        x = self.reslike2(x)                                        # Output size: (1, 64, 80, 96, 64)
115
        x = self.upsize3(x)                                         # Output size: (1, 32, 160, 192, 128)
116
        x = x + hooks.stored[0]                                     # Output size: (1, 32, 160, 192, 128)
117
        x = self.reslike3(x)                                        # Output size: (1, 32, 160, 192, 128)
118
        x = self.conv1(x)                                           # Output size: (1, 3, 160, 192, 128)
119
        x = self.sigmoid1(x)                                        # Output size: (1, 3, 160, 192, 128)
120
        return x
121
122
class VAEEncoder(nn.Module):
123
    "Variational auto-encoder encoder part"
124
    def __init__(self, latent_dim:int=128):
125
        super().__init__()
126
        self.latent_dim = latent_dim
127
        self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)
128
        self.linear1 = nn.Linear(60, 1)
129
        
130
        # Assumed latent variable's probability density function parameters
131
        self.z_mean = nn.Linear(256, latent_dim)
132
        self.z_log_var = nn.Linear(256, latent_dim)
133
        self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
134
        
135
    def forward(self, x):
136
        x = self.conv_block(x)                                   # Output size: (1, 16, 10, 12, 8)                                  
137
        x = x.view(256, -1)                                      # Output size: (256, 60)                                       
138
        x = self.linear1(x)                                      # Output size: (256, 1)
139
        x = x.view(1, 256)                                       # Output size: (1, 256)   
140
        z_mean = self.z_mean(x)                                  # Output size: (1, 128)
141
        z_var = self.z_log_var(x).exp()                          # Output size: (1, 128)              
142
        return z_mean + z_var * self.epsilon                     # Output size: (1, 128)
143
144
class VAEDecoder(nn.Module):
145
    "Variational auto-encoder decoder part"
146
    def __init__(self):
147
        super().__init__()
148
        self.linear1 = nn.Linear(128, 256*60)
149
        self.relu1 = nn.ReLU()
150
        self.upsize1 = upsize(16, 256)
151
        self.upsize2 = upsize(256, 128)
152
        self.reslike1 = reslike_block(128, num_groups=8)
153
        self.upsize3 = upsize(128, 64)
154
        self.reslike2 = reslike_block(64, num_groups=8)
155
        self.upsize4 = upsize(64, 32)
156
        self.reslike3 = reslike_block(32, num_groups=8)
157
        self.conv1 = nn.Conv3d(32, 4, 1)
158
    
159
    def forward(self, x):
160
        x = self.linear1(x)                                          # Output size: (1, 256*60)      
161
        x = self.relu1(x)                                            # Output size: (1, 256*60)
162
        x = x.view(1, 16, 10, 12, 8)                                 # Output size: (1, 16, 10, 12, 8)
163
        x = self.upsize1(x)                                          # Output size: (1, 256, 20, 24, 16)
164
        x = self.upsize2(x)                                          # Output size: (1, 128, 40, 48, 32)
165
        x = self.reslike1(x)                                         # Output size: (1, 128, 40, 48, 32)
166
        x = self.upsize3(x)                                          # Output size: (1, 64, 80, 96, 64)
167
        x = self.reslike2(x)                                         # Output size: (1, 64, 80, 96, 64)
168
        x = self.upsize4(x)                                          # Output size: (1, 32, 160, 192, 128)
169
        x = self.reslike3(x)                                         # Output size: (1, 32, 160, 192, 128)
170
        x = self.conv1(x)                                            # Output size: (1, 4, 160, 192, 128) 
171
        return x
172
173
class AutoUNet(nn.Module):
174
  "3D U-Net using autoencoder regularization"
175
  def __init__(self):
176
    super().__init__()
177
    self.encoder = Encoder()
178
    self.decoder = Decoder()
179
    self.vencoder = VAEEncoder(latent_dim=128)
180
    self.vdecoder = VAEDecoder()
181
182
  def forward(self, input):
183
    interm_res = self.encoder(input)
184
    top_res = self.decoder(interm_res)                               # Output size: (1, 3, 160, 192, 128)
185
    bottom_res = self.vdecoder(self.vencoder(interm_res))            # Output size: (1, 4, 160, 192, 128)
186
    return top_res, bottom_res
187
188
class SoftDiceLoss(nn.Module): 
189
    "Soft dice loss based on a measure of overlap between prediction and ground truth"
190
    def __init__(self, epsilon=1e-6, c=3):
191
        super().__init__()
192
        self.epsilon = epsilon
193
        self.c = 3
194
    
195
    def forward(self, x:Tensor, y:Tensor):
196
        intersection = 2 * ( (x*y).sum() )
197
        union = (x**2).sum() + (y**2).sum() 
198
        return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )
199
200
class KLDivergence(nn.Module): 
201
    "KL divergence between the estimated normal distribution and a prior distribution"
202
    N = H * W * D  #hyperparameter check
203
204
    def __init__(self):
205
        super().__init__()
206
    
207
    def forward(self, z_mean:Tensor, z_log_var:Tensor):
208
        z_var = z_log_var.exp()
209
        return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )
210
211
class L2Loss(nn.Module): 
212
    "Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`"
213
    def __init__(self):
214
        super().__init__()
215
        
216
    def forward(self, x:Tensor, y:Tensor):
217
        return  ( (x - y)**2 ).sum()
218
219
autounet = AutoUNet()
220
ms = [autounet.encoder.res_block1, 
221
      autounet.encoder.res_block3, 
222
      autounet.encoder.res_block5, 
223
      autounet.vencoder.z_mean, 
224
      autounet.vencoder.z_log_var]
225
hooks = hook_outputs(ms, detach=False, grad=False)
226
227
lr = 1e-4
228
optimizer = optim.Adam(autounet.parameters(), lr)