|
a |
|
b/perturbation/model.py |
|
|
1 |
import albumentations as alb |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
import torch.nn as nn |
|
|
6 |
from perturbation.polyp_inpainter import * |
|
|
7 |
from perturbation.perturbator import RandomDraw |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class ModelOfNaturalVariationInpainter(nn.Module): |
|
|
11 |
def __init__(self, T0=0, use_inpainter=False): |
|
|
12 |
super(ModelOfNaturalVariationInpainter, self).__init__() |
|
|
13 |
self.temp = T0 |
|
|
14 |
self.linstep = 0.1 |
|
|
15 |
self.use_inpainter = use_inpainter |
|
|
16 |
|
|
|
17 |
self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms() |
|
|
18 |
|
|
|
19 |
def _map_to_range(self, max, min=0): |
|
|
20 |
# The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation |
|
|
21 |
# Obviously, having for example a 100% reduction in brightness is not really productive. |
|
|
22 |
return min + self.temp * (max - min) |
|
|
23 |
|
|
|
24 |
def get_encoded_transforms(self): |
|
|
25 |
quality_lower = max(100 - int(self._map_to_range(max=100)), 10) |
|
|
26 |
quality_upper = np.clip(quality_lower + 10, 10, 100) |
|
|
27 |
pixelwise = alb.Compose([ |
|
|
28 |
alb.ColorJitter(brightness=self._map_to_range(max=0.2), |
|
|
29 |
contrast=self._map_to_range(max=0.2), |
|
|
30 |
saturation=self._map_to_range(max=0.2), |
|
|
31 |
hue=self._map_to_range(max=0.05), p=0.5), |
|
|
32 |
alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=0.5), |
|
|
33 |
alb.ImageCompression(quality_lower=quality_lower, |
|
|
34 |
quality_upper=quality_upper, |
|
|
35 |
p=self.temp), |
|
|
36 |
] |
|
|
37 |
) |
|
|
38 |
geometric = alb.Compose([alb.RandomRotate90(p=0.5), |
|
|
39 |
alb.Flip(p=0.5), |
|
|
40 |
alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)] |
|
|
41 |
) |
|
|
42 |
return pixelwise, geometric |
|
|
43 |
|
|
|
44 |
def forward(self, image, mask): |
|
|
45 |
assert len(image.shape) == 4, "Image must be in BxCxHxW format" |
|
|
46 |
augmented_imgs = torch.zeros_like(image) |
|
|
47 |
augmented_masks = torch.zeros_like(mask) |
|
|
48 |
|
|
|
49 |
for batch_idx in range(image.shape[0]): # random transforms to every image in the batch |
|
|
50 |
aug_img = image[batch_idx].squeeze().cpu().numpy().T |
|
|
51 |
aug_mask = mask[batch_idx].squeeze().cpu().numpy().T |
|
|
52 |
# if np.random.rand() < self.temp and self.use_inpainter: |
|
|
53 |
# # todo integrate inpainting |
|
|
54 |
# inpainting_mask_numpy = self.perturbator(rad=0.25) |
|
|
55 |
# inpainting_mask = torch.from_numpy(inpainting_mask_numpy).unsqueeze(0).to("cuda").float() |
|
|
56 |
# with torch.no_grad(): |
|
|
57 |
# aug_img, polyp = self.inpainter(img=image[batch_idx], mask=inpainting_mask) |
|
|
58 |
# aug_img = aug_img[0].cpu().numpy().T # TODO fix this filth |
|
|
59 |
# aug_mask = np.clip(aug_mask + inpainting_mask_numpy, 0, 1) |
|
|
60 |
pixelwise = self.pixelwise_augments(image=aug_img)["image"] |
|
|
61 |
geoms = self.geometric_augments(image=pixelwise, mask=aug_mask) |
|
|
62 |
augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T) |
|
|
63 |
augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T) |
|
|
64 |
return augmented_imgs, augmented_masks |
|
|
65 |
|
|
|
66 |
def step(self): |
|
|
67 |
self.temp = np.clip(self.temp + self.linstep, 0, 1) |
|
|
68 |
self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms() |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
class ModelOfNaturalVariation(nn.Module): |
|
|
72 |
def __init__(self, T0=0, use_inpainter=False): |
|
|
73 |
super(ModelOfNaturalVariation, self).__init__() |
|
|
74 |
self.temp = T0 |
|
|
75 |
self.linstep = 0.1 |
|
|
76 |
self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms() |
|
|
77 |
self.use_inpainter = use_inpainter |
|
|
78 |
if use_inpainter: |
|
|
79 |
self.inpainter = Inpainter("Predictors/Inpainters/no-pretrain-deeplab-generator-4990") |
|
|
80 |
|
|
|
81 |
def _map_to_range(self, max, min=0): |
|
|
82 |
# The temperature varies between 0 and 1, where 0 represents no augmentation and 1 represents the maximum augmentation |
|
|
83 |
# Obviously, having for example a 100% reduction in brightness is not really productive. |
|
|
84 |
return min + self.temp * (max - min) |
|
|
85 |
|
|
|
86 |
def get_encoded_transforms(self): |
|
|
87 |
quality_lower = max(100 - int(self._map_to_range(max=100)), 10) |
|
|
88 |
quality_upper = np.clip(quality_lower + 10, 10, 100) |
|
|
89 |
pixelwise = alb.Compose([ |
|
|
90 |
alb.ColorJitter(brightness=self._map_to_range(max=0.2), |
|
|
91 |
contrast=self._map_to_range(max=0.2), |
|
|
92 |
saturation=self._map_to_range(max=0.2), |
|
|
93 |
hue=self._map_to_range(max=0.05), p=self.temp), |
|
|
94 |
alb.GaussNoise(var_limit=self._map_to_range(max=0.01), p=self.temp), |
|
|
95 |
alb.ImageCompression(quality_lower=quality_lower, |
|
|
96 |
quality_upper=quality_upper, |
|
|
97 |
p=self.temp) |
|
|
98 |
] |
|
|
99 |
) |
|
|
100 |
geometric = alb.Compose([alb.RandomRotate90(p=self.temp), |
|
|
101 |
alb.Flip(p=self.temp), |
|
|
102 |
alb.OpticalDistortion(distort_limit=self.temp, p=self.temp)]) |
|
|
103 |
return pixelwise, geometric |
|
|
104 |
|
|
|
105 |
def forward(self, image, mask): |
|
|
106 |
# assert len(image.shape) == 4, "Image must be in BxCxHxW format" |
|
|
107 |
|
|
|
108 |
augmented_imgs = torch.zeros_like(image) |
|
|
109 |
augmented_masks = torch.zeros_like(mask) |
|
|
110 |
for batch_idx in range(image.shape[0]): # random transforms to every image in the batch |
|
|
111 |
aug_img = image[batch_idx].squeeze().cpu().numpy().T |
|
|
112 |
aug_mask = mask[batch_idx].squeeze().cpu().numpy().T |
|
|
113 |
if self.use_inpainter and np.random.rand(1) < 1: |
|
|
114 |
aug_img, aug_mask = self.inpainter.add_polyp(aug_img, aug_mask) |
|
|
115 |
pixelwise = self.pixelwise_augments(image=aug_img)["image"] |
|
|
116 |
geoms = self.geometric_augments(image=pixelwise, mask=aug_mask) |
|
|
117 |
augmented_imgs[batch_idx] = torch.Tensor(geoms["image"].T) |
|
|
118 |
augmented_masks[batch_idx] = torch.Tensor(geoms["mask"].T) |
|
|
119 |
return augmented_imgs, augmented_masks |
|
|
120 |
|
|
|
121 |
def set_temp(self, temp): |
|
|
122 |
self.temp = temp |
|
|
123 |
self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms() |
|
|
124 |
|
|
|
125 |
def step(self): |
|
|
126 |
self.temp = np.clip(self.temp + self.linstep, 0, 1) |
|
|
127 |
self.pixelwise_augments, self.geometric_augments = self.get_encoded_transforms() |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
if __name__ == '__main__': |
|
|
131 |
from data.hyperkvasir import KvasirSegmentationDataset |
|
|
132 |
from torch.utils.data import DataLoader |
|
|
133 |
|
|
|
134 |
mnv = ModelOfNaturalVariation(1) |
|
|
135 |
for x, y, fname in DataLoader(KvasirSegmentationDataset("Datasets/HyperKvasir", augment=False)): |
|
|
136 |
img = x |
|
|
137 |
mask = y |
|
|
138 |
aug_img, aug_mask = mnv(img, mask) |
|
|
139 |
plt.imshow(x[0].T) |
|
|
140 |
plt.axis("off") |
|
|
141 |
plt.savefig(f"experiments/Data/augmentation_samples/unaugmented_{fname}.png", bbox_inches='tight') |
|
|
142 |
plt.imshow(aug_img[0].T) |
|
|
143 |
plt.axis("off") |
|
|
144 |
plt.savefig(f"experiments/Data/augmentation_samples/augmented_{fname}.png", bbox_inches='tight') |