|
a |
|
b/experiments/test_inpainter.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
|
|
|
3 |
from models.segmentation_models import DeepLab |
|
|
4 |
from data.hyperkvasir import KvasirSyntheticDataset |
|
|
5 |
import torch |
|
|
6 |
from torch.utils.data import DataLoader |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def test(): |
|
|
10 |
model = DeepLab().to("cuda") |
|
|
11 |
model.load_state_dict(torch.load("Predictors/Augmented/DeepLab/inpainter_augmentation_1")) |
|
|
12 |
for x, y, _ in DataLoader(KvasirSyntheticDataset("Datasets/HyperKvasir")): |
|
|
13 |
x = x.to("cuda") |
|
|
14 |
y = y.to("cuda") |
|
|
15 |
out = model.predict(x) |
|
|
16 |
plt.imshow(x[0].cpu().T) |
|
|
17 |
plt.axis("off") |
|
|
18 |
plt.show(bbox_inches='tight', pad_inches=0) |
|
|
19 |
plt.imshow(y[0].cpu().T, alpha=0.5) |
|
|
20 |
plt.show() |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
if __name__ == '__main__': |
|
|
24 |
test() |