Switch to unified view

a b/inpainting/model/basemodel.py
1
import os
2
import torch
3
import torch.nn as nn
4
5
6
# a complex model consisted of several nets, and each net will be explicitly defined in other py class files
7
class BaseModel(nn.Module):
8
    def __init__(self):
9
        super(BaseModel, self).__init__()
10
11
    def init(self, opt):
12
        self.opt = opt
13
        self.gpu_ids = opt.gpu_ids
14
        self.save_dir = opt.model_folder
15
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
16
        self.model_names = []
17
18
    def setInput(self, img, mask):
19
        self.input = img
20
        self.mask = mask
21
22
    def forward(self):
23
        pass
24
25
    def optimize_parameters(self):
26
        pass
27
28
    def get_current_visuals(self):
29
        pass
30
31
    def get_current_losses(self):
32
        pass
33
34
    def update_learning_rate(self):
35
        pass
36
37
    def test(self):
38
        with torch.no_grad():
39
            self.forward()
40
41
    # save models to the disk
42
    def save_networks(self, which_epoch):
43
        for name in self.model_names:
44
            if isinstance(name, str):
45
                save_filename = '%s_net_%s.pth' % (which_epoch, name)
46
                save_path = os.path.join(self.save_dir, save_filename)
47
                net = getattr(self, 'net' + name)
48
49
                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
50
                    torch.save(net.state_dict(), save_path)
51
                    # net.cuda(self.gpu_ids[0])
52
                else:
53
                    torch.save(net.state_dict(), save_path)
54
55
    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
56
        key = keys[i]
57
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
58
            if module.__class__.__name__.startswith('InstanceNorm') and \
59
                    (key == 'running_mean' or key == 'running_var'):
60
                if getattr(module, key) is None:
61
                    state_dict.pop('.'.join(keys))
62
        else:
63
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
64
65
    # load models from the disk
66
    def load_networks(self, load_path):
67
        for name in self.model_names:
68
            if isinstance(name, str):
69
                net = getattr(self, 'net' + name)
70
                if isinstance(net, torch.nn.DataParallel):
71
                    net = net.module
72
                print('loading the model from %s' % load_path)
73
                # if you are using PyTorch newer than 0.4 (e.g., built from
74
                # GitHub source), you can remove str() on self.device
75
                state_dict = torch.load(load_path)
76
                # patch InstanceNorm checkpoints prior to 0.4
77
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
78
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
79
                net.load_state_dict(state_dict)
80
81
    # print network information
82
    def print_networks(self, verbose=True):
83
        print('---------- Networks initialized -------------')
84
        for name in self.model_names:
85
            if isinstance(name, str):
86
                net = getattr(self, 'net' + name)
87
                num_params = 0
88
                for param in net.parameters():
89
                    num_params += param.numel()
90
                if verbose:
91
                    print(net)
92
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
93
        print('-----------------------------------------------')
94
95
    # set requies_grad=Fasle to avoid computation
96
    def set_requires_grad(self, nets, requires_grad=False):
97
        if not isinstance(nets, list):
98
            nets = [nets]
99
        for net in nets:
100
            if net is not None:
101
                for param in net.parameters():
102
                    param.requires_grad = requires_grad