Diff of /3DNet/model.py [000000] .. [9f60b7]

Switch to side-by-side view

--- a
+++ b/3DNet/model.py
@@ -0,0 +1,105 @@
+import torch
+from torch import nn
+from models import resnet
+
+
+def generate_model(opt):
+    assert opt.model in [
+        'resnet'
+    ]
+
+    print('model depth: ',opt.model_depth)
+
+    if opt.model == 'resnet':
+        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
+        
+        if opt.model_depth == 10:
+            model = resnet.resnet10(
+                # sample_input_W=opt.input_W,
+                # sample_input_H=opt.input_H,
+                # sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                # num_seg_classes=opt.n_seg_classes,
+            )
+        elif opt.model_depth == 18:
+            model = resnet.resnet18(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+        elif opt.model_depth == 34:
+            model = resnet.resnet34(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+        elif opt.model_depth == 50:
+            model = resnet.resnet50(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+        elif opt.model_depth == 101:
+            model = resnet.resnet101(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+        elif opt.model_depth == 152:
+            model = resnet.resnet152(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+        elif opt.model_depth == 200:
+            model = resnet.resnet200(
+                sample_input_W=opt.input_W,
+                sample_input_H=opt.input_H,
+                sample_input_D=opt.input_D,
+                shortcut_type=opt.resnet_shortcut,
+                no_cuda=opt.no_cuda,
+                num_seg_classes=opt.n_seg_classes)
+    
+    if not opt.no_cuda:
+            model = model.cuda()
+            model = nn.DataParallel(model)
+            net_dict = model.state_dict()
+    else:
+        net_dict = model.state_dict()
+    
+    # load pretrain
+    if opt.pretrain_path:
+        print ('loading pretrained model {}'.format(opt.pretrain_path))
+        pretrain = torch.load(opt.pretrain_path)
+        pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys() and 'conv1' not in k}
+        print(pretrain_dict.keys())
+
+        net_dict.update(pretrain_dict)
+        model.load_state_dict(net_dict)
+
+        new_parameters = []
+        for pname, p in model.named_parameters():
+            for layer_name in opt.new_layer_names:
+                if pname.find(layer_name) >= 0:
+                    new_parameters.append(p)
+                    break
+
+        new_parameters_id = list(map(id, new_parameters))
+        base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
+        parameters = {'base_parameters': base_parameters,
+                      'new_parameters': new_parameters}
+
+        return model, parameters
+
+    return model, model.parameters()