import torch
from torch import nn
from models import resnet
def generate_model(opt):
assert opt.model in [
'resnet'
]
print('model depth: ',opt.model_depth)
if opt.model == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet.resnet10(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
# num_seg_classes=opt.n_seg_classes,
)
elif opt.model_depth == 18:
model = resnet.resnet18(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 34:
model = resnet.resnet34(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 50:
model = resnet.resnet50(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 101:
model = resnet.resnet101(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 152:
model = resnet.resnet152(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 200:
model = resnet.resnet200(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
if not opt.no_cuda:
model = model.cuda()
model = nn.DataParallel(model)
net_dict = model.state_dict()
else:
net_dict = model.state_dict()
# load pretrain
if opt.pretrain_path:
print ('loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys() and 'conv1' not in k}
print(pretrain_dict.keys())
net_dict.update(pretrain_dict)
model.load_state_dict(net_dict)
new_parameters = []
for pname, p in model.named_parameters():
for layer_name in opt.new_layer_names:
if pname.find(layer_name) >= 0:
new_parameters.append(p)
break
new_parameters_id = list(map(id, new_parameters))
base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
parameters = {'base_parameters': base_parameters,
'new_parameters': new_parameters}
return model, parameters
return model, model.parameters()