a b/monai 0.5.0/deprecated/networks.py
1
from train import *
2
from torch.nn import init
3
import monai
4
from torch.optim import lr_scheduler
5
6
7
def init_weights(net, init_type='normal', init_gain=0.02):
8
    """Initialize network weights.
9
    Parameters:
10
        net (network)   -- network to be initialized
11
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
12
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
13
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14
    work better for some applications. Feel free to try yourself.
15
    """
16
    def init_func(m):  # define the initialization function
17
        classname = m.__class__.__name__
18
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19
            if init_type == 'normal':
20
                init.normal_(m.weight.data, 0.0, init_gain)
21
            elif init_type == 'xavier':
22
                init.xavier_normal_(m.weight.data, gain=init_gain)
23
            elif init_type == 'kaiming':
24
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25
            elif init_type == 'orthogonal':
26
                init.orthogonal_(m.weight.data, gain=init_gain)
27
            else:
28
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29
            if hasattr(m, 'bias') and m.bias is not None:
30
                init.constant_(m.bias.data, 0.0)
31
        elif classname.find('BatchNorm3d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32
            init.normal_(m.weight.data, 1.0, init_gain)
33
            init.constant_(m.bias.data, 0.0)
34
35
    # print('initialize network with %s' % init_type)
36
    net.apply(init_func)  # apply the initialization function <init_func>
37
38
39
def get_scheduler(optimizer, opt):
40
    if opt.lr_policy == 'lambda':
41
        def lambda_rule(epoch):
42
            # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
43
            lr_l = (1 - epoch / opt.epochs) ** 0.9
44
            return lr_l
45
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46
    elif opt.lr_policy == 'step':
47
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48
    elif opt.lr_policy == 'plateau':
49
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
50
    elif opt.lr_policy == 'cosine':
51
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
52
    else:
53
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
54
    return scheduler
55
56
57
# update learning rate (called once every epoch)
58
def update_learning_rate(scheduler, optimizer):
59
    scheduler.step()
60
    lr = optimizer.param_groups[0]['lr']
61
    # print('learning rate = %.7f' % lr)
62
63
64
from torch.nn import Module, Sequential
65
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
66
from torch.nn import ReLU, Sigmoid
67
import torch
68
69
70
def build_net():
71
72
    from init import Options
73
    opt = Options().parse()
74
    from monai.networks.layers import Norm
75
    from monai.networks.layers.factories import split_args
76
    act_type, args = split_args("RELU")
77
78
    # # create Unet
79
    # Unet = monai.networks.nets.UNet(
80
    #     dimensions=3,
81
    #     in_channels=opt.in_channels,
82
    #     out_channels=opt.out_channels,
83
    #     channels=(64, 128, 256, 512, 1024),
84
    #     strides=(2, 2, 2, 2),
85
    #     act=act_type,
86
    #     num_res_units=3,
87
    #     dropout=0.2,
88
    #     norm=Norm.BATCH,
89
    #
90
    # )
91
92
    # create nn-Unet
93
    if opt.resolution is None:
94
        sizes, spacings = opt.patch_size, opt.spacing
95
    else:
96
        sizes, spacings = opt.patch_size, opt.resolution
97
98
    strides, kernels = [], []
99
100
    while True:
101
        spacing_ratio = [sp / min(spacings) for sp in spacings]
102
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
103
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
104
        if all(s == 1 for s in stride):
105
            break
106
        sizes = [i / j for i, j in zip(sizes, stride)]
107
        spacings = [i * j for i, j in zip(spacings, stride)]
108
        kernels.append(kernel)
109
        strides.append(stride)
110
    strides.insert(0, len(spacings) * [1])
111
    kernels.append(len(spacings) * [3])
112
113
    nn_Unet = monai.networks.nets.DynUNet(
114
        spatial_dims=3,
115
        in_channels=opt.in_channels,
116
        out_channels=opt.out_channels,
117
        kernel_size=kernels,
118
        strides=strides,
119
        upsample_kernel_size=strides[1:],
120
        res_block=True,
121
        # act=act_type,
122
        # norm=Norm.BATCH,
123
    )
124
125
    init_weights(nn_Unet, init_type='normal')
126
127
    return nn_Unet
128
129
130
if __name__ == '__main__':
131
    import time
132
    import torch
133
    from torch.autograd import Variable
134
    from torchsummaryX import summary
135
    from torch.nn import init
136
137
    opt = Options().parse()
138
139
    torch.cuda.set_device(0)
140
    network = build_net()
141
    net = network.cuda().eval()
142
143
    data = Variable(torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
144
145
    out = net(data)
146
147
    torch.onnx.export(net, data, "Unet_model_graph.onnx")
148
149
    summary(net,data)
150
    print("out size: {}".format(out.size()))
151
152
153
154
155
156