|
a |
|
b/ndv/modules/model.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
"""01_model.ipynb |
|
|
3 |
|
|
|
4 |
Automatically generated by Colaboratory. |
|
|
5 |
|
|
|
6 |
Original file is located at |
|
|
7 |
https://colab.research.google.com/drive/1OWXPL8K-jKC4KgGmYXXkeBWOL2V1biuf |
|
|
8 |
|
|
|
9 |
# Setup |
|
|
10 |
""" |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
from fastai.callbacks import * |
|
|
14 |
from fastai.vision import * |
|
|
15 |
|
|
|
16 |
H = 160 |
|
|
17 |
W = 192 |
|
|
18 |
D = 128 |
|
|
19 |
|
|
|
20 |
def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs): |
|
|
21 |
"A sequence of modules composed of Group Norm, ReLU and Conv3d in order" |
|
|
22 |
if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None |
|
|
23 |
return nn.Sequential(nn.GroupNorm(num_groups, c_in), |
|
|
24 |
nn.ReLU(), |
|
|
25 |
nn.Conv3d(c_in, c_out, ks, **conv_kwargs)) |
|
|
26 |
|
|
|
27 |
def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs): |
|
|
28 |
"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality" |
|
|
29 |
nf_inner = nf / 2 if bottle_neck else nf |
|
|
30 |
return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs), |
|
|
31 |
conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs), |
|
|
32 |
MergeLayer()) |
|
|
33 |
|
|
|
34 |
def upsize(c_in, c_out, ks=1, scale=2): |
|
|
35 |
"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling" |
|
|
36 |
return nn.Sequential(nn.Conv3d(c_in, c_out, ks), |
|
|
37 |
nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=True)) |
|
|
38 |
|
|
|
39 |
def hook_debug(module, input, output): |
|
|
40 |
""" |
|
|
41 |
Print out what's been hooked usually for debugging purpose |
|
|
42 |
---------------------------------------------------------- |
|
|
43 |
Example: |
|
|
44 |
Hooks(ms, hook_debug, is_forward=True, detach=False) |
|
|
45 |
|
|
|
46 |
""" |
|
|
47 |
print('Hooking ' + module.__class__.__name__) |
|
|
48 |
print('output size:', output.data.size()) |
|
|
49 |
return output |
|
|
50 |
|
|
|
51 |
class Encoder(nn.Module): |
|
|
52 |
"Encoder part" |
|
|
53 |
def __init__(self): |
|
|
54 |
super().__init__() |
|
|
55 |
self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1) |
|
|
56 |
self.res_block1 = reslike_block(32, num_groups=8) |
|
|
57 |
self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1) |
|
|
58 |
self.res_block2 = reslike_block(64, num_groups=8) |
|
|
59 |
self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1) |
|
|
60 |
self.res_block3 = reslike_block(64, num_groups=8) |
|
|
61 |
self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1) |
|
|
62 |
self.res_block4 = reslike_block(128, num_groups=8) |
|
|
63 |
self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1) |
|
|
64 |
self.res_block5 = reslike_block(128, num_groups=8) |
|
|
65 |
self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1) |
|
|
66 |
self.res_block6 = reslike_block(256, num_groups=8) |
|
|
67 |
self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) |
|
|
68 |
self.res_block7 = reslike_block(256, num_groups=8) |
|
|
69 |
self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) |
|
|
70 |
self.res_block8 = reslike_block(256, num_groups=8) |
|
|
71 |
self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) |
|
|
72 |
self.res_block9 = reslike_block(256, num_groups=8) |
|
|
73 |
|
|
|
74 |
def forward(self, x): |
|
|
75 |
x = self.conv1(x) # Output size: (1, 32, 160, 192, 128) |
|
|
76 |
x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128) |
|
|
77 |
x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64) |
|
|
78 |
x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64) |
|
|
79 |
x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64) |
|
|
80 |
x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64) |
|
|
81 |
x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32) |
|
|
82 |
x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32) |
|
|
83 |
x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32) |
|
|
84 |
x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32) |
|
|
85 |
x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16) |
|
|
86 |
x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16) |
|
|
87 |
x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16) |
|
|
88 |
x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16) |
|
|
89 |
x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16) |
|
|
90 |
x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16) |
|
|
91 |
x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16) |
|
|
92 |
x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16) |
|
|
93 |
return x |
|
|
94 |
|
|
|
95 |
class Decoder(nn.Module): |
|
|
96 |
"Decoder Part" |
|
|
97 |
def __init__(self): |
|
|
98 |
super().__init__() |
|
|
99 |
self.upsize1 = upsize(256, 128) |
|
|
100 |
self.reslike1 = reslike_block(128, num_groups=8) |
|
|
101 |
self.upsize2 = upsize(128, 64) |
|
|
102 |
self.reslike2 = reslike_block(64, num_groups=8) |
|
|
103 |
self.upsize3 = upsize(64, 32) |
|
|
104 |
self.reslike3 = reslike_block(32, num_groups=8) |
|
|
105 |
self.conv1 = nn.Conv3d(32, 3, 1) |
|
|
106 |
self.sigmoid1 = torch.nn.Sigmoid() |
|
|
107 |
|
|
|
108 |
def forward(self, x): |
|
|
109 |
x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32) |
|
|
110 |
x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32) |
|
|
111 |
x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32) |
|
|
112 |
x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64) |
|
|
113 |
x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64) |
|
|
114 |
x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64) |
|
|
115 |
x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128) |
|
|
116 |
x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128) |
|
|
117 |
x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128) |
|
|
118 |
x = self.conv1(x) # Output size: (1, 3, 160, 192, 128) |
|
|
119 |
x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128) |
|
|
120 |
return x |
|
|
121 |
|
|
|
122 |
class VAEEncoder(nn.Module): |
|
|
123 |
"Variational auto-encoder encoder part" |
|
|
124 |
def __init__(self, latent_dim:int=128): |
|
|
125 |
super().__init__() |
|
|
126 |
self.latent_dim = latent_dim |
|
|
127 |
self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1) |
|
|
128 |
self.linear1 = nn.Linear(60, 1) |
|
|
129 |
|
|
|
130 |
# Assumed latent variable's probability density function parameters |
|
|
131 |
self.z_mean = nn.Linear(256, latent_dim) |
|
|
132 |
self.z_log_var = nn.Linear(256, latent_dim) |
|
|
133 |
self.epsilon = nn.Parameter(torch.randn(1, latent_dim)) |
|
|
134 |
|
|
|
135 |
def forward(self, x): |
|
|
136 |
x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8) |
|
|
137 |
x = x.view(256, -1) # Output size: (256, 60) |
|
|
138 |
x = self.linear1(x) # Output size: (256, 1) |
|
|
139 |
x = x.view(1, 256) # Output size: (1, 256) |
|
|
140 |
z_mean = self.z_mean(x) # Output size: (1, 128) |
|
|
141 |
z_var = self.z_log_var(x).exp() # Output size: (1, 128) |
|
|
142 |
return z_mean + z_var * self.epsilon # Output size: (1, 128) |
|
|
143 |
|
|
|
144 |
class VAEDecoder(nn.Module): |
|
|
145 |
"Variational auto-encoder decoder part" |
|
|
146 |
def __init__(self): |
|
|
147 |
super().__init__() |
|
|
148 |
self.linear1 = nn.Linear(128, 256*60) |
|
|
149 |
self.relu1 = nn.ReLU() |
|
|
150 |
self.upsize1 = upsize(16, 256) |
|
|
151 |
self.upsize2 = upsize(256, 128) |
|
|
152 |
self.reslike1 = reslike_block(128, num_groups=8) |
|
|
153 |
self.upsize3 = upsize(128, 64) |
|
|
154 |
self.reslike2 = reslike_block(64, num_groups=8) |
|
|
155 |
self.upsize4 = upsize(64, 32) |
|
|
156 |
self.reslike3 = reslike_block(32, num_groups=8) |
|
|
157 |
self.conv1 = nn.Conv3d(32, 4, 1) |
|
|
158 |
|
|
|
159 |
def forward(self, x): |
|
|
160 |
x = self.linear1(x) # Output size: (1, 256*60) |
|
|
161 |
x = self.relu1(x) # Output size: (1, 256*60) |
|
|
162 |
x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8) |
|
|
163 |
x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16) |
|
|
164 |
x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32) |
|
|
165 |
x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32) |
|
|
166 |
x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64) |
|
|
167 |
x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64) |
|
|
168 |
x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128) |
|
|
169 |
x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128) |
|
|
170 |
x = self.conv1(x) # Output size: (1, 4, 160, 192, 128) |
|
|
171 |
return x |
|
|
172 |
|
|
|
173 |
class AutoUNet(nn.Module): |
|
|
174 |
"3D U-Net using autoencoder regularization" |
|
|
175 |
def __init__(self): |
|
|
176 |
super().__init__() |
|
|
177 |
self.encoder = Encoder() |
|
|
178 |
self.decoder = Decoder() |
|
|
179 |
self.vencoder = VAEEncoder(latent_dim=128) |
|
|
180 |
self.vdecoder = VAEDecoder() |
|
|
181 |
|
|
|
182 |
def forward(self, input): |
|
|
183 |
interm_res = self.encoder(input) |
|
|
184 |
top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128) |
|
|
185 |
bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128) |
|
|
186 |
return top_res, bottom_res |
|
|
187 |
|
|
|
188 |
class SoftDiceLoss(nn.Module): |
|
|
189 |
"Soft dice loss based on a measure of overlap between prediction and ground truth" |
|
|
190 |
def __init__(self, epsilon=1e-6, c=3): |
|
|
191 |
super().__init__() |
|
|
192 |
self.epsilon = epsilon |
|
|
193 |
self.c = 3 |
|
|
194 |
|
|
|
195 |
def forward(self, x:Tensor, y:Tensor): |
|
|
196 |
intersection = 2 * ( (x*y).sum() ) |
|
|
197 |
union = (x**2).sum() + (y**2).sum() |
|
|
198 |
return 1 - ( ( intersection / (union + self.epsilon) ) / self.c ) |
|
|
199 |
|
|
|
200 |
class KLDivergence(nn.Module): |
|
|
201 |
"KL divergence between the estimated normal distribution and a prior distribution" |
|
|
202 |
N = H * W * D #hyperparameter check |
|
|
203 |
|
|
|
204 |
def __init__(self): |
|
|
205 |
super().__init__() |
|
|
206 |
|
|
|
207 |
def forward(self, z_mean:Tensor, z_log_var:Tensor): |
|
|
208 |
z_var = z_log_var.exp() |
|
|
209 |
return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() ) |
|
|
210 |
|
|
|
211 |
class L2Loss(nn.Module): |
|
|
212 |
"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`" |
|
|
213 |
def __init__(self): |
|
|
214 |
super().__init__() |
|
|
215 |
|
|
|
216 |
def forward(self, x:Tensor, y:Tensor): |
|
|
217 |
return ( (x - y)**2 ).sum() |
|
|
218 |
|
|
|
219 |
autounet = AutoUNet() |
|
|
220 |
ms = [autounet.encoder.res_block1, |
|
|
221 |
autounet.encoder.res_block3, |
|
|
222 |
autounet.encoder.res_block5, |
|
|
223 |
autounet.vencoder.z_mean, |
|
|
224 |
autounet.vencoder.z_log_var] |
|
|
225 |
hooks = hook_outputs(ms, detach=False, grad=False) |
|
|
226 |
|
|
|
227 |
lr = 1e-4 |
|
|
228 |
optimizer = optim.Adam(autounet.parameters(), lr) |