--- a +++ b/perturbation/polyp_inpainter.py @@ -0,0 +1,55 @@ +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader + +# from perturbation.gmcnn.model.net_with_dropout import InpaintingModel_GMCNN_Given_Mask +from data.hyperkvasir import * +from models.inpainters import SegGenerator +from perturbation.perturbator import RandomDraw + + +class Inpainter: + # wrapper around gmcnn + def __init__(self, path_to_state_dict): + super(Inpainter, self).__init__() + self.config = None + self.model = SegGenerator() + self.model.load_state_dict(torch.load(path_to_state_dict)) + self.model.to("cuda") + self.model.eval() + self.perturbator = RandomDraw() + + def forward(self, img, mask, masked_image=None): + mask = mask.unsqueeze(1) + masked_image = img * (1 - mask) + polyp = self.model(masked_image) + merged = (1 - mask) * img + (polyp * mask) + return merged, polyp + + def add_polyp(self, img, old_mask_a): + new_polyp_mask = np.expand_dims(self.perturbator(rad=0.25), -1) + old_mask = np.expand_dims(old_mask_a, -1) + total_mask = np.clip(new_polyp_mask + old_mask, 0, 1) + masked = img * (1 - new_polyp_mask) + with torch.no_grad(): + polyp = self.model(torch.Tensor(masked).to("cuda").T.unsqueeze(0)) + + cpu_polyp = polyp.cpu().squeeze(0).T.numpy() + inpainted = masked + (cpu_polyp * new_polyp_mask) + return inpainted.astype(np.float32), total_mask.astype(np.float32) + + def get_test(self, split="test"): + for i, (image, mask, masked_image, part, fname) in enumerate( + DataLoader(KvasirSyntheticDataset("Datasets/HyperKvasir", split="test"), + batch_size=4)): + merged, mask = self.add_polyp(image, mask) + plt.title("Inpainted image") + plt.imshow(merged[0].T) + plt.show() + # plt.savefig(f"perturbation/inpaint_examples/{i}") + + +if __name__ == '__main__': + inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990") + inpainter.get_test()