Switch to side-by-side view

--- a
+++ b/training/train_inpainter.py
@@ -0,0 +1,170 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn
+import torchvision.transforms as transforms
+from PIL import Image
+from torch.autograd import Variable
+from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
+from torch.utils.data import DataLoader
+
+from data.hyperkvasir import KvasirInpaintingDataset
+from models.inpainters import SegGenerator, SegDiscriminator
+from perturbation.polyp_inpainter import Inpainter
+
+
+# TODO refactor
+
+def weights_init_normal(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
+    elif classname.find("BatchNorm2d") != -1:
+        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
+        torch.nn.init.constant_(m.bias.data, 0.0)
+
+
+def train_new_inpainter():
+    # Loss function
+    adversarial_loss = torch.nn.BCELoss()
+    pixelwise_loss = torch.nn.L1Loss()
+
+    # Initialize generator and discriminator
+    # generator = Generator(channels=3)
+    # discriminator = Discriminator(channels=3)
+    generator = SegGenerator()
+    discriminator = SegDiscriminator()
+
+    generator.load_state_dict(torch.load("Predictors/Inpainters/no-pretrain-deeplab-generator-940"))
+    discriminator.load_state_dict(torch.load("Predictors/Inpainters/no-pretrain-deeplab-discriminator-940"))
+
+    cuda = True
+    if cuda:
+        generator.cuda()
+        discriminator.cuda()
+        adversarial_loss.cuda()
+        pixelwise_loss.cuda()
+    # Dataset loader TODO refactor w/ albumentation library
+    transforms_ = [
+        transforms.Resize((400, 400), Image.BICUBIC),
+        transforms.ToTensor(),
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+    ]
+    dataloader = DataLoader(
+        KvasirInpaintingDataset("Datasets/HyperKvasir"),
+        batch_size=8,
+        shuffle=False,
+        num_workers=1,
+    )
+    # test_dataloader = DataLoader(
+    #     EtisDataset("Datasets/ETIS-LaribPolypDB"),
+    #     batch_size=12,
+    #     shuffle=True,
+    #     num_workers=1,
+    # )
+
+    # Optimizers
+    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
+    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00001)
+    scheduler_G = CosineAnnealingWarmRestarts(optimizer_G, T_0=100, T_mult=2)
+    scheduler_D = CosineAnnealingWarmRestarts(optimizer_D, T_0=100, T_mult=2)
+
+    # Initialize weights
+    # generator.apply(weights_init_normal)
+    # discriminator.apply(weights_init_normal)
+    # patch_h, patch_w = int(50 / 2 ** 3), int(50 / 2 ** 3)
+    # patch = (1, patch_h, patch_w)
+    # print(patch)
+    for epoch in range(990, 5000):
+        printed = False
+        d_losses = []
+        g_advs = []
+        g_pixels = []
+
+        for i, (imgs, mask, masked_imgs, masked_parts, filename) in enumerate(dataloader):
+            imgs = imgs.cuda()
+            mask = mask.cuda()
+            masked_imgs = masked_imgs.cuda()
+            masked_parts = masked_parts.cuda()
+            mask_bool = mask == 1
+            # Adversarial ground truths (boxes)
+            # valid = Variable(torch.Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
+            # fake = Variable(torch.Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)
+            valid = torch.masked_select(torch.ones_like(mask), mask_bool)
+            fake = torch.masked_select(torch.zeros_like(mask), mask_bool)
+
+            # Configure input
+            imgs = Variable(imgs)
+            masked_imgs = Variable(masked_imgs)
+            masked_parts = Variable(masked_parts)
+
+            # -----------------
+            #  Train Generator
+            # -----------------
+
+            optimizer_G.zero_grad()
+
+            # Generate a batch of images
+            gen_parts = generator(masked_imgs)
+
+            # Adversarial and pixelwise loss
+            disc = discriminator(gen_parts)
+            # print(disc)
+
+            g_adv = adversarial_loss(torch.masked_select(disc, mask_bool), valid)
+            g_pixel = pixelwise_loss(torch.masked_select(gen_parts, mask_bool), torch.masked_select(imgs, mask_bool))
+            g_advs.append(g_adv.item())
+            g_pixels.append(g_pixel.item())
+            # Total loss
+            g_loss = 0.001 * g_adv + 0.999 * g_pixel
+
+            g_loss.backward()
+            optimizer_G.step()
+            scheduler_G.step(epoch)
+
+            # ---------------------
+            #  Train Discriminator
+            # ---------------------
+
+            optimizer_D.zero_grad()
+
+            # Measure discriminator's ability to classify real from generated samples
+
+            real_loss = adversarial_loss(torch.masked_select(discriminator(masked_parts), mask_bool), valid)
+            fake_loss = adversarial_loss(torch.masked_select(discriminator(gen_parts.detach()), mask_bool), fake)
+            d_loss = 0.5 * (real_loss + fake_loss)
+            d_losses.append(d_loss.item())
+            # wasserstein critic loss
+
+            # d_loss = -torch.mean(discriminator(masked_parts)) + torch.mean(discriminator(gen_parts.detach()))
+
+            d_loss.backward()
+            optimizer_D.step()
+            scheduler_D.step(epoch)
+            if not printed and epoch % 10 == 0:
+                torch.save(generator.state_dict(), f"Predictors/Inpainters/no-pretrain-deeplab-generator-{epoch}")
+                torch.save(discriminator.state_dict(),
+                           f"Predictors/Inpainters/no-pretrain-deeplab-discriminator-{epoch}")
+                plt.title("Part")
+                plt.imshow((gen_parts[0].detach().cpu().numpy().T))
+                plt.show()
+                # plt.title("Superimposed")
+                # plt.imshow((gen_parts[0].detach().cpu().numpy().T))
+                # plt.imshow(masked_imgs[0].detach().cpu().numpy().T)
+                # plt.show()
+                # plt.title("Real")
+                # plt.imshow(masked_parts[0].detach().cpu().numpy().T)
+                # plt.show()
+                try:
+                    test = Inpainter(f"Predictors/Inpainters/no-pretrain-deeplab-generator-{epoch}")
+                    test.get_test()
+                except FileNotFoundError:
+                    print("Weird...")
+                printed = True
+        print(
+            f"[Epoch {epoch}] [D loss: {np.mean(d_losses)}] [G adv: {np.mean(g_advs)}, pixel: {np.mean(g_pixels)}]"
+        )
+
+
+if __name__ == '__main__':
+    train_new_inpainter()