|
a |
|
b/tests/tools/test_sa.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pytest |
|
|
3 |
import statsmodels |
|
|
4 |
from lifelines import ( |
|
|
5 |
CoxPHFitter, |
|
|
6 |
KaplanMeierFitter, |
|
|
7 |
LogLogisticAFTFitter, |
|
|
8 |
NelsonAalenFitter, |
|
|
9 |
WeibullAFTFitter, |
|
|
10 |
WeibullFitter, |
|
|
11 |
) |
|
|
12 |
|
|
|
13 |
import ehrapy as ep |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
@pytest.fixture |
|
|
17 |
def mimic_2_sa(): |
|
|
18 |
adata = ep.dt.mimic_2(encoded=False) |
|
|
19 |
adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) |
|
|
20 |
adata = adata[:, ["mort_day_censored", "censor_flg"]].copy() |
|
|
21 |
duration_col, event_col = "mort_day_censored", "censor_flg" |
|
|
22 |
|
|
|
23 |
return adata, duration_col, event_col |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
class TestSA: |
|
|
27 |
def test_ols(self): |
|
|
28 |
adata = ep.dt.mimic_2(encoded=False) |
|
|
29 |
formula = "tco2_first ~ pco2_first" |
|
|
30 |
var_names = ["tco2_first", "pco2_first"] |
|
|
31 |
ols = ep.tl.ols(adata, var_names, formula, missing="drop") |
|
|
32 |
s = ols.fit().params.iloc[1] |
|
|
33 |
i = ols.fit().params.iloc[0] |
|
|
34 |
assert isinstance(ols, statsmodels.regression.linear_model.OLS) |
|
|
35 |
assert 0.18857179158259973 == pytest.approx(s) |
|
|
36 |
assert 16.210859352601442 == pytest.approx(i) |
|
|
37 |
|
|
|
38 |
def test_glm(self): |
|
|
39 |
adata = ep.dt.mimic_2(encoded=False) |
|
|
40 |
formula = "day_28_flg ~ age" |
|
|
41 |
var_names = ["day_28_flg", "age"] |
|
|
42 |
family = "Binomial" |
|
|
43 |
glm = ep.tl.glm(adata, var_names, formula, family, missing="drop", as_continuous=["age"]) |
|
|
44 |
Intercept = glm.fit().params.iloc[0] |
|
|
45 |
age = glm.fit().params.iloc[1] |
|
|
46 |
assert isinstance(glm, statsmodels.genmod.generalized_linear_model.GLM) |
|
|
47 |
assert 5.778006344870297 == pytest.approx(Intercept) |
|
|
48 |
assert -0.06523274132877163 == pytest.approx(age) |
|
|
49 |
|
|
|
50 |
@pytest.mark.parametrize("weightings", ["wilcoxon", "tarone-ware", "peto", "fleming-harrington"]) |
|
|
51 |
def test_calculate_logrank_pvalue(self, weightings): |
|
|
52 |
durations_A = [1, 2, 3] |
|
|
53 |
event_observed_A = [1, 1, 0] |
|
|
54 |
durations_B = [1, 2, 3, 4] |
|
|
55 |
event_observed_B = [1, 0, 0, 1] |
|
|
56 |
|
|
|
57 |
kmf1 = KaplanMeierFitter() |
|
|
58 |
kmf1.fit(durations_A, event_observed_A) |
|
|
59 |
|
|
|
60 |
kmf2 = KaplanMeierFitter() |
|
|
61 |
kmf2.fit(durations_B, event_observed_B) |
|
|
62 |
|
|
|
63 |
results_pairwise = ep.tl.test_kmf_logrank(kmf1, kmf2) |
|
|
64 |
p_value_pairwise = results_pairwise.p_value |
|
|
65 |
assert 0 < p_value_pairwise < 1 |
|
|
66 |
|
|
|
67 |
def test_anova_glm(self): |
|
|
68 |
adata = ep.dt.mimic_2(encoded=False) |
|
|
69 |
formula = "day_28_flg ~ age" |
|
|
70 |
var_names = ["day_28_flg", "age"] |
|
|
71 |
family = "Binomial" |
|
|
72 |
age_glm = ep.tl.glm(adata, var_names, formula, family, missing="drop", as_continuous=["age"]) |
|
|
73 |
age_glm_result = age_glm.fit() |
|
|
74 |
formula = "day_28_flg ~ age + service_unit" |
|
|
75 |
var_names = ["day_28_flg", "age", "service_unit"] |
|
|
76 |
ageunit_glm = ep.tl.glm(adata, var_names, formula, family="Binomial", missing="drop", as_continuous=["age"]) |
|
|
77 |
ageunit_glm_result = ageunit_glm.fit() |
|
|
78 |
dataframe = ep.tl.anova_glm( |
|
|
79 |
age_glm_result, ageunit_glm_result, "day_28_flg ~ age", "day_28_flg ~ age + service_unit" |
|
|
80 |
) |
|
|
81 |
|
|
|
82 |
assert len(dataframe) == 2 |
|
|
83 |
assert dataframe.shape == (2, 6) |
|
|
84 |
assert dataframe.iloc[1, 4] == 2 |
|
|
85 |
assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185 |
|
|
86 |
|
|
|
87 |
def _sa_function_assert(self, model, model_class, adata=None): |
|
|
88 |
assert isinstance(model, model_class) |
|
|
89 |
assert len(model.durations) == 1776 |
|
|
90 |
assert sum(model.event_observed) == 497 |
|
|
91 |
|
|
|
92 |
if adata is not None: # doing it disway, due to legacy kmf function |
|
|
93 |
model_summary = adata.uns.get("test") |
|
|
94 |
assert model_summary is not None |
|
|
95 |
if isinstance(model, KaplanMeierFitter) or isinstance( |
|
|
96 |
model, NelsonAalenFitter |
|
|
97 |
): # kmf and nelson_aalen have event_table |
|
|
98 |
assert model_summary.equals(model.event_table) |
|
|
99 |
else: |
|
|
100 |
assert model_summary.equals(model.summary) |
|
|
101 |
|
|
|
102 |
def _sa_func_test(self, sa_function, sa_class, mimic_2_sa): |
|
|
103 |
adata, duration_col, event_col = mimic_2_sa |
|
|
104 |
sa = sa_function(adata, duration_col=duration_col, event_col=event_col, uns_key="test") |
|
|
105 |
|
|
|
106 |
self._sa_function_assert(sa, sa_class, adata) |
|
|
107 |
|
|
|
108 |
def test_kmf(self, mimic_2_sa): |
|
|
109 |
# check for deprecation warning |
|
|
110 |
with pytest.warns(DeprecationWarning): |
|
|
111 |
adata, _, _ = mimic_2_sa |
|
|
112 |
kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) |
|
|
113 |
self._sa_function_assert(kmf, KaplanMeierFitter) |
|
|
114 |
|
|
|
115 |
def test_kaplan_meier(self, mimic_2_sa): |
|
|
116 |
self._sa_func_test(ep.tl.kaplan_meier, KaplanMeierFitter, mimic_2_sa) |
|
|
117 |
|
|
|
118 |
def test_cox_ph(self, mimic_2_sa): |
|
|
119 |
self._sa_func_test(ep.tl.cox_ph, CoxPHFitter, mimic_2_sa) |
|
|
120 |
|
|
|
121 |
def test_nelson_aalen(self, mimic_2_sa): |
|
|
122 |
self._sa_func_test(ep.tl.nelson_aalen, NelsonAalenFitter, mimic_2_sa) |
|
|
123 |
|
|
|
124 |
def test_weibull(self, mimic_2_sa): |
|
|
125 |
self._sa_func_test(ep.tl.weibull, WeibullFitter, mimic_2_sa) |
|
|
126 |
|
|
|
127 |
def test_weibull_aft(self, mimic_2_sa): |
|
|
128 |
self._sa_func_test(ep.tl.weibull_aft, WeibullAFTFitter, mimic_2_sa) |
|
|
129 |
|
|
|
130 |
def test_log_logistic(self, mimic_2_sa): |
|
|
131 |
self._sa_func_test(ep.tl.log_logistic_aft, LogLogisticAFTFitter, mimic_2_sa) |