Diff of /utils.py [000000] .. [968c76]

Switch to unified view

a b/utils.py
1
import numpy as np
2
3
from functools import lru_cache
4
from scipy.stats import multivariate_normal
5
6
7
@lru_cache()
8
def get_gauss_pdf(sigma):
9
    n = sigma * 8
10
11
    x, y = np.mgrid[0:n, 0:n]
12
    pos = np.empty(x.shape + (2,))
13
    pos[:, :, 0] = x
14
    pos[:, :, 1] = y
15
16
    rv = multivariate_normal([n / 2, n / 2], [[sigma ** 2, 0], [0, sigma ** 2]])
17
    pdf = rv.pdf(pos)
18
19
    heatmap = pdf / np.max(pdf)
20
21
    return heatmap
22
23
24
def to_int(num):
25
    return int(round(num))
26
27
28
@lru_cache()
29
def get_binary_mask(diameter):
30
    d = diameter
31
    _map = np.zeros((d, d), dtype = np.float32)
32
33
    r = d / 2
34
    s = int(d / 2)
35
36
    y, x = np.ogrid[-s:d - s, -s:d - s]
37
    mask = x * x + y * y <= r * r
38
39
    _map[mask] = 1.0
40
41
    return _map
42
43
44
def get_binary_heat_map(shape, is_present, centers, diameter = 9):
45
    n = diameter
46
    r = int(n / 2)
47
    hn = int(2 * n)
48
    qn = int(4 * n)
49
    pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32)
50
51
    for i in range(shape[0]):
52
        for j in range(shape[3]):
53
            my = centers[i, 0, j] - r
54
            mx = centers[i, 1, j] - r
55
56
            if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]:
57
                pl[i, my + hn:my + 3 * n, mx + hn:mx + 3 * n, j] = get_binary_mask(diameter)
58
59
    return pl[:, hn:-hn, hn:-hn, :]
60
61
62
def get_gauss_heat_map(shape, is_present, mean, sigma = 5):
63
    n = to_int(sigma * 8)
64
    hn = to_int(n / 2)
65
    dn = int(2 * n)
66
    qn = int(4 * n)
67
    pl = np.zeros((shape[0], shape[1] + qn, shape[2] + qn, shape[3]), dtype = np.float32)
68
69
    for i in range(shape[0]):
70
        for j in range(shape[3]):
71
            my = mean[i, 0, j] - hn
72
            mx = mean[i, 1, j] - hn
73
74
            if -n < my < shape[1] and -n < mx < shape[2] and is_present[i, j]:
75
                pl[i, my + dn:my + 3 * n, mx + dn:mx + 3 * n, j] = get_gauss_pdf(sigma)
76
                # else:
77
                #     print(my, mx)
78
79
    return pl[:, dn:-dn, dn:-dn, :]