[8eeb5a]: / inpainting / train.py

Download this file

144 lines (122 with data), 6.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.utils as vutils
import torch
import torch.nn as nn
from data.data import InpaintingDataset, ToTensor
from model.net import InpaintingModel_GMCNN
from options.train_options import TrainOptions
from util.utils import getLatest
from data.hyperkvasir import KvasirSegmentationDataset
import torch.nn.functional as F
from inpainting.model.basemodel import BaseModel
from inpainting.model.basenet import BaseNet
from inpainting.model.loss import WGANLoss, IDMRFLoss
from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
from inpainting.model.net import *
config = {"model": "DeepLab",
"device": "cuda",
"lr": 0.00001,
"batch_size": 8,
"epochs": 250}
DATASET_PATH = ""
class gmcnn_inpainter_trainer:
def __init__(self, config):
self.model = InpaintingModel_GMCNN(in_channels=3, opt=config)
self.dataset = KvasirSegmentationDataset(DATASET_PATH, augment=True) # todo make inpainting dataset
self.dataloader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True)
self.device = config["device"]
self.recloss = nn.L1Loss()
self.aeloss = nn.L1Loss()
self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
self.netGM = GMCNN(3, out_channels=3, cnum=32, act=F.elu, norm=F.instance_norm).cuda()
init_weights(self.netGM)
self.model_names = ['GM']
self.netD = None
self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
self.optimizer_D = None
self.wganloss = None
self.recloss = nn.L1Loss()
self.aeloss = nn.L1Loss()
self.mrfloss = None
# self.lambda_adv = opt.lambda_adv
# self.lambda_rec = opt.lambda_rec
# self.lambda_ae = opt.lambda_ae
# self.lambda_gp = opt.lambda_gp
# self.lambda_mrf = opt.lambda_mrf
self.G_loss = None
self.G_loss_reconstruction = None
self.G_loss_mrf = None
self.G_loss_adv, self.G_loss_adv_local = None, None
self.G_loss_ae = None
self.D_loss, self.D_loss_local = None, None
self.GAN_loss = None
self.gt, self.gt_local = None, None
self.mask, self.mask_01 = None, None
self.rect = None
self.im_in, self.gin = None, None
self.completed, self.completed_local = None, None
self.completed_logit, self.completed_local_logit = None, None
self.gt_logit, self.gt_local_logit = None, None
self.pred = None
self.netD = GlobalLocalDiscriminator(3, cnum=64, act=F.elu,
spectral_norm=True,
g_fc_channels=512 // 16 * 512 // 16 * 64 * 4,
l_fc_channels=512 // 16 * 512 // 16 * 64 * 4).to(self.device)
init_weights(self.netD)
self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=config["lr"],
betas=(0.5, 0.9))
self.wganloss = WGANLoss()
self.mrfloss = IDMRFLoss()
def train(self):
for epoch in range(config["epochs"]):
self.train_epoch(epoch)
ret_loss = self.model.get_current_losses()
self.model.save_networks(epoch + 1)
def train_epoch(self, epoch):
for img, mask, fname in self.dataloader:
img, mask = img.to(self.device), mask.to(self.device)
img_in = img * (1 - mask)
self.gen_in = torch.cat((img_in, mask), 1)
self.model.setInput(img, mask)
self.model.optimize_parameters()
self.pred = self.netGM(self.gin)
self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
self.completed_local = self.completed
for i in range(5): # train discriminator 5 times interleaved
self.optimizer_D.zero_grad()
self.optimizer_G.zero_grad()
self.forward_D()
self.backward_D()
self.optimizer_D.step()
self.optimizer_G.zero_grad()
self.forward_G()
self.backward_G()
self.optimizer_G.step()
# TODO come back here to finish gmcnn
def forward_G(self):
self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0)
self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf
self.G_loss_adv = -self.completed_logit.mean()
self.G_loss_adv_local = -self.completed_local_logit.mean()
self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)
def forward_D(self):
self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(),
self.completed_local.detach())
self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
# hinge loss
self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(
1.0 + self.completed_local_logit).mean()
self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
self.D_loss = self.D_loss + self.D_loss_local
def backward_G(self):
self.G_loss.backward()
def backward_D(self):
self.D_loss.backward(retain_graph=True)