--- a +++ b/inpainting/model/basenet.py @@ -0,0 +1,51 @@ +import os +import torch +import torch.nn as nn + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + def init(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.save_dir = opt.checkpoint_dir + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + + def forward(self, *input): + return super(BaseNet, self).forward(*input) + + def test(self, *input): + with torch.no_grad(): + self.forward(*input) + + def save_network(self, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(self.cpu().state_dict(), save_path) + + def load_network(self, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + if not os.path.isfile(save_path): + print('%s not exists yet!' % save_path) + else: + try: + self.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = self.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + self.load_state_dict(pretrained_dict) + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label) + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + print(k.split('.')[0]) + self.load_state_dict(model_dict)