|
a |
|
b/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
from functools import lru_cache |
|
|
4 |
from scipy.stats import multivariate_normal |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
@lru_cache() |
|
|
8 |
def get_gauss_pdf(sigma): |
|
|
9 |
n = sigma * 8 |
|
|
10 |
|
|
|
11 |
x, y = np.mgrid[0:n, 0:n] |
|
|
12 |
pos = np.empty(x.shape + (2,)) |
|
|
13 |
pos[:, :, 0] = x |
|
|
14 |
pos[:, :, 1] = y |
|
|
15 |
|
|
|
16 |
rv = multivariate_normal([n / 2, n / 2], [[sigma ** 2, 0], [0, sigma ** 2]]) |
|
|
17 |
pdf = rv.pdf(pos) |
|
|
18 |
|
|
|
19 |
heatmap = pdf / np.max(pdf) |
|
|
20 |
|
|
|
21 |
return heatmap |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def to_int(num): |
|
|
25 |
return int(round(num)) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
@lru_cache() |
|
|
29 |
def get_binary_mask(diameter): |
|
|
30 |
d = diameter |
|
|
31 |
_map = np.zeros((d, d), dtype = np.float32) |
|
|
32 |
|
|
|
33 |
r = d / 2 |
|
|
34 |
s = int(d / 2) |
|
|
35 |
|
|
|
36 |
y, x = np.ogrid[-s:d - s, -s:d - s] |
|
|
37 |
mask = x * x + y * y <= r * r |
|
|
38 |
|
|
|
39 |
_map[mask] = 1.0 |
|
|
40 |
|
|
|
41 |
return _map |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def get_binary_heat_map(shape, is_present, centers, diameter = 9): |
|
|
45 |
n = diameter |
|
|
46 |
r = int(n / 2) |
|
|
47 |
hn = int(2 * n) |
|
|
48 |
qn = int(4 * n) |
|
|
49 |
pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32) |
|
|
50 |
|
|
|
51 |
for i in range(shape[0]): |
|
|
52 |
for j in range(shape[3]): |
|
|
53 |
my = centers[i, 0, j] - r |
|
|
54 |
mx = centers[i, 1, j] - r |
|
|
55 |
|
|
|
56 |
if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]: |
|
|
57 |
pl[i, my + hn:my + 3 * n, mx + hn:mx + 3 * n, j] = get_binary_mask(diameter) |
|
|
58 |
|
|
|
59 |
return pl[:, hn:-hn, hn:-hn, :] |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def get_gauss_heat_map(shape, is_present, mean, sigma = 5): |
|
|
63 |
n = to_int(sigma * 8) |
|
|
64 |
hn = to_int(n / 2) |
|
|
65 |
dn = int(2 * n) |
|
|
66 |
qn = int(4 * n) |
|
|
67 |
pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32) |
|
|
68 |
|
|
|
69 |
for i in range(shape[0]): |
|
|
70 |
for j in range(shape[3]): |
|
|
71 |
my = mean[i, 0, j] - hn |
|
|
72 |
mx = mean[i, 1, j] - hn |
|
|
73 |
|
|
|
74 |
if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]: |
|
|
75 |
pl[i, my + dn:my + 3 * n, mx + dn:mx + 3 * n, j] = get_gauss_pdf(sigma) |
|
|
76 |
# else: |
|
|
77 |
# print(my, mx) |
|
|
78 |
|
|
|
79 |
return pl[:, dn:-dn, dn:-dn, :] |