a b/perturbation/polyp_inpainter.py
1
import matplotlib.pyplot as plt
2
import torch
3
import torch.nn as nn
4
from torch.utils.data.dataloader import DataLoader
5
6
# from perturbation.gmcnn.model.net_with_dropout import InpaintingModel_GMCNN_Given_Mask
7
from data.hyperkvasir import *
8
from models.inpainters import SegGenerator
9
from perturbation.perturbator import RandomDraw
10
11
12
class Inpainter:
13
    # wrapper around gmcnn
14
    def __init__(self, path_to_state_dict):
15
        super(Inpainter, self).__init__()
16
        self.config = None
17
        self.model = SegGenerator()
18
        self.model.load_state_dict(torch.load(path_to_state_dict))
19
        self.model.to("cuda")
20
        self.model.eval()
21
        self.perturbator = RandomDraw()
22
23
    def forward(self, img, mask, masked_image=None):
24
        mask = mask.unsqueeze(1)
25
        masked_image = img * (1 - mask)
26
        polyp = self.model(masked_image)
27
        merged = (1 - mask) * img + (polyp * mask)
28
        return merged, polyp
29
30
    def add_polyp(self, img, old_mask_a):
31
        new_polyp_mask = np.expand_dims(self.perturbator(rad=0.25), -1)
32
        old_mask = np.expand_dims(old_mask_a, -1)
33
        total_mask = np.clip(new_polyp_mask + old_mask, 0, 1)
34
        masked = img * (1 - new_polyp_mask)
35
        with torch.no_grad():
36
            polyp = self.model(torch.Tensor(masked).to("cuda").T.unsqueeze(0))
37
38
        cpu_polyp = polyp.cpu().squeeze(0).T.numpy()
39
        inpainted = masked + (cpu_polyp * new_polyp_mask)
40
        return inpainted.astype(np.float32), total_mask.astype(np.float32)
41
42
    def get_test(self, split="test"):
43
        for i, (image, mask, masked_image, part, fname) in enumerate(
44
                DataLoader(KvasirSyntheticDataset("Datasets/HyperKvasir", split="test"),
45
                           batch_size=4)):
46
            merged, mask = self.add_polyp(image, mask)
47
            plt.title("Inpainted image")
48
            plt.imshow(merged[0].T)
49
            plt.show()
50
            # plt.savefig(f"perturbation/inpaint_examples/{i}")
51
52
53
if __name__ == '__main__':
54
    inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990")
55
    inpainter.get_test()