Switch to unified view

a b/experiments/reconstruction_test.py
1
import matplotlib.pyplot as plt
2
3
from models.segmentation_models import *
4
from data.hyperkvasir import KvasirSegmentationDataset
5
from torch.utils.data import DataLoader
6
import torch
7
import torch.nn as nn
8
import copy
9
10
11
class SplicedReconstructor(nn.Module):
12
    def __init__(self):
13
        super(SplicedReconstructor, self).__init__()
14
        inductivenet = InductiveNet()
15
        inductivenet.load_state_dict(torch.load("Predictors/Augmented/InductiveNet/consistency_1"))
16
        self.decoder = copy.deepcopy(inductivenet.reconstruction_decoder)
17
        self.head = copy.deepcopy(inductivenet.reconstruction_head)
18
        del inductivenet
19
        deeplab = DeepLab()
20
        deeplab.load_state_dict(torch.load("Predictors/Augmented/DeepLab/consistency_1"))
21
        self.encoder = copy.deepcopy(deeplab.encoder)
22
        del deeplab
23
24
    def predict(self, x):
25
        features = self.encoder(x)
26
        reconstructor_output = self.decoder(*features)
27
        reconstructed = self.head(reconstructor_output)
28
        return reconstructed
29
30
31
if __name__ == '__main__':
32
    model = SplicedReconstructor().to("cuda").eval()
33
34
    for x, y, _ in DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir/", "test")):
35
        with torch.no_grad():
36
            reconstruction = model.predict(x.to("cuda")).cpu()
37
        fig, ax = plt.subplots(ncols=1, nrows=2, sharey=True, sharex=True, figsize=(2, 1), dpi=1000)
38
        fig.subplots_adjust(wspace=0, hspace=0)
39
        ax[0].imshow(reconstruction[0].T)
40
        ax[1].imshow(x[0].T)
41
        plt.show()
42
        print("Showing...")