--- a +++ b/networks.py @@ -0,0 +1,168 @@ +# from train import * +from torch.nn import init +from init import Options +import monai +from torch.optim import lr_scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + # print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function <init_func> + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1) + lr_l = (1 - epoch / opt.epochs) ** 0.9 + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +# update learning rate (called once every epoch) +def update_learning_rate(scheduler, optimizer): + scheduler.step() + lr = optimizer.param_groups[0]['lr'] + # print('learning rate = %.7f' % lr) + + +from torch.nn import Module, Sequential +from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d +from torch.nn import ReLU, Sigmoid +import torch + + +def build_net(): + + from init import Options + opt = Options().parse() + from monai.networks.layers import Norm + + # create nn-Unet + if opt.resolution is None: + sizes, spacings = opt.patch_size, opt.spacing + else: + sizes, spacings = opt.patch_size, opt.resolution + + strides, kernels = [], [] + + while True: + spacing_ratio = [sp / min(spacings) for sp in spacings] + stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] + kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] + if all(s == 1 for s in stride): + break + sizes = [i / j for i, j in zip(sizes, stride)] + spacings = [i * j for i, j in zip(spacings, stride)] + kernels.append(kernel) + strides.append(stride) + strides.insert(0, len(spacings) * [1]) + kernels.append(len(spacings) * [3]) + + # # create Unet + + nn_Unet = monai.networks.nets.DynUNet( + spatial_dims=3, + in_channels=opt.in_channels, + out_channels=opt.out_channels, + kernel_size=kernels, + strides=strides, + upsample_kernel_size=strides[1:], + res_block=True, + ) + + init_weights(nn_Unet, init_type='normal') + + return nn_Unet + + +def build_UNETR(): + + from init import Options + opt = Options().parse() + + # create UneTR + + UneTR = monai.networks.nets.UNETR( + in_channels=opt.in_channels, + out_channels=opt.out_channels, + img_size=opt.patch_size, + feature_size=32, + hidden_size=768, + mlp_dim=3072, + num_heads=12, + pos_embed="conv", + norm_name="instance", + res_block=True, + dropout_rate=0.0, + ) + + init_weights(UneTR, init_type='normal') + + return UneTR + + +if __name__ == '__main__': + import time + import torch + from torch.autograd import Variable + from torchsummaryX import summary + from torch.nn import init + + opt = Options().parse() + + torch.cuda.set_device(0) + # network = build_net() + network = build_UNETR() + net = network.cuda().eval() + + data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() + + out = net(data) + + # torch.onnx.export(net, data, "Unet_model_graph.onnx") + + summary(net,data) + print("out size: {}".format(out.size())) + + + + + +