|
a |
|
b/inpainting/model/basemodel.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
# a complex model consisted of several nets, and each net will be explicitly defined in other py class files |
|
|
7 |
class BaseModel(nn.Module): |
|
|
8 |
def __init__(self): |
|
|
9 |
super(BaseModel, self).__init__() |
|
|
10 |
|
|
|
11 |
def init(self, opt): |
|
|
12 |
self.opt = opt |
|
|
13 |
self.gpu_ids = opt.gpu_ids |
|
|
14 |
self.save_dir = opt.model_folder |
|
|
15 |
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
|
|
16 |
self.model_names = [] |
|
|
17 |
|
|
|
18 |
def setInput(self, img, mask): |
|
|
19 |
self.input = img |
|
|
20 |
self.mask = mask |
|
|
21 |
|
|
|
22 |
def forward(self): |
|
|
23 |
pass |
|
|
24 |
|
|
|
25 |
def optimize_parameters(self): |
|
|
26 |
pass |
|
|
27 |
|
|
|
28 |
def get_current_visuals(self): |
|
|
29 |
pass |
|
|
30 |
|
|
|
31 |
def get_current_losses(self): |
|
|
32 |
pass |
|
|
33 |
|
|
|
34 |
def update_learning_rate(self): |
|
|
35 |
pass |
|
|
36 |
|
|
|
37 |
def test(self): |
|
|
38 |
with torch.no_grad(): |
|
|
39 |
self.forward() |
|
|
40 |
|
|
|
41 |
# save models to the disk |
|
|
42 |
def save_networks(self, which_epoch): |
|
|
43 |
for name in self.model_names: |
|
|
44 |
if isinstance(name, str): |
|
|
45 |
save_filename = '%s_net_%s.pth' % (which_epoch, name) |
|
|
46 |
save_path = os.path.join(self.save_dir, save_filename) |
|
|
47 |
net = getattr(self, 'net' + name) |
|
|
48 |
|
|
|
49 |
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
|
|
50 |
torch.save(net.state_dict(), save_path) |
|
|
51 |
# net.cuda(self.gpu_ids[0]) |
|
|
52 |
else: |
|
|
53 |
torch.save(net.state_dict(), save_path) |
|
|
54 |
|
|
|
55 |
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
|
|
56 |
key = keys[i] |
|
|
57 |
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer |
|
|
58 |
if module.__class__.__name__.startswith('InstanceNorm') and \ |
|
|
59 |
(key == 'running_mean' or key == 'running_var'): |
|
|
60 |
if getattr(module, key) is None: |
|
|
61 |
state_dict.pop('.'.join(keys)) |
|
|
62 |
else: |
|
|
63 |
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) |
|
|
64 |
|
|
|
65 |
# load models from the disk |
|
|
66 |
def load_networks(self, load_path): |
|
|
67 |
for name in self.model_names: |
|
|
68 |
if isinstance(name, str): |
|
|
69 |
net = getattr(self, 'net' + name) |
|
|
70 |
if isinstance(net, torch.nn.DataParallel): |
|
|
71 |
net = net.module |
|
|
72 |
print('loading the model from %s' % load_path) |
|
|
73 |
# if you are using PyTorch newer than 0.4 (e.g., built from |
|
|
74 |
# GitHub source), you can remove str() on self.device |
|
|
75 |
state_dict = torch.load(load_path) |
|
|
76 |
# patch InstanceNorm checkpoints prior to 0.4 |
|
|
77 |
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop |
|
|
78 |
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) |
|
|
79 |
net.load_state_dict(state_dict) |
|
|
80 |
|
|
|
81 |
# print network information |
|
|
82 |
def print_networks(self, verbose=True): |
|
|
83 |
print('---------- Networks initialized -------------') |
|
|
84 |
for name in self.model_names: |
|
|
85 |
if isinstance(name, str): |
|
|
86 |
net = getattr(self, 'net' + name) |
|
|
87 |
num_params = 0 |
|
|
88 |
for param in net.parameters(): |
|
|
89 |
num_params += param.numel() |
|
|
90 |
if verbose: |
|
|
91 |
print(net) |
|
|
92 |
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) |
|
|
93 |
print('-----------------------------------------------') |
|
|
94 |
|
|
|
95 |
# set requies_grad=Fasle to avoid computation |
|
|
96 |
def set_requires_grad(self, nets, requires_grad=False): |
|
|
97 |
if not isinstance(nets, list): |
|
|
98 |
nets = [nets] |
|
|
99 |
for net in nets: |
|
|
100 |
if net is not None: |
|
|
101 |
for param in net.parameters(): |
|
|
102 |
param.requires_grad = requires_grad |