|
a |
|
b/inpainting/train.py |
|
|
1 |
import os |
|
|
2 |
from torch.utils.data import DataLoader |
|
|
3 |
from torchvision import transforms |
|
|
4 |
import torchvision.utils as vutils |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
from data.data import InpaintingDataset, ToTensor |
|
|
8 |
from model.net import InpaintingModel_GMCNN |
|
|
9 |
from options.train_options import TrainOptions |
|
|
10 |
from util.utils import getLatest |
|
|
11 |
from data.hyperkvasir import KvasirSegmentationDataset |
|
|
12 |
import torch.nn.functional as F |
|
|
13 |
from inpainting.model.basemodel import BaseModel |
|
|
14 |
from inpainting.model.basenet import BaseNet |
|
|
15 |
from inpainting.model.loss import WGANLoss, IDMRFLoss |
|
|
16 |
from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm |
|
|
17 |
from inpainting.model.net import * |
|
|
18 |
|
|
|
19 |
config = {"model": "DeepLab", |
|
|
20 |
"device": "cuda", |
|
|
21 |
"lr": 0.00001, |
|
|
22 |
"batch_size": 8, |
|
|
23 |
"epochs": 250} |
|
|
24 |
DATASET_PATH = "" |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class gmcnn_inpainter_trainer: |
|
|
28 |
def __init__(self, config): |
|
|
29 |
self.model = InpaintingModel_GMCNN(in_channels=3, opt=config) |
|
|
30 |
self.dataset = KvasirSegmentationDataset(DATASET_PATH, augment=True) # todo make inpainting dataset |
|
|
31 |
self.dataloader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True) |
|
|
32 |
self.device = config["device"] |
|
|
33 |
self.recloss = nn.L1Loss() |
|
|
34 |
self.aeloss = nn.L1Loss() |
|
|
35 |
self.confidence_mask_layer = ConfidenceDrivenMaskLayer() |
|
|
36 |
|
|
|
37 |
self.netGM = GMCNN(3, out_channels=3, cnum=32, act=F.elu, norm=F.instance_norm).cuda() |
|
|
38 |
init_weights(self.netGM) |
|
|
39 |
self.model_names = ['GM'] |
|
|
40 |
|
|
|
41 |
self.netD = None |
|
|
42 |
|
|
|
43 |
self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9)) |
|
|
44 |
self.optimizer_D = None |
|
|
45 |
|
|
|
46 |
self.wganloss = None |
|
|
47 |
self.recloss = nn.L1Loss() |
|
|
48 |
self.aeloss = nn.L1Loss() |
|
|
49 |
self.mrfloss = None |
|
|
50 |
# self.lambda_adv = opt.lambda_adv |
|
|
51 |
# self.lambda_rec = opt.lambda_rec |
|
|
52 |
# self.lambda_ae = opt.lambda_ae |
|
|
53 |
# self.lambda_gp = opt.lambda_gp |
|
|
54 |
# self.lambda_mrf = opt.lambda_mrf |
|
|
55 |
self.G_loss = None |
|
|
56 |
self.G_loss_reconstruction = None |
|
|
57 |
self.G_loss_mrf = None |
|
|
58 |
self.G_loss_adv, self.G_loss_adv_local = None, None |
|
|
59 |
self.G_loss_ae = None |
|
|
60 |
self.D_loss, self.D_loss_local = None, None |
|
|
61 |
self.GAN_loss = None |
|
|
62 |
|
|
|
63 |
self.gt, self.gt_local = None, None |
|
|
64 |
self.mask, self.mask_01 = None, None |
|
|
65 |
self.rect = None |
|
|
66 |
self.im_in, self.gin = None, None |
|
|
67 |
|
|
|
68 |
self.completed, self.completed_local = None, None |
|
|
69 |
self.completed_logit, self.completed_local_logit = None, None |
|
|
70 |
self.gt_logit, self.gt_local_logit = None, None |
|
|
71 |
|
|
|
72 |
self.pred = None |
|
|
73 |
self.netD = GlobalLocalDiscriminator(3, cnum=64, act=F.elu, |
|
|
74 |
spectral_norm=True, |
|
|
75 |
g_fc_channels=512 // 16 * 512 // 16 * 64 * 4, |
|
|
76 |
l_fc_channels=512 // 16 * 512 // 16 * 64 * 4).to(self.device) |
|
|
77 |
init_weights(self.netD) |
|
|
78 |
self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=config["lr"], |
|
|
79 |
betas=(0.5, 0.9)) |
|
|
80 |
self.wganloss = WGANLoss() |
|
|
81 |
self.mrfloss = IDMRFLoss() |
|
|
82 |
|
|
|
83 |
def train(self): |
|
|
84 |
for epoch in range(config["epochs"]): |
|
|
85 |
self.train_epoch(epoch) |
|
|
86 |
ret_loss = self.model.get_current_losses() |
|
|
87 |
self.model.save_networks(epoch + 1) |
|
|
88 |
|
|
|
89 |
def train_epoch(self, epoch): |
|
|
90 |
for img, mask, fname in self.dataloader: |
|
|
91 |
img, mask = img.to(self.device), mask.to(self.device) |
|
|
92 |
img_in = img * (1 - mask) |
|
|
93 |
self.gen_in = torch.cat((img_in, mask), 1) |
|
|
94 |
self.model.setInput(img, mask) |
|
|
95 |
self.model.optimize_parameters() |
|
|
96 |
|
|
|
97 |
self.pred = self.netGM(self.gin) |
|
|
98 |
self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01) |
|
|
99 |
self.completed_local = self.completed |
|
|
100 |
|
|
|
101 |
for i in range(5): # train discriminator 5 times interleaved |
|
|
102 |
self.optimizer_D.zero_grad() |
|
|
103 |
self.optimizer_G.zero_grad() |
|
|
104 |
self.forward_D() |
|
|
105 |
self.backward_D() |
|
|
106 |
self.optimizer_D.step() |
|
|
107 |
|
|
|
108 |
self.optimizer_G.zero_grad() |
|
|
109 |
self.forward_G() |
|
|
110 |
self.backward_G() |
|
|
111 |
self.optimizer_G.step() |
|
|
112 |
# TODO come back here to finish gmcnn |
|
|
113 |
|
|
|
114 |
def forward_G(self): |
|
|
115 |
self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask) |
|
|
116 |
self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01) |
|
|
117 |
self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01)) |
|
|
118 |
self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01) |
|
|
119 |
self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae |
|
|
120 |
|
|
|
121 |
self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local) |
|
|
122 |
self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0) |
|
|
123 |
self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf |
|
|
124 |
|
|
|
125 |
self.G_loss_adv = -self.completed_logit.mean() |
|
|
126 |
self.G_loss_adv_local = -self.completed_local_logit.mean() |
|
|
127 |
self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local) |
|
|
128 |
|
|
|
129 |
def forward_D(self): |
|
|
130 |
self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(), |
|
|
131 |
self.completed_local.detach()) |
|
|
132 |
self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local) |
|
|
133 |
# hinge loss |
|
|
134 |
self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()( |
|
|
135 |
1.0 + self.completed_local_logit).mean() |
|
|
136 |
self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean() |
|
|
137 |
self.D_loss = self.D_loss + self.D_loss_local |
|
|
138 |
|
|
|
139 |
def backward_G(self): |
|
|
140 |
self.G_loss.backward() |
|
|
141 |
|
|
|
142 |
def backward_D(self): |
|
|
143 |
self.D_loss.backward(retain_graph=True) |