--- a +++ b/experiments/test_inpainter.py @@ -0,0 +1,24 @@ +import matplotlib.pyplot as plt + +from models.segmentation_models import DeepLab +from data.hyperkvasir import KvasirSyntheticDataset +import torch +from torch.utils.data import DataLoader + + +def test(): + model = DeepLab().to("cuda") + model.load_state_dict(torch.load("Predictors/Augmented/DeepLab/inpainter_augmentation_1")) + for x, y, _ in DataLoader(KvasirSyntheticDataset("Datasets/HyperKvasir")): + x = x.to("cuda") + y = y.to("cuda") + out = model.predict(x) + plt.imshow(x[0].cpu().T) + plt.axis("off") + plt.show(bbox_inches='tight', pad_inches=0) + plt.imshow(y[0].cpu().T, alpha=0.5) + plt.show() + + +if __name__ == '__main__': + test()