Switch to side-by-side view

--- a
+++ b/inpainting/model/basemodel.py
@@ -0,0 +1,102 @@
+import os
+import torch
+import torch.nn as nn
+
+
+# a complex model consisted of several nets, and each net will be explicitly defined in other py class files
+class BaseModel(nn.Module):
+    def __init__(self):
+        super(BaseModel, self).__init__()
+
+    def init(self, opt):
+        self.opt = opt
+        self.gpu_ids = opt.gpu_ids
+        self.save_dir = opt.model_folder
+        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+        self.model_names = []
+
+    def setInput(self, img, mask):
+        self.input = img
+        self.mask = mask
+
+    def forward(self):
+        pass
+
+    def optimize_parameters(self):
+        pass
+
+    def get_current_visuals(self):
+        pass
+
+    def get_current_losses(self):
+        pass
+
+    def update_learning_rate(self):
+        pass
+
+    def test(self):
+        with torch.no_grad():
+            self.forward()
+
+    # save models to the disk
+    def save_networks(self, which_epoch):
+        for name in self.model_names:
+            if isinstance(name, str):
+                save_filename = '%s_net_%s.pth' % (which_epoch, name)
+                save_path = os.path.join(self.save_dir, save_filename)
+                net = getattr(self, 'net' + name)
+
+                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+                    torch.save(net.state_dict(), save_path)
+                    # net.cuda(self.gpu_ids[0])
+                else:
+                    torch.save(net.state_dict(), save_path)
+
+    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+        key = keys[i]
+        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+                    (key == 'running_mean' or key == 'running_var'):
+                if getattr(module, key) is None:
+                    state_dict.pop('.'.join(keys))
+        else:
+            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+    # load models from the disk
+    def load_networks(self, load_path):
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, 'net' + name)
+                if isinstance(net, torch.nn.DataParallel):
+                    net = net.module
+                print('loading the model from %s' % load_path)
+                # if you are using PyTorch newer than 0.4 (e.g., built from
+                # GitHub source), you can remove str() on self.device
+                state_dict = torch.load(load_path)
+                # patch InstanceNorm checkpoints prior to 0.4
+                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
+                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+                net.load_state_dict(state_dict)
+
+    # print network information
+    def print_networks(self, verbose=True):
+        print('---------- Networks initialized -------------')
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, 'net' + name)
+                num_params = 0
+                for param in net.parameters():
+                    num_params += param.numel()
+                if verbose:
+                    print(net)
+                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+        print('-----------------------------------------------')
+
+    # set requies_grad=Fasle to avoid computation
+    def set_requires_grad(self, nets, requires_grad=False):
+        if not isinstance(nets, list):
+            nets = [nets]
+        for net in nets:
+            if net is not None:
+                for param in net.parameters():
+                    param.requires_grad = requires_grad