[8eeb5a]: / inpainting / model / basemodel.py

Download this file

103 lines (86 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
import os
import torch
import torch.nn as nn
# a complex model consisted of several nets, and each net will be explicitly defined in other py class files
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
def init(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.save_dir = opt.model_folder
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.model_names = []
def setInput(self, img, mask):
self.input = img
self.mask = mask
def forward(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def get_current_losses(self):
pass
def update_learning_rate(self):
pass
def test(self):
with torch.no_grad():
self.forward()
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.state_dict(), save_path)
# net.cuda(self.gpu_ids[0])
else:
torch.save(net.state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, load_path):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path)
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose=True):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad