|
a |
|
b/quicknat.py |
|
|
1 |
"""Quicknat architecture""" |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
from nn_common_modules import modules as sm |
|
|
6 |
from squeeze_and_excitation import squeeze_and_excitation as se |
|
|
7 |
|
|
|
8 |
class QuickNat(nn.Module): |
|
|
9 |
""" |
|
|
10 |
A PyTorch implementation of QuickNAT |
|
|
11 |
|
|
|
12 |
""" |
|
|
13 |
def __init__(self, params): |
|
|
14 |
""" |
|
|
15 |
|
|
|
16 |
:param params: {'num_channels':1, |
|
|
17 |
'num_filters':64, |
|
|
18 |
'kernel_h':5, |
|
|
19 |
'kernel_w':5, |
|
|
20 |
'stride_conv':1, |
|
|
21 |
'pool':2, |
|
|
22 |
'stride_pool':2, |
|
|
23 |
'num_classes':28 |
|
|
24 |
'se_block': False, |
|
|
25 |
'drop_out':0.2} |
|
|
26 |
""" |
|
|
27 |
super(QuickNat, self).__init__() |
|
|
28 |
print(se.SELayer(params['se_block'])) |
|
|
29 |
self.encode1 = sm.EncoderBlock(params, se_block_type=params['se_block']) |
|
|
30 |
params['num_channels'] = params['num_filters'] |
|
|
31 |
self.encode2 = sm.EncoderBlock(params, se_block_type=params['se_block']) |
|
|
32 |
self.encode3 = sm.EncoderBlock(params, se_block_type=params['se_block']) |
|
|
33 |
self.encode4 = sm.EncoderBlock(params, se_block_type=params['se_block']) |
|
|
34 |
self.bottleneck = sm.DenseBlock(params, se_block_type=params['se_block']) |
|
|
35 |
params['num_channels'] = params['num_filters'] * 2 |
|
|
36 |
self.decode1 = sm.DecoderBlock(params, se_block_type=params['se_block']) |
|
|
37 |
self.decode2 = sm.DecoderBlock(params, se_block_type=params['se_block']) |
|
|
38 |
self.decode3 = sm.DecoderBlock(params, se_block_type=params['se_block']) |
|
|
39 |
self.decode4 = sm.DecoderBlock(params, se_block_type=params['se_block']) |
|
|
40 |
params['num_channels'] = params['num_filters'] |
|
|
41 |
self.classifier = sm.ClassifierBlock(params) |
|
|
42 |
|
|
|
43 |
def forward(self, input): |
|
|
44 |
""" |
|
|
45 |
|
|
|
46 |
:param input: X |
|
|
47 |
:return: probabiliy map |
|
|
48 |
""" |
|
|
49 |
e1, out1, ind1 = self.encode1.forward(input) |
|
|
50 |
e2, out2, ind2 = self.encode2.forward(e1) |
|
|
51 |
e3, out3, ind3 = self.encode3.forward(e2) |
|
|
52 |
e4, out4, ind4 = self.encode4.forward(e3) |
|
|
53 |
|
|
|
54 |
bn = self.bottleneck.forward(e4) |
|
|
55 |
|
|
|
56 |
d4 = self.decode4.forward(bn, out4, ind4) |
|
|
57 |
d3 = self.decode1.forward(d4, out3, ind3) |
|
|
58 |
d2 = self.decode2.forward(d3, out2, ind2) |
|
|
59 |
d1 = self.decode3.forward(d2, out1, ind1) |
|
|
60 |
prob = self.classifier.forward(d1) |
|
|
61 |
|
|
|
62 |
return prob |
|
|
63 |
|
|
|
64 |
def enable_test_dropout(self): |
|
|
65 |
""" |
|
|
66 |
Enables test time drop out for uncertainity |
|
|
67 |
:return: |
|
|
68 |
""" |
|
|
69 |
attr_dict = self.__dict__['_modules'] |
|
|
70 |
for i in range(1, 5): |
|
|
71 |
encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] |
|
|
72 |
encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) |
|
|
73 |
decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) |
|
|
74 |
|
|
|
75 |
@property |
|
|
76 |
def is_cuda(self): |
|
|
77 |
""" |
|
|
78 |
Check if model parameters are allocated on the GPU. |
|
|
79 |
""" |
|
|
80 |
return next(self.parameters()).is_cuda |
|
|
81 |
|
|
|
82 |
def save(self, path): |
|
|
83 |
""" |
|
|
84 |
Save model with its parameters to the given path. Conventionally the |
|
|
85 |
path should end with '*.model'. |
|
|
86 |
|
|
|
87 |
Inputs: |
|
|
88 |
- path: path string |
|
|
89 |
""" |
|
|
90 |
print('Saving model... %s' % path) |
|
|
91 |
torch.save(self, path) |
|
|
92 |
|
|
|
93 |
def predict(self, X, device=0, enable_dropout=False, out_prob=False): |
|
|
94 |
""" |
|
|
95 |
Predicts the outout after the model is trained. |
|
|
96 |
Inputs: |
|
|
97 |
- X: Volume to be predicted |
|
|
98 |
""" |
|
|
99 |
self.eval() |
|
|
100 |
|
|
|
101 |
if type(X) is np.ndarray: |
|
|
102 |
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) |
|
|
103 |
elif type(X) is torch.Tensor and not X.is_cuda: |
|
|
104 |
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) |
|
|
105 |
|
|
|
106 |
if enable_dropout: |
|
|
107 |
self.enable_test_dropout() |
|
|
108 |
|
|
|
109 |
with torch.no_grad(): |
|
|
110 |
out = self.forward(X) |
|
|
111 |
|
|
|
112 |
if out_prob: |
|
|
113 |
return out |
|
|
114 |
else: |
|
|
115 |
max_val, idx = torch.max(out, 1) |
|
|
116 |
idx = idx.data.cpu().numpy() |
|
|
117 |
prediction = np.squeeze(idx) |
|
|
118 |
del X, out, idx, max_val |
|
|
119 |
return prediction |