Switch to unified view

a b/inpainting/model/basenet.py
1
import os
2
import torch
3
import torch.nn as nn
4
5
class BaseNet(nn.Module):
6
    def __init__(self):
7
        super(BaseNet, self).__init__()
8
9
    def init(self, opt):
10
        self.opt = opt
11
        self.gpu_ids = opt.gpu_ids
12
        self.save_dir = opt.checkpoint_dir
13
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
14
15
    def forward(self, *input):
16
        return super(BaseNet, self).forward(*input)
17
18
    def test(self, *input):
19
        with torch.no_grad():
20
            self.forward(*input)
21
22
    def save_network(self, network_label, epoch_label):
23
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
24
        save_path = os.path.join(self.save_dir, save_filename)
25
        torch.save(self.cpu().state_dict(), save_path)
26
27
    def load_network(self, network_label, epoch_label):
28
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
29
        save_path = os.path.join(self.save_dir, save_filename)
30
        if not os.path.isfile(save_path):
31
            print('%s not exists yet!' % save_path)
32
        else:
33
            try:
34
                self.load_state_dict(torch.load(save_path))
35
            except:
36
                pretrained_dict = torch.load(save_path)
37
                model_dict = self.state_dict()
38
                try:
39
                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
40
                    self.load_state_dict(pretrained_dict)
41
                    print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
42
                except:
43
                    print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label)
44
                    for k, v in pretrained_dict.items():
45
                        if v.size() == model_dict[k].size():
46
                            model_dict[k] = v
47
48
                    for k, v in model_dict.items():
49
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
50
                            print(k.split('.')[0])
51
                    self.load_state_dict(model_dict)