--- a +++ b/3DNet/model.py @@ -0,0 +1,105 @@ +import torch +from torch import nn +from models import resnet + + +def generate_model(opt): + assert opt.model in [ + 'resnet' + ] + + print('model depth: ',opt.model_depth) + + if opt.model == 'resnet': + assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] + + if opt.model_depth == 10: + model = resnet.resnet10( + # sample_input_W=opt.input_W, + # sample_input_H=opt.input_H, + # sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + # num_seg_classes=opt.n_seg_classes, + ) + elif opt.model_depth == 18: + model = resnet.resnet18( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + elif opt.model_depth == 34: + model = resnet.resnet34( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + elif opt.model_depth == 50: + model = resnet.resnet50( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + elif opt.model_depth == 101: + model = resnet.resnet101( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + elif opt.model_depth == 152: + model = resnet.resnet152( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + elif opt.model_depth == 200: + model = resnet.resnet200( + sample_input_W=opt.input_W, + sample_input_H=opt.input_H, + sample_input_D=opt.input_D, + shortcut_type=opt.resnet_shortcut, + no_cuda=opt.no_cuda, + num_seg_classes=opt.n_seg_classes) + + if not opt.no_cuda: + model = model.cuda() + model = nn.DataParallel(model) + net_dict = model.state_dict() + else: + net_dict = model.state_dict() + + # load pretrain + if opt.pretrain_path: + print ('loading pretrained model {}'.format(opt.pretrain_path)) + pretrain = torch.load(opt.pretrain_path) + pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys() and 'conv1' not in k} + print(pretrain_dict.keys()) + + net_dict.update(pretrain_dict) + model.load_state_dict(net_dict) + + new_parameters = [] + for pname, p in model.named_parameters(): + for layer_name in opt.new_layer_names: + if pname.find(layer_name) >= 0: + new_parameters.append(p) + break + + new_parameters_id = list(map(id, new_parameters)) + base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters())) + parameters = {'base_parameters': base_parameters, + 'new_parameters': new_parameters} + + return model, parameters + + return model, model.parameters()