|
a |
|
b/tests/featurizers/test_OnlineStatistics.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pytest |
|
|
3 |
|
|
|
4 |
from femr.featurizers.utils import OnlineStatistics |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def _assert_correct_stats(stat: OnlineStatistics, values: list): |
|
|
8 |
TOLERANCE = 1e-6 # Allow for some floating point error |
|
|
9 |
true_mean = np.mean(values) |
|
|
10 |
true_sample_variance = np.var(values, ddof=1) |
|
|
11 |
true_m2 = true_sample_variance * (len(values) - 1) |
|
|
12 |
assert stat.current_count == len(values), f"{stat.current_count} != {len(values)}" |
|
|
13 |
assert np.isclose(stat.mean(), true_mean), f"{stat.mean()} != {true_mean}" |
|
|
14 |
assert np.isclose( |
|
|
15 |
stat.variance(), true_sample_variance, atol=TOLERANCE |
|
|
16 |
), f"{stat.variance()} != {true_sample_variance}" |
|
|
17 |
assert np.isclose(stat.current_M2, true_m2, atol=TOLERANCE), f"{stat.current_M2} != {true_m2}" |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def test_add(): |
|
|
21 |
# Test adding things to the statistics |
|
|
22 |
def _run_test(values): |
|
|
23 |
stat = OnlineStatistics() |
|
|
24 |
for i in values: |
|
|
25 |
stat.add(i) |
|
|
26 |
_assert_correct_stats(stat, values) |
|
|
27 |
|
|
|
28 |
# Positive integers |
|
|
29 |
_run_test(range(51)) |
|
|
30 |
_run_test(range(10, 10000, 3)) |
|
|
31 |
# Negative integers |
|
|
32 |
_run_test(range(-400, -300)) |
|
|
33 |
# Positive/negative integers |
|
|
34 |
_run_test(list(range(4, 900, 2)) + list(range(-1000, -300, 7))) |
|
|
35 |
_run_test(list(range(-100, 100, 7)) + list(range(-100, 100, 2))) |
|
|
36 |
# Decimals |
|
|
37 |
_run_test(np.linspace(0, 1, 100)) |
|
|
38 |
_run_test(np.logspace(-100, 3, 100)) |
|
|
39 |
# Small lists |
|
|
40 |
_run_test([0, 1]) |
|
|
41 |
_run_test([-1, 1]) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def test_constructor(): |
|
|
45 |
# Test default |
|
|
46 |
stat = OnlineStatistics() |
|
|
47 |
assert stat.current_count == 0 |
|
|
48 |
assert stat.current_mean == stat.mean() == 0 |
|
|
49 |
assert stat.current_M2 == 0 |
|
|
50 |
|
|
|
51 |
# Test explicitly setting args |
|
|
52 |
stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=3) |
|
|
53 |
assert stat.current_count == 1 |
|
|
54 |
assert stat.current_mean == stat.mean() == 2 |
|
|
55 |
assert stat.current_M2 == 0 |
|
|
56 |
|
|
|
57 |
# Test M2 |
|
|
58 |
stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30) |
|
|
59 |
assert stat.current_count == 10 |
|
|
60 |
assert stat.current_mean == 20 |
|
|
61 |
assert stat.current_M2 == 30 * (10 - 1) |
|
|
62 |
|
|
|
63 |
# Test getters/setters |
|
|
64 |
stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30) |
|
|
65 |
assert stat.mean() == 20 |
|
|
66 |
assert stat.variance() == 30 |
|
|
67 |
assert stat.standard_deviation() == np.sqrt(30) |
|
|
68 |
|
|
|
69 |
# Test fail cases |
|
|
70 |
with pytest.raises(ValueError) as _: |
|
|
71 |
# Negative count |
|
|
72 |
stat = OnlineStatistics(current_count=-1, current_mean=2, current_variance=3) |
|
|
73 |
with pytest.raises(ValueError) as _: |
|
|
74 |
# Negative variance |
|
|
75 |
stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=-3) |
|
|
76 |
with pytest.raises(ValueError) as _: |
|
|
77 |
# Positive variance with 0 count |
|
|
78 |
stat = OnlineStatistics(current_count=0, current_mean=2, current_variance=1) |
|
|
79 |
with pytest.raises(ValueError) as _: |
|
|
80 |
# Can only compute variance with >1 observation |
|
|
81 |
stat = OnlineStatistics() |
|
|
82 |
stat.add(1) |
|
|
83 |
stat.variance() |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
def test_merge_pair(): |
|
|
87 |
# Simulate two statistics being merged via `merge_pair`` |
|
|
88 |
stat1 = OnlineStatistics() |
|
|
89 |
values1 = list(range(-300, 300, 4)) + list(range(400, 450)) |
|
|
90 |
for i in values1: |
|
|
91 |
stat1.add(i) |
|
|
92 |
stat2 = OnlineStatistics() |
|
|
93 |
values2 = list(range(100, 150)) |
|
|
94 |
for i in values2: |
|
|
95 |
stat2.add(i) |
|
|
96 |
merged_stat = OnlineStatistics.merge_pair(stat1, stat2) |
|
|
97 |
merged_stat_values = values1 + values2 |
|
|
98 |
_assert_correct_stats(merged_stat, merged_stat_values) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def test_merge(): |
|
|
102 |
# Simulate parallel statistics being merged via `merge` |
|
|
103 |
stats = [] |
|
|
104 |
values = [ |
|
|
105 |
np.linspace(-100, 100, 50), |
|
|
106 |
np.linspace(100, 200, 50), |
|
|
107 |
np.linspace(100, 150, 100), |
|
|
108 |
np.linspace(-10, 100, 100), |
|
|
109 |
np.linspace(10, 200, 3), |
|
|
110 |
] |
|
|
111 |
for i in range(len(values)): |
|
|
112 |
stat = OnlineStatistics() |
|
|
113 |
for v in values[i]: |
|
|
114 |
stat.add(v) |
|
|
115 |
stats.append(stat) |
|
|
116 |
merged_stat = OnlineStatistics.merge(stats) |
|
|
117 |
_assert_correct_stats(merged_stat, np.concatenate(values)) |