Diff of /perturbation/model.py [000000] .. [92cc18]

Switch to unified view

a b/perturbation/model.py
1
import albumentations as alb
2
import matplotlib.pyplot as plt
3
import numpy as np
4
import torch
5
import torch.nn as nn
6
from perturbation.polyp_inpainter import *
7
from perturbation.perturbator import RandomDraw
8
9
10
class ModelOfNaturalVariationInpainter(nn.Module):
11
    def __init__(self, T0=0, use_inpainter=False):
12
        super(ModelOfNaturalVariationInpainter, self).__init__()
13
        self.temp = T0
14
        self.linstep = 0.1
15
        self.use_inpainter = use_inpainter
16
17
        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
18
19
    def _map_to_range(self, max, min=0):
20
        # The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation
21
        # Obviously, having for example a 100% reduction in brightness is not really productive.
22
        return min + self.temp * (max - min)
23
24
    def get_encoded_transforms(self):
25
        quality_lower = max(100 - int(self._map_to_range(max=100)), 10)
26
        quality_upper = np.clip(quality_lower + 10, 10, 100)
27
        pixelwise = alb.Compose([
28
            alb.ColorJitter(brightness=self._map_to_range(max=0.2),
29
                            contrast=self._map_to_range(max=0.2),
30
                            saturation=self._map_to_range(max=0.2),
31
                            hue=self._map_to_range(max=0.05), p=0.5),
32
            alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=0.5),
33
            alb.ImageCompression(quality_lower=quality_lower,
34
                                 quality_upper=quality_upper,
35
                                 p=self.temp),
36
        ]
37
        )
38
        geometric = alb.Compose([alb.RandomRotate90(p=0.5),
39
                                 alb.Flip(p=0.5),
40
                                 alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)]
41
                                )
42
        return pixelwise, geometric
43
44
    def forward(self, image, mask):
45
        assert len(image.shape) == 4, "Image must be in BxCxHxW format"
46
        augmented_imgs = torch.zeros_like(image)
47
        augmented_masks = torch.zeros_like(mask)
48
49
        for batch_idx in range(image.shape[0]):  # random transforms to every image in the batch
50
            aug_img = image[batch_idx].squeeze().cpu().numpy().T
51
            aug_mask = mask[batch_idx].squeeze().cpu().numpy().T
52
            # if np.random.rand() < self.temp and self.use_inpainter:
53
            #     # todo integrate inpainting
54
            #     inpainting_mask_numpy = self.perturbator(rad=0.25)
55
            #     inpainting_mask = torch.from_numpy(inpainting_mask_numpy).unsqueeze(0).to("cuda").float()
56
            #     with torch.no_grad():
57
            #         aug_img, polyp = self.inpainter(img=image[batch_idx], mask=inpainting_mask)
58
            #     aug_img = aug_img[0].cpu().numpy().T  # TODO fix this filth
59
            #     aug_mask = np.clip(aug_mask + inpainting_mask_numpy, 0, 1)
60
            pixelwise = self.pixelwise_augments(image=aug_img)["image"]
61
            geoms = self.geometric_augments(image=pixelwise, mask=aug_mask)
62
            augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T)
63
            augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T)
64
        return augmented_imgs, augmented_masks
65
66
    def step(self):
67
        self.temp = np.clip(self.temp + self.linstep, 0, 1)
68
        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
69
70
71
class ModelOfNaturalVariation(nn.Module):
72
    def __init__(self, T0=0, use_inpainter=False):
73
        super(ModelOfNaturalVariation, self).__init__()
74
        self.temp = T0
75
        self.linstep = 0.1
76
        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
77
        self.use_inpainter = use_inpainter
78
        if use_inpainter:
79
            self.inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990")
80
81
    def _map_to_range(self, max, min=0):
82
        # The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation
83
        # Obviously, having for example a 100% reduction in brightness is not really productive.
84
        return min + self.temp * (max - min)
85
86
    def get_encoded_transforms(self):
87
        quality_lower = max(100 - int(self._map_to_range(max=100)), 10)
88
        quality_upper = np.clip(quality_lower + 10, 10, 100)
89
        pixelwise = alb.Compose([
90
            alb.ColorJitter(brightness=self._map_to_range(max=0.2),
91
                            contrast=self._map_to_range(max=0.2),
92
                            saturation=self._map_to_range(max=0.2),
93
                            hue=self._map_to_range(max=0.05), p=self.temp),
94
            alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=self.temp),
95
            alb.ImageCompression(quality_lower=quality_lower,
96
                                 quality_upper=quality_upper,
97
                                 p=self.temp)
98
        ]
99
        )
100
        geometric = alb.Compose([alb.RandomRotate90(p=self.temp),
101
                                 alb.Flip(p=self.temp),
102
                                 alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)])
103
        return pixelwise, geometric
104
105
    def forward(self, image, mask):
106
        # assert len(image.shape) == 4, "Image must be in BxCxHxW format"
107
108
        augmented_imgs = torch.zeros_like(image)
109
        augmented_masks = torch.zeros_like(mask)
110
        for batch_idx in range(image.shape[0]):  # random transforms to every image in the batch
111
            aug_img = image[batch_idx].squeeze().cpu().numpy().T
112
            aug_mask = mask[batch_idx].squeeze().cpu().numpy().T
113
            if self.use_inpainter and np.random.rand(1) < 1:
114
                aug_img, aug_mask = self.inpainter.add_polyp(aug_img, aug_mask)
115
            pixelwise = self.pixelwise_augments(image=aug_img)["image"]
116
            geoms = self.geometric_augments(image=pixelwise, mask=aug_mask)
117
            augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T)
118
            augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T)
119
        return augmented_imgs, augmented_masks
120
121
    def set_temp(self, temp):
122
        self.temp = temp
123
        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
124
125
    def step(self):
126
        self.temp = np.clip(self.temp + self.linstep, 0, 1)
127
        self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms()
128
129
130
if __name__ == '__main__':
131
    from data.hyperkvasir import KvasirSegmentationDataset
132
    from torch.utils.data import DataLoader
133
134
    mnv = ModelOfNaturalVariation(1)
135
    for x, y, fname in DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", augment=False)):
136
        img = x
137
        mask = y
138
        aug_img, aug_mask = mnv(img, mask)
139
        plt.imshow(x[0].T)
140
        plt.axis("off")
141
        plt.savefig(f"experiments/Data/augmentation_samples/unaugmented_{fname}.png", bbox_inches='tight')
142
        plt.imshow(aug_img[0].T)
143
        plt.axis("off")
144
        plt.savefig(f"experiments/Data/augmentation_samples/augmented_{fname}.png", bbox_inches='tight')