--- a +++ b/perturbation/perturbator.py @@ -0,0 +1,100 @@ +import sys + +import matplotlib.pyplot as plt +import numpy as np +import scipy.ndimage as filt +import torch.nn as nn + +from utils import mask_generator + + +class BezierPolypExtender(nn.Module): + def __init__(self, num_nodes, degree, minimum_distance=50, maximum_distance=100): + super(BezierPolypExtender, self).__init__() + self.num_nodes = num_nodes + self.degree = degree + self.minimum_distance = minimum_distance + self.maximum_distance = maximum_distance + # recursion depth is often exceeded despite low memory usage. + sys.setrecursionlimit(10000) + + def get_distances_along_edge_from_seed(self, binary_edge_image, current_coord, out, iter=0): + if iter > self.maximum_distance + self.minimum_distance: + return out + for xoff in (-1, 0, 1): + for yoff in (-1, 0, 1): + try: + if binary_edge_image[current_coord[0] + xoff, current_coord[1] + yoff] != 0 and out[ + current_coord[0] + xoff, current_coord[1] + yoff] == 0: + iter += 1 + out[current_coord[0] + xoff, current_coord[1] + yoff] = iter + self.get_distances_along_edge_from_seed(binary_edge_image, + [current_coord[0] + xoff, current_coord[1] + yoff], out, + iter) + except IndexError: + print("continuing...") + continue + return out + + def forward(self, original_mask): + # select seed from edge pixels + edges = ((filt.sobel(original_mask, axis=-1) ** 2 + filt.sobel(original_mask, axis=-2) ** 2) != 0).astype(int) + edge_indexes = np.argwhere(edges == 1) + plt.imshow(edges) + plt.title(np.unique(edges)) + plt.savefig("edges") + seed = edge_indexes[np.random.choice(range(edge_indexes.shape[0]), 1)][0] + # generate edge image with proximity to seed + proximity_image = np.zeros_like(edges) # todo change to distance via contour + proximity_image = self.get_distances_along_edge_from_seed(edges, seed, proximity_image) + # i = 0 + # current_coords = seed.copy() + # # todo perform thining ahead of iteration over contour to prevent gettign stuck + # retrace_node = seed.copy() + # while i < self.minimum_distance + self.maximum_distance: + # found_new = False + # print(f"{i}/{self.minimum_distance + self.maximum_distance}") + # for xoff in (-1, 0, 1): + # for yoff in (-1, 0, 1): + # if proximity_image[current_coords[0] + xoff, current_coords[1] + yoff] == 0 and edges[ + # current_coords[0] + xoff, current_coords[1] + yoff] != 0: + # i += 1 + # proximity_image[current_coords[0] + xoff, current_coords[1] + yoff] = i + # retrace_node = current_coords + # current_coords = [current_coords[0] + xoff, current_coords[1] + yoff] + # found_new = True + # if not found_new: + # print("failed to find new!") + # current_coords = retrace_node + # plt.imshow(edges, alpha=0.5) + # plt.imshow(proximity_image, alpha=0.5) + # plt.savefig("wtf.png") + # plt.show() + plt.imshow(proximity_image) + plt.show() # for x in np.arange(proximity_image.shape[0]): + # for y in np.arange(proximity_image.shape[1]): + # if edges[x, y] == 1: + # proximity_image[x, y] = np.linalg.norm(seed - np.array([[x, y]])) + + # convert to pdf + pdf = (np.max(proximity_image) - proximity_image) + pdf[proximity_image < self.minimum_distance] = 0 # controls minimum abberation size + pdf = pdf / np.sum(pdf) # normalize + ac_idx = np.argwhere(pdf != 0) + probs = pdf[ac_idx[:, 0], ac_idx[:, 1]] + anchorpoint = ac_idx[np.random.choice(range(ac_idx.shape[0]), 1, p=probs)] + + plt.imshow(pdf) + plt.colorbar() + plt.scatter(y=seed[:, 0], x=seed[:, 1], marker="X") + plt.scatter(y=anchorpoint[:, 0], x=anchorpoint[:, 1], marker="o") + + plt.show() + + +class RandomDraw(nn.Module): + def __init__(self): + super(RandomDraw, self).__init__() + + def forward(self, rad): + return mask_generator.generate_a_mask(rad=rad)