|
a |
|
b/networks.py |
|
|
1 |
# from train import * |
|
|
2 |
from torch.nn import init |
|
|
3 |
from init import Options |
|
|
4 |
import monai |
|
|
5 |
from torch.optim import lr_scheduler |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def init_weights(net, init_type='normal', init_gain=0.02): |
|
|
9 |
"""Initialize network weights. |
|
|
10 |
Parameters: |
|
|
11 |
net (network) -- network to be initialized |
|
|
12 |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
|
13 |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
14 |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
|
15 |
work better for some applications. Feel free to try yourself. |
|
|
16 |
""" |
|
|
17 |
def init_func(m): # define the initialization function |
|
|
18 |
classname = m.__class__.__name__ |
|
|
19 |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
|
20 |
if init_type == 'normal': |
|
|
21 |
init.normal_(m.weight.data, 0.0, init_gain) |
|
|
22 |
elif init_type == 'xavier': |
|
|
23 |
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
|
24 |
elif init_type == 'kaiming': |
|
|
25 |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
|
26 |
elif init_type == 'orthogonal': |
|
|
27 |
init.orthogonal_(m.weight.data, gain=init_gain) |
|
|
28 |
else: |
|
|
29 |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
|
30 |
if hasattr(m, 'bias') and m.bias is not None: |
|
|
31 |
init.constant_(m.bias.data, 0.0) |
|
|
32 |
elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. |
|
|
33 |
init.normal_(m.weight.data, 1.0, init_gain) |
|
|
34 |
init.constant_(m.bias.data, 0.0) |
|
|
35 |
|
|
|
36 |
# print('initialize network with %s' % init_type) |
|
|
37 |
net.apply(init_func) # apply the initialization function <init_func> |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
def get_scheduler(optimizer, opt): |
|
|
41 |
if opt.lr_policy == 'lambda': |
|
|
42 |
def lambda_rule(epoch): |
|
|
43 |
# lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1) |
|
|
44 |
lr_l = (1 - epoch / opt.epochs) ** 0.9 |
|
|
45 |
return lr_l |
|
|
46 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
|
|
47 |
elif opt.lr_policy == 'step': |
|
|
48 |
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) |
|
|
49 |
elif opt.lr_policy == 'plateau': |
|
|
50 |
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) |
|
|
51 |
elif opt.lr_policy == 'cosine': |
|
|
52 |
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0) |
|
|
53 |
else: |
|
|
54 |
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) |
|
|
55 |
return scheduler |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
# update learning rate (called once every epoch) |
|
|
59 |
def update_learning_rate(scheduler, optimizer): |
|
|
60 |
scheduler.step() |
|
|
61 |
lr = optimizer.param_groups[0]['lr'] |
|
|
62 |
# print('learning rate = %.7f' % lr) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
from torch.nn import Module, Sequential |
|
|
66 |
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d |
|
|
67 |
from torch.nn import ReLU, Sigmoid |
|
|
68 |
import torch |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def build_net(): |
|
|
72 |
|
|
|
73 |
from init import Options |
|
|
74 |
opt = Options().parse() |
|
|
75 |
from monai.networks.layers import Norm |
|
|
76 |
|
|
|
77 |
# create nn-Unet |
|
|
78 |
if opt.resolution is None: |
|
|
79 |
sizes, spacings = opt.patch_size, opt.spacing |
|
|
80 |
else: |
|
|
81 |
sizes, spacings = opt.patch_size, opt.resolution |
|
|
82 |
|
|
|
83 |
strides, kernels = [], [] |
|
|
84 |
|
|
|
85 |
while True: |
|
|
86 |
spacing_ratio = [sp / min(spacings) for sp in spacings] |
|
|
87 |
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] |
|
|
88 |
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] |
|
|
89 |
if all(s == 1 for s in stride): |
|
|
90 |
break |
|
|
91 |
sizes = [i / j for i, j in zip(sizes, stride)] |
|
|
92 |
spacings = [i * j for i, j in zip(spacings, stride)] |
|
|
93 |
kernels.append(kernel) |
|
|
94 |
strides.append(stride) |
|
|
95 |
strides.insert(0, len(spacings) * [1]) |
|
|
96 |
kernels.append(len(spacings) * [3]) |
|
|
97 |
|
|
|
98 |
# # create Unet |
|
|
99 |
|
|
|
100 |
nn_Unet = monai.networks.nets.DynUNet( |
|
|
101 |
spatial_dims=3, |
|
|
102 |
in_channels=opt.in_channels, |
|
|
103 |
out_channels=opt.out_channels, |
|
|
104 |
kernel_size=kernels, |
|
|
105 |
strides=strides, |
|
|
106 |
upsample_kernel_size=strides[1:], |
|
|
107 |
res_block=True, |
|
|
108 |
) |
|
|
109 |
|
|
|
110 |
init_weights(nn_Unet, init_type='normal') |
|
|
111 |
|
|
|
112 |
return nn_Unet |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def build_UNETR(): |
|
|
116 |
|
|
|
117 |
from init import Options |
|
|
118 |
opt = Options().parse() |
|
|
119 |
|
|
|
120 |
# create UneTR |
|
|
121 |
|
|
|
122 |
UneTR = monai.networks.nets.UNETR( |
|
|
123 |
in_channels=opt.in_channels, |
|
|
124 |
out_channels=opt.out_channels, |
|
|
125 |
img_size=opt.patch_size, |
|
|
126 |
feature_size=32, |
|
|
127 |
hidden_size=768, |
|
|
128 |
mlp_dim=3072, |
|
|
129 |
num_heads=12, |
|
|
130 |
pos_embed="conv", |
|
|
131 |
norm_name="instance", |
|
|
132 |
res_block=True, |
|
|
133 |
dropout_rate=0.0, |
|
|
134 |
) |
|
|
135 |
|
|
|
136 |
init_weights(UneTR, init_type='normal') |
|
|
137 |
|
|
|
138 |
return UneTR |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
if __name__ == '__main__': |
|
|
142 |
import time |
|
|
143 |
import torch |
|
|
144 |
from torch.autograd import Variable |
|
|
145 |
from torchsummaryX import summary |
|
|
146 |
from torch.nn import init |
|
|
147 |
|
|
|
148 |
opt = Options().parse() |
|
|
149 |
|
|
|
150 |
torch.cuda.set_device(0) |
|
|
151 |
# network = build_net() |
|
|
152 |
network = build_UNETR() |
|
|
153 |
net = network.cuda().eval() |
|
|
154 |
|
|
|
155 |
data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() |
|
|
156 |
|
|
|
157 |
out = net(data) |
|
|
158 |
|
|
|
159 |
# torch.onnx.export(net, data, "Unet_model_graph.onnx") |
|
|
160 |
|
|
|
161 |
summary(net,data) |
|
|
162 |
print("out size: {}".format(out.size())) |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
|
|
|
166 |
|
|
|
167 |
|
|
|
168 |
|