|
a |
|
b/perturbation/polyp_inpainter.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
from torch.utils.data.dataloader import DataLoader |
|
|
5 |
|
|
|
6 |
# from perturbation.gmcnn.model.net_with_dropout import InpaintingModel_GMCNN_Given_Mask |
|
|
7 |
from data.hyperkvasir import * |
|
|
8 |
from models.inpainters import SegGenerator |
|
|
9 |
from perturbation.perturbator import RandomDraw |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class Inpainter: |
|
|
13 |
# wrapper around gmcnn |
|
|
14 |
def __init__(self, path_to_state_dict): |
|
|
15 |
super(Inpainter, self).__init__() |
|
|
16 |
self.config = None |
|
|
17 |
self.model = SegGenerator() |
|
|
18 |
self.model.load_state_dict(torch.load(path_to_state_dict)) |
|
|
19 |
self.model.to("cuda") |
|
|
20 |
self.model.eval() |
|
|
21 |
self.perturbator = RandomDraw() |
|
|
22 |
|
|
|
23 |
def forward(self, img, mask, masked_image=None): |
|
|
24 |
mask = mask.unsqueeze(1) |
|
|
25 |
masked_image = img * (1 - mask) |
|
|
26 |
polyp = self.model(masked_image) |
|
|
27 |
merged = (1 - mask) * img + (polyp * mask) |
|
|
28 |
return merged, polyp |
|
|
29 |
|
|
|
30 |
def add_polyp(self, img, old_mask_a): |
|
|
31 |
new_polyp_mask = np.expand_dims(self.perturbator(rad=0.25), -1) |
|
|
32 |
old_mask = np.expand_dims(old_mask_a, -1) |
|
|
33 |
total_mask = np.clip(new_polyp_mask + old_mask, 0, 1) |
|
|
34 |
masked = img * (1 - new_polyp_mask) |
|
|
35 |
with torch.no_grad(): |
|
|
36 |
polyp = self.model(torch.Tensor(masked).to("cuda").T.unsqueeze(0)) |
|
|
37 |
|
|
|
38 |
cpu_polyp = polyp.cpu().squeeze(0).T.numpy() |
|
|
39 |
inpainted = masked + (cpu_polyp * new_polyp_mask) |
|
|
40 |
return inpainted.astype(np.float32), total_mask.astype(np.float32) |
|
|
41 |
|
|
|
42 |
def get_test(self, split="test"): |
|
|
43 |
for i, (image, mask, masked_image, part, fname) in enumerate( |
|
|
44 |
DataLoader(KvasirSyntheticDataset("Datasets/HyperKvasir", split="test"), |
|
|
45 |
batch_size=4)): |
|
|
46 |
merged, mask = self.add_polyp(image, mask) |
|
|
47 |
plt.title("Inpainted image") |
|
|
48 |
plt.imshow(merged[0].T) |
|
|
49 |
plt.show() |
|
|
50 |
# plt.savefig(f"perturbation/inpaint_examples/{i}") |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
if __name__ == '__main__': |
|
|
54 |
inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990") |
|
|
55 |
inpainter.get_test() |