Diff of /inpainting/train.py [000000] .. [92cc18]

Switch to unified view

a b/inpainting/train.py
1
import os
2
from torch.utils.data import DataLoader
3
from torchvision import transforms
4
import torchvision.utils as vutils
5
import torch
6
import torch.nn as nn
7
from data.data import InpaintingDataset, ToTensor
8
from model.net import InpaintingModel_GMCNN
9
from options.train_options import TrainOptions
10
from util.utils import getLatest
11
from data.hyperkvasir import KvasirSegmentationDataset
12
import torch.nn.functional as F
13
from inpainting.model.basemodel import BaseModel
14
from inpainting.model.basenet import BaseNet
15
from inpainting.model.loss import WGANLoss, IDMRFLoss
16
from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
17
from inpainting.model.net import *
18
19
config = {"model": "DeepLab",
20
          "device": "cuda",
21
          "lr": 0.00001,
22
          "batch_size": 8,
23
          "epochs": 250}
24
DATASET_PATH = ""
25
26
27
class gmcnn_inpainter_trainer:
28
    def __init__(self, config):
29
        self.model = InpaintingModel_GMCNN(in_channels=3, opt=config)
30
        self.dataset = KvasirSegmentationDataset(DATASET_PATH, augment=True)  # todo make inpainting dataset
31
        self.dataloader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True)
32
        self.device = config["device"]
33
        self.recloss = nn.L1Loss()
34
        self.aeloss = nn.L1Loss()
35
        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
36
37
        self.netGM = GMCNN(3, out_channels=3, cnum=32, act=F.elu, norm=F.instance_norm).cuda()
38
        init_weights(self.netGM)
39
        self.model_names = ['GM']
40
41
        self.netD = None
42
43
        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
44
        self.optimizer_D = None
45
46
        self.wganloss = None
47
        self.recloss = nn.L1Loss()
48
        self.aeloss = nn.L1Loss()
49
        self.mrfloss = None
50
        # self.lambda_adv = opt.lambda_adv
51
        # self.lambda_rec = opt.lambda_rec
52
        # self.lambda_ae = opt.lambda_ae
53
        # self.lambda_gp = opt.lambda_gp
54
        # self.lambda_mrf = opt.lambda_mrf
55
        self.G_loss = None
56
        self.G_loss_reconstruction = None
57
        self.G_loss_mrf = None
58
        self.G_loss_adv, self.G_loss_adv_local = None, None
59
        self.G_loss_ae = None
60
        self.D_loss, self.D_loss_local = None, None
61
        self.GAN_loss = None
62
63
        self.gt, self.gt_local = None, None
64
        self.mask, self.mask_01 = None, None
65
        self.rect = None
66
        self.im_in, self.gin = None, None
67
68
        self.completed, self.completed_local = None, None
69
        self.completed_logit, self.completed_local_logit = None, None
70
        self.gt_logit, self.gt_local_logit = None, None
71
72
        self.pred = None
73
        self.netD = GlobalLocalDiscriminator(3, cnum=64, act=F.elu,
74
                                             spectral_norm=True,
75
                                             g_fc_channels=512 // 16 * 512 // 16 * 64 * 4,
76
                                             l_fc_channels=512 // 16 * 512 // 16 * 64 * 4).to(self.device)
77
        init_weights(self.netD)
78
        self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=config["lr"],
79
                                            betas=(0.5, 0.9))
80
        self.wganloss = WGANLoss()
81
        self.mrfloss = IDMRFLoss()
82
83
    def train(self):
84
        for epoch in range(config["epochs"]):
85
            self.train_epoch(epoch)
86
            ret_loss = self.model.get_current_losses()
87
            self.model.save_networks(epoch + 1)
88
89
    def train_epoch(self, epoch):
90
        for img, mask, fname in self.dataloader:
91
            img, mask = img.to(self.device), mask.to(self.device)
92
            img_in = img * (1 - mask)
93
            self.gen_in = torch.cat((img_in, mask), 1)
94
            self.model.setInput(img, mask)
95
            self.model.optimize_parameters()
96
97
            self.pred = self.netGM(self.gin)
98
            self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
99
            self.completed_local = self.completed
100
101
            for i in range(5):  # train discriminator 5 times interleaved
102
                self.optimizer_D.zero_grad()
103
                self.optimizer_G.zero_grad()
104
                self.forward_D()
105
                self.backward_D()
106
                self.optimizer_D.step()
107
108
            self.optimizer_G.zero_grad()
109
            self.forward_G()
110
            self.backward_G()
111
            self.optimizer_G.step()
112
            # TODO come back here to finish gmcnn
113
114
    def forward_G(self):
115
        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
116
        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
117
        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
118
        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
119
        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
120
121
        self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
122
        self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0)
123
        self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf
124
125
        self.G_loss_adv = -self.completed_logit.mean()
126
        self.G_loss_adv_local = -self.completed_local_logit.mean()
127
        self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)
128
129
    def forward_D(self):
130
        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(),
131
                                                                     self.completed_local.detach())
132
        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
133
        # hinge loss
134
        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(
135
            1.0 + self.completed_local_logit).mean()
136
        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
137
        self.D_loss = self.D_loss + self.D_loss_local
138
139
    def backward_G(self):
140
        self.G_loss.backward()
141
142
    def backward_D(self):
143
        self.D_loss.backward(retain_graph=True)