Switch to unified view

a b/perturbation/perturbator.py
1
import sys
2
3
import matplotlib.pyplot as plt
4
import numpy as np
5
import scipy.ndimage as filt
6
import torch.nn as nn
7
8
from utils import mask_generator
9
10
11
class BezierPolypExtender(nn.Module):
12
    def __init__(self, num_nodes, degree, minimum_distance=50, maximum_distance=100):
13
        super(BezierPolypExtender, self).__init__()
14
        self.num_nodes = num_nodes
15
        self.degree = degree
16
        self.minimum_distance = minimum_distance
17
        self.maximum_distance = maximum_distance
18
        # recursion depth is often exceeded despite low memory usage.
19
        sys.setrecursionlimit(10000)
20
21
    def get_distances_along_edge_from_seed(self, binary_edge_image, current_coord, out, iter=0):
22
        if iter > self.maximum_distance + self.minimum_distance:
23
            return out
24
        for xoff in (-1, 0, 1):
25
            for yoff in (-1, 0, 1):
26
                try:
27
                    if binary_edge_image[current_coord[0] + xoff, current_coord[1] + yoff] != 0 and out[
28
                        current_coord[0] + xoff, current_coord[1] + yoff] == 0:
29
                        iter += 1
30
                        out[current_coord[0] + xoff, current_coord[1] + yoff] = iter
31
                        self.get_distances_along_edge_from_seed(binary_edge_image,
32
                                                                [current_coord[0] + xoff, current_coord[1] + yoff], out,
33
                                                                iter)
34
                except IndexError:
35
                    print("continuing...")
36
                    continue
37
        return out
38
39
    def forward(self, original_mask):
40
        # select seed from edge pixels
41
        edges = ((filt.sobel(original_mask, axis=-1) ** 2 + filt.sobel(original_mask, axis=-2) ** 2) != 0).astype(int)
42
        edge_indexes = np.argwhere(edges == 1)
43
        plt.imshow(edges)
44
        plt.title(np.unique(edges))
45
        plt.savefig("edges")
46
        seed = edge_indexes[np.random.choice(range(edge_indexes.shape[0]), 1)][0]
47
        # generate edge image with proximity to seed
48
        proximity_image = np.zeros_like(edges)  # todo change to distance via contour
49
        proximity_image = self.get_distances_along_edge_from_seed(edges, seed, proximity_image)
50
        # i = 0
51
        # current_coords = seed.copy()
52
        # # todo perform thining ahead of iteration over contour to prevent gettign stuck
53
        # retrace_node = seed.copy()
54
        # while i < self.minimum_distance + self.maximum_distance:
55
        #     found_new = False
56
        #     print(f"{i}/{self.minimum_distance + self.maximum_distance}")
57
        #     for xoff in (-1, 0, 1):
58
        #         for yoff in (-1, 0, 1):
59
        #             if proximity_image[current_coords[0] + xoff, current_coords[1] + yoff] == 0 and edges[
60
        #                 current_coords[0] + xoff, current_coords[1] + yoff] != 0:
61
        #                 i += 1
62
        #                 proximity_image[current_coords[0] + xoff, current_coords[1] + yoff] = i
63
        #                 retrace_node = current_coords
64
        #                 current_coords = [current_coords[0] + xoff, current_coords[1] + yoff]
65
        #                 found_new = True
66
        #     if not found_new:
67
        #         print("failed to find new!")
68
        #         current_coords = retrace_node
69
        #         plt.imshow(edges, alpha=0.5)
70
        #         plt.imshow(proximity_image, alpha=0.5)
71
        #         plt.savefig("wtf.png")
72
        #         plt.show()
73
        plt.imshow(proximity_image)
74
        plt.show()  # for x in np.arange(proximity_image.shape[0]):
75
        #     for y in np.arange(proximity_image.shape[1]):
76
        #         if edges[x, y] == 1:
77
        #             proximity_image[x, y] = np.linalg.norm(seed - np.array([[x, y]]))
78
79
        # convert to pdf
80
        pdf = (np.max(proximity_image) - proximity_image)
81
        pdf[proximity_image < self.minimum_distance] = 0  # controls minimum abberation size
82
        pdf = pdf / np.sum(pdf)  # normalize
83
        ac_idx = np.argwhere(pdf != 0)
84
        probs = pdf[ac_idx[:, 0], ac_idx[:, 1]]
85
        anchorpoint = ac_idx[np.random.choice(range(ac_idx.shape[0]), 1, p=probs)]
86
87
        plt.imshow(pdf)
88
        plt.colorbar()
89
        plt.scatter(y=seed[:, 0], x=seed[:, 1], marker="X")
90
        plt.scatter(y=anchorpoint[:, 0], x=anchorpoint[:, 1], marker="o")
91
92
        plt.show()
93
94
95
class RandomDraw(nn.Module):
96
    def __init__(self):
97
        super(RandomDraw, self).__init__()
98
99
    def forward(self, rad):
100
        return mask_generator.generate_a_mask(rad=rad)