Switch to unified view

a b/tests/featurizers/test_OnlineStatistics.py
1
import numpy as np
2
import pytest
3
4
from femr.featurizers.utils import OnlineStatistics
5
6
7
def _assert_correct_stats(stat: OnlineStatistics, values: list):
8
    TOLERANCE = 1e-6  # Allow for some floating point error
9
    true_mean = np.mean(values)
10
    true_sample_variance = np.var(values, ddof=1)
11
    true_m2 = true_sample_variance * (len(values) - 1)
12
    assert stat.current_count == len(values), f"{stat.current_count} != {len(values)}"
13
    assert np.isclose(stat.mean(), true_mean), f"{stat.mean()} != {true_mean}"
14
    assert np.isclose(
15
        stat.variance(), true_sample_variance, atol=TOLERANCE
16
    ), f"{stat.variance()} != {true_sample_variance}"
17
    assert np.isclose(stat.current_M2, true_m2, atol=TOLERANCE), f"{stat.current_M2} != {true_m2}"
18
19
20
def test_add():
21
    # Test adding things to the statistics
22
    def _run_test(values):
23
        stat = OnlineStatistics()
24
        for i in values:
25
            stat.add(i)
26
        _assert_correct_stats(stat, values)
27
28
    # Positive integers
29
    _run_test(range(51))
30
    _run_test(range(10, 10000, 3))
31
    # Negative integers
32
    _run_test(range(-400, -300))
33
    # Positive/negative integers
34
    _run_test(list(range(4, 900, 2)) + list(range(-1000, -300, 7)))
35
    _run_test(list(range(-100, 100, 7)) + list(range(-100, 100, 2)))
36
    # Decimals
37
    _run_test(np.linspace(0, 1, 100))
38
    _run_test(np.logspace(-100, 3, 100))
39
    # Small lists
40
    _run_test([0, 1])
41
    _run_test([-1, 1])
42
43
44
def test_constructor():
45
    # Test default
46
    stat = OnlineStatistics()
47
    assert stat.current_count == 0
48
    assert stat.current_mean == stat.mean() == 0
49
    assert stat.current_M2 == 0
50
51
    # Test explicitly setting args
52
    stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=3)
53
    assert stat.current_count == 1
54
    assert stat.current_mean == stat.mean() == 2
55
    assert stat.current_M2 == 0
56
57
    # Test M2
58
    stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30)
59
    assert stat.current_count == 10
60
    assert stat.current_mean == 20
61
    assert stat.current_M2 == 30 * (10 - 1)
62
63
    # Test getters/setters
64
    stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30)
65
    assert stat.mean() == 20
66
    assert stat.variance() == 30
67
    assert stat.standard_deviation() == np.sqrt(30)
68
69
    # Test fail cases
70
    with pytest.raises(ValueError) as _:
71
        # Negative count
72
        stat = OnlineStatistics(current_count=-1, current_mean=2, current_variance=3)
73
    with pytest.raises(ValueError) as _:
74
        # Negative variance
75
        stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=-3)
76
    with pytest.raises(ValueError) as _:
77
        # Positive variance with 0 count
78
        stat = OnlineStatistics(current_count=0, current_mean=2, current_variance=1)
79
    with pytest.raises(ValueError) as _:
80
        # Can only compute variance with >1 observation
81
        stat = OnlineStatistics()
82
        stat.add(1)
83
        stat.variance()
84
85
86
def test_merge_pair():
87
    # Simulate two statistics being merged via `merge_pair``
88
    stat1 = OnlineStatistics()
89
    values1 = list(range(-300, 300, 4)) + list(range(400, 450))
90
    for i in values1:
91
        stat1.add(i)
92
    stat2 = OnlineStatistics()
93
    values2 = list(range(100, 150))
94
    for i in values2:
95
        stat2.add(i)
96
    merged_stat = OnlineStatistics.merge_pair(stat1, stat2)
97
    merged_stat_values = values1 + values2
98
    _assert_correct_stats(merged_stat, merged_stat_values)
99
100
101
def test_merge():
102
    # Simulate parallel statistics being merged via `merge`
103
    stats = []
104
    values = [
105
        np.linspace(-100, 100, 50),
106
        np.linspace(100, 200, 50),
107
        np.linspace(100, 150, 100),
108
        np.linspace(-10, 100, 100),
109
        np.linspace(10, 200, 3),
110
    ]
111
    for i in range(len(values)):
112
        stat = OnlineStatistics()
113
        for v in values[i]:
114
            stat.add(v)
115
        stats.append(stat)
116
    merged_stat = OnlineStatistics.merge(stats)
117
    _assert_correct_stats(merged_stat, np.concatenate(values))