[9f60b7]: / 3DNet / model.py

Download this file

106 lines (93 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import torch
from torch import nn
from models import resnet
def generate_model(opt):
assert opt.model in [
'resnet'
]
print('model depth: ',opt.model_depth)
if opt.model == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet.resnet10(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
# num_seg_classes=opt.n_seg_classes,
)
elif opt.model_depth == 18:
model = resnet.resnet18(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 34:
model = resnet.resnet34(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 50:
model = resnet.resnet50(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 101:
model = resnet.resnet101(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 152:
model = resnet.resnet152(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
elif opt.model_depth == 200:
model = resnet.resnet200(
sample_input_W=opt.input_W,
sample_input_H=opt.input_H,
sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut,
no_cuda=opt.no_cuda,
num_seg_classes=opt.n_seg_classes)
if not opt.no_cuda:
model = model.cuda()
model = nn.DataParallel(model)
net_dict = model.state_dict()
else:
net_dict = model.state_dict()
# load pretrain
if opt.pretrain_path:
print ('loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys() and 'conv1' not in k}
print(pretrain_dict.keys())
net_dict.update(pretrain_dict)
model.load_state_dict(net_dict)
new_parameters = []
for pname, p in model.named_parameters():
for layer_name in opt.new_layer_names:
if pname.find(layer_name) >= 0:
new_parameters.append(p)
break
new_parameters_id = list(map(id, new_parameters))
base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
parameters = {'base_parameters': base_parameters,
'new_parameters': new_parameters}
return model, parameters
return model, model.parameters()