a b/monai 0.5.0/networks.py
1
from train import *
2
from torch.nn import init
3
import monai
4
from torch.optim import lr_scheduler
5
6
7
def init_weights(net, init_type='normal', init_gain=0.02):
8
    """Initialize network weights.
9
    Parameters:
10
        net (network)   -- network to be initialized
11
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
12
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
13
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14
    work better for some applications. Feel free to try yourself.
15
    """
16
    def init_func(m):  # define the initialization function
17
        classname = m.__class__.__name__
18
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19
            if init_type == 'normal':
20
                init.normal_(m.weight.data, 0.0, init_gain)
21
            elif init_type == 'xavier':
22
                init.xavier_normal_(m.weight.data, gain=init_gain)
23
            elif init_type == 'kaiming':
24
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25
            elif init_type == 'orthogonal':
26
                init.orthogonal_(m.weight.data, gain=init_gain)
27
            else:
28
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29
            if hasattr(m, 'bias') and m.bias is not None:
30
                init.constant_(m.bias.data, 0.0)
31
        elif classname.find('BatchNorm3d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32
            init.normal_(m.weight.data, 1.0, init_gain)
33
            init.constant_(m.bias.data, 0.0)
34
35
    # print('initialize network with %s' % init_type)
36
    net.apply(init_func)  # apply the initialization function <init_func>
37
38
39
def get_scheduler(optimizer, opt):
40
    if opt.lr_policy == 'lambda':
41
        def lambda_rule(epoch):
42
            # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
43
            lr_l = (1 - epoch / opt.epochs) ** 0.9
44
            return lr_l
45
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46
    elif opt.lr_policy == 'step':
47
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48
    elif opt.lr_policy == 'plateau':
49
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
50
    elif opt.lr_policy == 'cosine':
51
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
52
    else:
53
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
54
    return scheduler
55
56
57
# update learning rate (called once every epoch)
58
def update_learning_rate(scheduler, optimizer):
59
    scheduler.step()
60
    lr = optimizer.param_groups[0]['lr']
61
    # print('learning rate = %.7f' % lr)
62
63
64
from torch.nn import Module, Sequential
65
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
66
from torch.nn import ReLU, Sigmoid
67
import torch
68
69
70
def build_net():
71
72
    from init import Options
73
    opt = Options().parse()
74
    from monai.networks.layers import Norm
75
76
    # create nn-Unet
77
    if opt.resolution is None:
78
        sizes, spacings = opt.patch_size, opt.spacing
79
    else:
80
        sizes, spacings = opt.patch_size, opt.resolution
81
82
    strides, kernels = [], []
83
84
    while True:
85
        spacing_ratio = [sp / min(spacings) for sp in spacings]
86
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
87
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
88
        if all(s == 1 for s in stride):
89
            break
90
        sizes = [i / j for i, j in zip(sizes, stride)]
91
        spacings = [i * j for i, j in zip(spacings, stride)]
92
        kernels.append(kernel)
93
        strides.append(stride)
94
    strides.insert(0, len(spacings) * [1])
95
    kernels.append(len(spacings) * [3])
96
97
    # # create Unet
98
99
    nn_Unet = monai.networks.nets.DynUNet(
100
        spatial_dims=3,
101
        in_channels=opt.in_channels,
102
        out_channels=opt.out_channels,
103
        kernel_size=kernels,
104
        strides=strides,
105
        upsample_kernel_size=strides[1:],
106
        res_block=True,
107
    )
108
109
    init_weights(nn_Unet, init_type='normal')
110
111
    return nn_Unet
112
113
114
if __name__ == '__main__':
115
    import time
116
    import torch
117
    from torch.autograd import Variable
118
    from torchsummaryX import summary
119
    from torch.nn import init
120
121
    opt = Options().parse()
122
123
    torch.cuda.set_device(0)
124
    network = build_net()
125
    net = network.cuda().eval()
126
127
    data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
128
129
    out = net(data)
130
131
    torch.onnx.export(net, data, "Unet_model_graph.onnx")
132
133
    summary(net,data)
134
    print("out size: {}".format(out.size()))
135
136
137
138
139
140