--- a +++ b/utils.py @@ -0,0 +1,79 @@ +import numpy as np + +from functools import lru_cache +from scipy.stats import multivariate_normal + + +@lru_cache() +def get_gauss_pdf(sigma): + n = sigma * 8 + + x, y = np.mgrid[0:n, 0:n] + pos = np.empty(x.shape + (2,)) + pos[:, :, 0] = x + pos[:, :, 1] = y + + rv = multivariate_normal([n / 2, n / 2], [[sigma ** 2, 0], [0, sigma ** 2]]) + pdf = rv.pdf(pos) + + heatmap = pdf / np.max(pdf) + + return heatmap + + +def to_int(num): + return int(round(num)) + + +@lru_cache() +def get_binary_mask(diameter): + d = diameter + _map = np.zeros((d, d), dtype = np.float32) + + r = d / 2 + s = int(d / 2) + + y, x = np.ogrid[-s:d - s, -s:d - s] + mask = x * x + y * y <= r * r + + _map[mask] = 1.0 + + return _map + + +def get_binary_heat_map(shape, is_present, centers, diameter = 9): + n = diameter + r = int(n / 2) + hn = int(2 * n) + qn = int(4 * n) + pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32) + + for i in range(shape[0]): + for j in range(shape[3]): + my = centers[i, 0, j] - r + mx = centers[i, 1, j] - r + + if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]: + pl[i, my + hn:my + 3 * n, mx + hn:mx + 3 * n, j] = get_binary_mask(diameter) + + return pl[:, hn:-hn, hn:-hn, :] + + +def get_gauss_heat_map(shape, is_present, mean, sigma = 5): + n = to_int(sigma * 8) + hn = to_int(n / 2) + dn = int(2 * n) + qn = int(4 * n) + pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32) + + for i in range(shape[0]): + for j in range(shape[3]): + my = mean[i, 0, j] - hn + mx = mean[i, 1, j] - hn + + if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]: + pl[i, my + dn:my + 3 * n, mx + dn:mx + 3 * n, j] = get_gauss_pdf(sigma) + # else: + # print(my, mx) + + return pl[:, dn:-dn, dn:-dn, :]