Switch to unified view

a b/training/train_inpainter.py
1
import matplotlib.pyplot as plt
2
import numpy as np
3
import torch
4
import torch.nn
5
import torchvision.transforms as transforms
6
from PIL import Image
7
from torch.autograd import Variable
8
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
9
from torch.utils.data import DataLoader
10
11
from data.hyperkvasir import KvasirInpaintingDataset
12
from models.inpainters import SegGenerator, SegDiscriminator
13
from perturbation.polyp_inpainter import Inpainter
14
15
16
# TODO refactor
17
18
def weights_init_normal(m):
19
    classname = m.__class__.__name__
20
    if classname.find("Conv") != -1:
21
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
22
    elif classname.find("BatchNorm2d") != -1:
23
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
24
        torch.nn.init.constant_(m.bias.data, 0.0)
25
26
27
def train_new_inpainter():
28
    # Loss function
29
    adversarial_loss = torch.nn.BCELoss()
30
    pixelwise_loss = torch.nn.L1Loss()
31
32
    # Initialize generator and discriminator
33
    # generator = Generator(channels=3)
34
    # discriminator = Discriminator(channels=3)
35
    generator = SegGenerator()
36
    discriminator = SegDiscriminator()
37
38
    generator.load_state_dict(torch.load("Predictors/Inpainters/no-pretrain-deeplab-generator-940"))
39
    discriminator.load_state_dict(torch.load("Predictors/Inpainters/no-pretrain-deeplab-discriminator-940"))
40
41
    cuda = True
42
    if cuda:
43
        generator.cuda()
44
        discriminator.cuda()
45
        adversarial_loss.cuda()
46
        pixelwise_loss.cuda()
47
    # Dataset loader TODO refactor w/ albumentation library
48
    transforms_ = [
49
        transforms.Resize((400, 400), Image.BICUBIC),
50
        transforms.ToTensor(),
51
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
52
    ]
53
    dataloader = DataLoader(
54
        KvasirInpaintingDataset("Datasets/HyperKvasir"),
55
        batch_size=8,
56
        shuffle=False,
57
        num_workers=1,
58
    )
59
    # test_dataloader = DataLoader(
60
    #     EtisDataset("Datasets/ETIS-LaribPolypDB"),
61
    #     batch_size=12,
62
    #     shuffle=True,
63
    #     num_workers=1,
64
    # )
65
66
    # Optimizers
67
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
68
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00001)
69
    scheduler_G = CosineAnnealingWarmRestarts(optimizer_G, T_0=100, T_mult=2)
70
    scheduler_D = CosineAnnealingWarmRestarts(optimizer_D, T_0=100, T_mult=2)
71
72
    # Initialize weights
73
    # generator.apply(weights_init_normal)
74
    # discriminator.apply(weights_init_normal)
75
    # patch_h, patch_w = int(50 / 2 ** 3), int(50 / 2 ** 3)
76
    # patch = (1, patch_h, patch_w)
77
    # print(patch)
78
    for epoch in range(990, 5000):
79
        printed = False
80
        d_losses = []
81
        g_advs = []
82
        g_pixels = []
83
84
        for i, (imgs, mask, masked_imgs, masked_parts, filename) in enumerate(dataloader):
85
            imgs = imgs.cuda()
86
            mask = mask.cuda()
87
            masked_imgs = masked_imgs.cuda()
88
            masked_parts = masked_parts.cuda()
89
            mask_bool = mask == 1
90
            # Adversarial ground truths (boxes)
91
            # valid = Variable(torch.Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
92
            # fake = Variable(torch.Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)
93
            valid = torch.masked_select(torch.ones_like(mask), mask_bool)
94
            fake = torch.masked_select(torch.zeros_like(mask), mask_bool)
95
96
            # Configure input
97
            imgs = Variable(imgs)
98
            masked_imgs = Variable(masked_imgs)
99
            masked_parts = Variable(masked_parts)
100
101
            # -----------------
102
            #  Train Generator
103
            # -----------------
104
105
            optimizer_G.zero_grad()
106
107
            # Generate a batch of images
108
            gen_parts = generator(masked_imgs)
109
110
            # Adversarial and pixelwise loss
111
            disc = discriminator(gen_parts)
112
            # print(disc)
113
114
            g_adv = adversarial_loss(torch.masked_select(disc, mask_bool), valid)
115
            g_pixel = pixelwise_loss(torch.masked_select(gen_parts, mask_bool), torch.masked_select(imgs, mask_bool))
116
            g_advs.append(g_adv.item())
117
            g_pixels.append(g_pixel.item())
118
            # Total loss
119
            g_loss = 0.001 * g_adv + 0.999 * g_pixel
120
121
            g_loss.backward()
122
            optimizer_G.step()
123
            scheduler_G.step(epoch)
124
125
            # ---------------------
126
            #  Train Discriminator
127
            # ---------------------
128
129
            optimizer_D.zero_grad()
130
131
            # Measure discriminator's ability to classify real from generated samples
132
133
            real_loss = adversarial_loss(torch.masked_select(discriminator(masked_parts), mask_bool), valid)
134
            fake_loss = adversarial_loss(torch.masked_select(discriminator(gen_parts.detach()), mask_bool), fake)
135
            d_loss = 0.5 * (real_loss + fake_loss)
136
            d_losses.append(d_loss.item())
137
            # wasserstein critic loss
138
139
            # d_loss = -torch.mean(discriminator(masked_parts)) + torch.mean(discriminator(gen_parts.detach()))
140
141
            d_loss.backward()
142
            optimizer_D.step()
143
            scheduler_D.step(epoch)
144
            if not printed and epoch % 10 == 0:
145
                torch.save(generator.state_dict(), f"Predictors/Inpainters/no-pretrain-deeplab-generator-{epoch}")
146
                torch.save(discriminator.state_dict(),
147
                           f"Predictors/Inpainters/no-pretrain-deeplab-discriminator-{epoch}")
148
                plt.title("Part")
149
                plt.imshow((gen_parts[0].detach().cpu().numpy().T))
150
                plt.show()
151
                # plt.title("Superimposed")
152
                # plt.imshow((gen_parts[0].detach().cpu().numpy().T))
153
                # plt.imshow(masked_imgs[0].detach().cpu().numpy().T)
154
                # plt.show()
155
                # plt.title("Real")
156
                # plt.imshow(masked_parts[0].detach().cpu().numpy().T)
157
                # plt.show()
158
                try:
159
                    test = Inpainter(f"Predictors/Inpainters/no-pretrain-deeplab-generator-{epoch}")
160
                    test.get_test()
161
                except FileNotFoundError:
162
                    print("Weird...")
163
                printed = True
164
        print(
165
            f"[Epoch {epoch}] [D loss: {np.mean(d_losses)}] [G adv: {np.mean(g_advs)}, pixel: {np.mean(g_pixels)}]"
166
        )
167
168
169
if __name__ == '__main__':
170
    train_new_inpainter()