a b/tests/tools/test_sa.py
1
import numpy as np
2
import pytest
3
import statsmodels
4
from lifelines import (
5
    CoxPHFitter,
6
    KaplanMeierFitter,
7
    LogLogisticAFTFitter,
8
    NelsonAalenFitter,
9
    WeibullAFTFitter,
10
    WeibullFitter,
11
)
12
13
import ehrapy as ep
14
15
16
@pytest.fixture
17
def mimic_2_sa():
18
    adata = ep.dt.mimic_2(encoded=False)
19
    adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
20
    adata = adata[:, ["mort_day_censored", "censor_flg"]].copy()
21
    duration_col, event_col = "mort_day_censored", "censor_flg"
22
23
    return adata, duration_col, event_col
24
25
26
class TestSA:
27
    def test_ols(self):
28
        adata = ep.dt.mimic_2(encoded=False)
29
        formula = "tco2_first ~ pco2_first"
30
        var_names = ["tco2_first", "pco2_first"]
31
        ols = ep.tl.ols(adata, var_names, formula, missing="drop")
32
        s = ols.fit().params.iloc[1]
33
        i = ols.fit().params.iloc[0]
34
        assert isinstance(ols, statsmodels.regression.linear_model.OLS)
35
        assert 0.18857179158259973 == pytest.approx(s)
36
        assert 16.210859352601442 == pytest.approx(i)
37
38
    def test_glm(self):
39
        adata = ep.dt.mimic_2(encoded=False)
40
        formula = "day_28_flg ~ age"
41
        var_names = ["day_28_flg", "age"]
42
        family = "Binomial"
43
        glm = ep.tl.glm(adata, var_names, formula, family, missing="drop", as_continuous=["age"])
44
        Intercept = glm.fit().params.iloc[0]
45
        age = glm.fit().params.iloc[1]
46
        assert isinstance(glm, statsmodels.genmod.generalized_linear_model.GLM)
47
        assert 5.778006344870297 == pytest.approx(Intercept)
48
        assert -0.06523274132877163 == pytest.approx(age)
49
50
    @pytest.mark.parametrize("weightings", ["wilcoxon", "tarone-ware", "peto", "fleming-harrington"])
51
    def test_calculate_logrank_pvalue(self, weightings):
52
        durations_A = [1, 2, 3]
53
        event_observed_A = [1, 1, 0]
54
        durations_B = [1, 2, 3, 4]
55
        event_observed_B = [1, 0, 0, 1]
56
57
        kmf1 = KaplanMeierFitter()
58
        kmf1.fit(durations_A, event_observed_A)
59
60
        kmf2 = KaplanMeierFitter()
61
        kmf2.fit(durations_B, event_observed_B)
62
63
        results_pairwise = ep.tl.test_kmf_logrank(kmf1, kmf2)
64
        p_value_pairwise = results_pairwise.p_value
65
        assert 0 < p_value_pairwise < 1
66
67
    def test_anova_glm(self):
68
        adata = ep.dt.mimic_2(encoded=False)
69
        formula = "day_28_flg ~ age"
70
        var_names = ["day_28_flg", "age"]
71
        family = "Binomial"
72
        age_glm = ep.tl.glm(adata, var_names, formula, family, missing="drop", as_continuous=["age"])
73
        age_glm_result = age_glm.fit()
74
        formula = "day_28_flg ~ age + service_unit"
75
        var_names = ["day_28_flg", "age", "service_unit"]
76
        ageunit_glm = ep.tl.glm(adata, var_names, formula, family="Binomial", missing="drop", as_continuous=["age"])
77
        ageunit_glm_result = ageunit_glm.fit()
78
        dataframe = ep.tl.anova_glm(
79
            age_glm_result, ageunit_glm_result, "day_28_flg ~ age", "day_28_flg ~ age + service_unit"
80
        )
81
82
        assert len(dataframe) == 2
83
        assert dataframe.shape == (2, 6)
84
        assert dataframe.iloc[1, 4] == 2
85
        assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185
86
87
    def _sa_function_assert(self, model, model_class, adata=None):
88
        assert isinstance(model, model_class)
89
        assert len(model.durations) == 1776
90
        assert sum(model.event_observed) == 497
91
92
        if adata is not None:  # doing it disway, due to legacy kmf function
93
            model_summary = adata.uns.get("test")
94
            assert model_summary is not None
95
            if isinstance(model, KaplanMeierFitter) or isinstance(
96
                model, NelsonAalenFitter
97
            ):  # kmf and nelson_aalen have event_table
98
                assert model_summary.equals(model.event_table)
99
            else:
100
                assert model_summary.equals(model.summary)
101
102
    def _sa_func_test(self, sa_function, sa_class, mimic_2_sa):
103
        adata, duration_col, event_col = mimic_2_sa
104
        sa = sa_function(adata, duration_col=duration_col, event_col=event_col, uns_key="test")
105
106
        self._sa_function_assert(sa, sa_class, adata)
107
108
    def test_kmf(self, mimic_2_sa):
109
        # check for deprecation warning
110
        with pytest.warns(DeprecationWarning):
111
            adata, _, _ = mimic_2_sa
112
            kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
113
            self._sa_function_assert(kmf, KaplanMeierFitter)
114
115
    def test_kaplan_meier(self, mimic_2_sa):
116
        self._sa_func_test(ep.tl.kaplan_meier, KaplanMeierFitter, mimic_2_sa)
117
118
    def test_cox_ph(self, mimic_2_sa):
119
        self._sa_func_test(ep.tl.cox_ph, CoxPHFitter, mimic_2_sa)
120
121
    def test_nelson_aalen(self, mimic_2_sa):
122
        self._sa_func_test(ep.tl.nelson_aalen, NelsonAalenFitter, mimic_2_sa)
123
124
    def test_weibull(self, mimic_2_sa):
125
        self._sa_func_test(ep.tl.weibull, WeibullFitter, mimic_2_sa)
126
127
    def test_weibull_aft(self, mimic_2_sa):
128
        self._sa_func_test(ep.tl.weibull_aft, WeibullAFTFitter, mimic_2_sa)
129
130
    def test_log_logistic(self, mimic_2_sa):
131
        self._sa_func_test(ep.tl.log_logistic_aft, LogLogisticAFTFitter, mimic_2_sa)