--- a +++ b/inpainting/model/basemodel.py @@ -0,0 +1,102 @@ +import os +import torch +import torch.nn as nn + + +# a complex model consisted of several nets, and each net will be explicitly defined in other py class files +class BaseModel(nn.Module): + def __init__(self): + super(BaseModel, self).__init__() + + def init(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.save_dir = opt.model_folder + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + self.model_names = [] + + def setInput(self, img, mask): + self.input = img + self.mask = mask + + def forward(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def update_learning_rate(self): + pass + + def test(self): + with torch.no_grad(): + self.forward() + + # save models to the disk + def save_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (which_epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.state_dict(), save_path) + # net.cuda(self.gpu_ids[0]) + else: + torch.save(net.state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + # load models from the disk + def load_networks(self, load_path): + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path) + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + # print network information + def print_networks(self, verbose=True): + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + # set requies_grad=Fasle to avoid computation + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad