|
a |
|
b/inpainting/model/basenet.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
|
|
|
5 |
class BaseNet(nn.Module): |
|
|
6 |
def __init__(self): |
|
|
7 |
super(BaseNet, self).__init__() |
|
|
8 |
|
|
|
9 |
def init(self, opt): |
|
|
10 |
self.opt = opt |
|
|
11 |
self.gpu_ids = opt.gpu_ids |
|
|
12 |
self.save_dir = opt.checkpoint_dir |
|
|
13 |
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
|
|
14 |
|
|
|
15 |
def forward(self, *input): |
|
|
16 |
return super(BaseNet, self).forward(*input) |
|
|
17 |
|
|
|
18 |
def test(self, *input): |
|
|
19 |
with torch.no_grad(): |
|
|
20 |
self.forward(*input) |
|
|
21 |
|
|
|
22 |
def save_network(self, network_label, epoch_label): |
|
|
23 |
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) |
|
|
24 |
save_path = os.path.join(self.save_dir, save_filename) |
|
|
25 |
torch.save(self.cpu().state_dict(), save_path) |
|
|
26 |
|
|
|
27 |
def load_network(self, network_label, epoch_label): |
|
|
28 |
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) |
|
|
29 |
save_path = os.path.join(self.save_dir, save_filename) |
|
|
30 |
if not os.path.isfile(save_path): |
|
|
31 |
print('%s not exists yet!' % save_path) |
|
|
32 |
else: |
|
|
33 |
try: |
|
|
34 |
self.load_state_dict(torch.load(save_path)) |
|
|
35 |
except: |
|
|
36 |
pretrained_dict = torch.load(save_path) |
|
|
37 |
model_dict = self.state_dict() |
|
|
38 |
try: |
|
|
39 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
|
40 |
self.load_state_dict(pretrained_dict) |
|
|
41 |
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) |
|
|
42 |
except: |
|
|
43 |
print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label) |
|
|
44 |
for k, v in pretrained_dict.items(): |
|
|
45 |
if v.size() == model_dict[k].size(): |
|
|
46 |
model_dict[k] = v |
|
|
47 |
|
|
|
48 |
for k, v in model_dict.items(): |
|
|
49 |
if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): |
|
|
50 |
print(k.split('.')[0]) |
|
|
51 |
self.load_state_dict(model_dict) |