Diff of /networks.py [000000] .. [83198a]

Switch to unified view

a b/networks.py
1
# from train import *
2
from torch.nn import init
3
from init import Options
4
import monai
5
from torch.optim import lr_scheduler
6
7
8
def init_weights(net, init_type='normal', init_gain=0.02):
9
    """Initialize network weights.
10
    Parameters:
11
        net (network)   -- network to be initialized
12
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
13
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
14
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
15
    work better for some applications. Feel free to try yourself.
16
    """
17
    def init_func(m):  # define the initialization function
18
        classname = m.__class__.__name__
19
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
20
            if init_type == 'normal':
21
                init.normal_(m.weight.data, 0.0, init_gain)
22
            elif init_type == 'xavier':
23
                init.xavier_normal_(m.weight.data, gain=init_gain)
24
            elif init_type == 'kaiming':
25
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
26
            elif init_type == 'orthogonal':
27
                init.orthogonal_(m.weight.data, gain=init_gain)
28
            else:
29
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
30
            if hasattr(m, 'bias') and m.bias is not None:
31
                init.constant_(m.bias.data, 0.0)
32
        elif classname.find('BatchNorm3d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
33
            init.normal_(m.weight.data, 1.0, init_gain)
34
            init.constant_(m.bias.data, 0.0)
35
36
    # print('initialize network with %s' % init_type)
37
    net.apply(init_func)  # apply the initialization function <init_func>
38
39
40
def get_scheduler(optimizer, opt):
41
    if opt.lr_policy == 'lambda':
42
        def lambda_rule(epoch):
43
            # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
44
            lr_l = (1 - epoch / opt.epochs) ** 0.9
45
            return lr_l
46
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
47
    elif opt.lr_policy == 'step':
48
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
49
    elif opt.lr_policy == 'plateau':
50
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
51
    elif opt.lr_policy == 'cosine':
52
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
53
    else:
54
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
55
    return scheduler
56
57
58
# update learning rate (called once every epoch)
59
def update_learning_rate(scheduler, optimizer):
60
    scheduler.step()
61
    lr = optimizer.param_groups[0]['lr']
62
    # print('learning rate = %.7f' % lr)
63
64
65
from torch.nn import Module, Sequential
66
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
67
from torch.nn import ReLU, Sigmoid
68
import torch
69
70
71
def build_net():
72
73
    from init import Options
74
    opt = Options().parse()
75
    from monai.networks.layers import Norm
76
77
    # create nn-Unet
78
    if opt.resolution is None:
79
        sizes, spacings = opt.patch_size, opt.spacing
80
    else:
81
        sizes, spacings = opt.patch_size, opt.resolution
82
83
    strides, kernels = [], []
84
85
    while True:
86
        spacing_ratio = [sp / min(spacings) for sp in spacings]
87
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
88
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
89
        if all(s == 1 for s in stride):
90
            break
91
        sizes = [i / j for i, j in zip(sizes, stride)]
92
        spacings = [i * j for i, j in zip(spacings, stride)]
93
        kernels.append(kernel)
94
        strides.append(stride)
95
    strides.insert(0, len(spacings) * [1])
96
    kernels.append(len(spacings) * [3])
97
98
    # # create Unet
99
100
    nn_Unet = monai.networks.nets.DynUNet(
101
        spatial_dims=3,
102
        in_channels=opt.in_channels,
103
        out_channels=opt.out_channels,
104
        kernel_size=kernels,
105
        strides=strides,
106
        upsample_kernel_size=strides[1:],
107
        res_block=True,
108
    )
109
110
    init_weights(nn_Unet, init_type='normal')
111
112
    return nn_Unet
113
114
115
def build_UNETR():
116
117
    from init import Options
118
    opt = Options().parse()
119
120
    # create UneTR
121
122
    UneTR = monai.networks.nets.UNETR(
123
        in_channels=opt.in_channels,
124
        out_channels=opt.out_channels,
125
        img_size=opt.patch_size,
126
        feature_size=32,
127
        hidden_size=768,
128
        mlp_dim=3072,
129
        num_heads=12,
130
        pos_embed="conv",
131
        norm_name="instance",
132
        res_block=True,
133
        dropout_rate=0.0,
134
    )
135
136
    init_weights(UneTR, init_type='normal')
137
138
    return UneTR
139
140
141
if __name__ == '__main__':
142
    import time
143
    import torch
144
    from torch.autograd import Variable
145
    from torchsummaryX import summary
146
    from torch.nn import init
147
148
    opt = Options().parse()
149
150
    torch.cuda.set_device(0)
151
    # network = build_net()
152
    network = build_UNETR()
153
    net = network.cuda().eval()
154
155
    data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
156
157
    out = net(data)
158
159
    # torch.onnx.export(net, data, "Unet_model_graph.onnx")
160
161
    summary(net,data)
162
    print("out size: {}".format(out.size()))
163
164
165
166
167
168