a b/ehrapy/tools/_sa.py
1
from __future__ import annotations
2
3
import warnings
4
from typing import TYPE_CHECKING, Literal
5
6
import numpy as np  # noqa: TC002
7
import pandas as pd
8
import statsmodels.api as sm
9
import statsmodels.formula.api as smf
10
from lifelines import (
11
    CoxPHFitter,
12
    KaplanMeierFitter,
13
    LogLogisticAFTFitter,
14
    NelsonAalenFitter,
15
    WeibullAFTFitter,
16
    WeibullFitter,
17
)
18
from lifelines.statistics import StatisticalResult, logrank_test
19
from scipy import stats
20
21
from ehrapy.anndata import anndata_to_df
22
from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG
23
24
if TYPE_CHECKING:
25
    from collections.abc import Iterable
26
27
    from anndata import AnnData
28
    from statsmodels.genmod.generalized_linear_model import GLMResultsWrapper
29
30
31
def ols(
32
    adata: AnnData,
33
    var_names: list[str] | None | None = None,
34
    formula: str | None = None,
35
    missing: Literal["none", "drop", "raise"] | None = "none",
36
    use_feature_types: bool = False,
37
) -> sm.OLS:
38
    """Create an Ordinary Least Squares (OLS) Model from a formula and AnnData.
39
40
    See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.ols.html#statsmodels.formula.api.ols
41
42
    Args:
43
        adata: The AnnData object for the OLS model.
44
        var_names: A list of var names indicating which columns are for the OLS model.
45
        formula: The formula specifying the model.
46
        use_feature_types: If True, the feature types in the AnnData objects .var are used.
47
        missing: Available options are 'none', 'drop', and 'raise'.
48
                 If 'none', no nan checking is done. If 'drop', any observations with nans are dropped.
49
                 If 'raise', an error is raised.
50
51
    Returns:
52
        The OLS model instance.
53
54
    Examples:
55
        >>> import ehrapy as ep
56
        >>> adata = ep.dt.mimic_2(encoded=False)
57
        >>> formula = "tco2_first ~ pco2_first"
58
        >>> var_names = ["tco2_first", "pco2_first"]
59
        >>> ols = ep.tl.ols(adata, var_names, formula, missing="drop")
60
    """
61
    if isinstance(var_names, list):
62
        data = pd.DataFrame(adata[:, var_names].X, columns=var_names)
63
    else:
64
        data = pd.DataFrame(adata.X, columns=adata.var_names)
65
66
    if use_feature_types:
67
        for col in data.columns:
68
            if col in adata.var.index:
69
                feature_type = adata.var[FEATURE_TYPE_KEY][col]
70
                if feature_type == CATEGORICAL_TAG:
71
                    data[col] = data[col].astype("category")
72
                elif feature_type == NUMERIC_TAG:
73
                    data[col] = data[col].astype(float)
74
    else:
75
        data = data.astype(float)
76
77
    ols = smf.ols(formula, data=data, missing=missing)
78
79
    return ols
80
81
82
def glm(
83
    adata: AnnData,
84
    var_names: Iterable[str] | None = None,
85
    formula: str | None = None,
86
    family: Literal["Gaussian", "Binomial", "Gamma", "Gaussian", "InverseGaussian"] = "Gaussian",
87
    use_feature_types: bool = False,
88
    missing: Literal["none", "drop", "raise"] = "none",
89
    as_continuous: Iterable[str] | None | None = None,
90
) -> sm.GLM:
91
    """Create a Generalized Linear Model (GLM) from a formula, a distribution, and AnnData.
92
93
    See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.glm.html#statsmodels.formula.api.glm
94
95
    Args:
96
        adata: The AnnData object for the GLM model.
97
        var_names: A list of var names indicating which columns are for the GLM model.
98
        formula: The formula specifying the model.
99
        family: The distribution families. Available options are 'Gaussian', 'Binomial', 'Gamma', and 'InverseGaussian'.
100
        use_feature_types: If True, the feature types in the AnnData objects .var are used.
101
        missing: Available options are 'none', 'drop', and 'raise'. If 'none', no nan checking is done.
102
                 If 'drop', any observations with nans are dropped. If 'raise', an error is raised.
103
        as_continuous: A list of var names indicating which columns are continuous rather than categorical.
104
                    The corresponding columns will be set as type float.
105
106
    Returns:
107
        The GLM model instance.
108
109
    Examples:
110
        >>> import ehrapy as ep
111
        >>> adata = ep.dt.mimic_2(encoded=False)
112
        >>> formula = "day_28_flg ~ age"
113
        >>> var_names = ["day_28_flg", "age"]
114
        >>> family = "Binomial"
115
        >>> glm = ep.tl.glm(adata, var_names, formula, family, missing="drop", as_continuous=["age"])
116
    """
117
    family_dict = {
118
        "Gaussian": sm.families.Gaussian(),
119
        "Binomial": sm.families.Binomial(),
120
        "Gamma": sm.families.Gamma(),
121
        "InverseGaussian": sm.families.InverseGaussian(),
122
    }
123
    if family in ["Gaussian", "Binomial", "Gamma", "Gaussian", "InverseGaussian"]:
124
        family = family_dict[family]
125
    if isinstance(var_names, list):
126
        data = pd.DataFrame(adata[:, var_names].X, columns=var_names)
127
    else:
128
        data = pd.DataFrame(adata.X, columns=adata.var_names)
129
    if as_continuous is not None:
130
        data[as_continuous] = data[as_continuous].astype(float)
131
    if use_feature_types:
132
        for col in data.columns:
133
            if col in adata.var.index:
134
                feature_type = adata.var[FEATURE_TYPE_KEY][col]
135
                if feature_type == CATEGORICAL_TAG:
136
                    data[col] = data[col].astype("category")
137
                elif feature_type == NUMERIC_TAG:
138
                    data[col] = data[col].astype(float)
139
140
    glm = smf.glm(formula, data=data, family=family, missing=missing)
141
142
    return glm
143
144
145
def kmf(
146
    durations: Iterable,
147
    event_observed: Iterable | None = None,
148
    timeline: Iterable = None,
149
    entry: Iterable | None = None,
150
    label: str | None = None,
151
    alpha: float | None = None,
152
    ci_labels: tuple[str, str] = None,
153
    weights: Iterable | None = None,
154
    censoring: Literal["right", "left"] = None,
155
) -> KaplanMeierFitter:
156
    """DEPRECATION WARNING: This function is deprecated and will be removed in the next release. Use `kaplan_meier` instead.
157
158
    Fit the Kaplan-Meier estimate for the survival function.
159
160
    The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data.
161
    In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment.
162
163
    See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
164
        https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
165
166
    Args:
167
        durations: length n -- duration (relative to subject's birth) the subject was alive for.
168
        event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.
169
        timeline: return the best estimate at the values in timelines (positively increasing)
170
        entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
171
               If None, all members of the population entered study when they were "born".
172
        label: A string to name the column of the estimate.
173
        alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
174
        ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
175
        weights: If providing a weighted dataset. For example, instead of providing every subject
176
                 as a single element of `durations` and `event_observed`, one could weigh subject differently.
177
        censoring: 'right' for fitting the model to a right-censored dataset.
178
                   'left' for fitting the model to a left-censored dataset (default: fit the model to a right-censored dataset).
179
180
    Returns:
181
        Fitted KaplanMeierFitter.
182
183
    Examples:
184
        >>> import ehrapy as ep
185
        >>> adata = ep.dt.mimic_2(encoded=False)
186
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
187
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
188
        >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
189
    """
190
    warnings.warn(
191
        "This function is deprecated and will be removed in the next release. Use `ep.tl.kaplan_meier` instead.",
192
        DeprecationWarning,
193
        stacklevel=2,
194
    )
195
    kmf = KaplanMeierFitter()
196
    if censoring == "None" or "right":
197
        kmf.fit(
198
            durations=durations,
199
            event_observed=event_observed,
200
            timeline=timeline,
201
            entry=entry,
202
            label=label,
203
            alpha=alpha,
204
            ci_labels=ci_labels,
205
            weights=weights,
206
        )
207
    elif censoring == "left":
208
        kmf.fit_left_censoring(
209
            durations=durations,
210
            event_observed=event_observed,
211
            timeline=timeline,
212
            entry=entry,
213
            label=label,
214
            alpha=alpha,
215
            ci_labels=ci_labels,
216
            weights=weights,
217
        )
218
219
    return kmf
220
221
222
def kaplan_meier(
223
    adata: AnnData,
224
    duration_col: str,
225
    event_col: str | None = None,
226
    *,
227
    uns_key: str = "kaplan_meier",
228
    timeline: list[float] | None = None,
229
    entry: str | None = None,
230
    label: str | None = None,
231
    alpha: float | None = None,
232
    ci_labels: list[str] | None = None,
233
    weights: list[float] | None = None,
234
    fit_options: dict | None = None,
235
    censoring: Literal["right", "left"] = "right",
236
) -> KaplanMeierFitter:
237
    """Fit the Kaplan-Meier estimate for the survival function.
238
239
    The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data.
240
    In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment.
241
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'kaplan_meier' unless specified otherwise in the `uns_key` parameter.
242
243
    See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
244
        https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
245
246
    Args:
247
        adata: AnnData object.
248
        duration_col: The name of the column in the AnnData object that contains the subjects’ lifetimes.
249
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
250
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
251
            If left `None`, all individuals are assumed to be uncensored.
252
        uns_key: The key to use for the `.uns` slot in the AnnData object.
253
        timeline: Return the best estimate at the values in timelines (positively increasing)
254
        entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
255
               If None, all members of the population entered study when they were "born".
256
        label: A string to name the column of the estimate.
257
        alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
258
        ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
259
        weights: If providing a weighted dataset. For example, instead of providing every subject
260
                 as a single element of `durations` and `event_observed`, one could weigh subject differently.
261
        fit_options: Additional keyword arguments to pass into the estimator.
262
        censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
263
                   'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).
264
265
    Returns:
266
        Fitted KaplanMeierFitter.
267
268
    Examples:
269
        >>> import ehrapy as ep
270
        >>> adata = ep.dt.mimic_2(encoded=False)
271
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
272
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
273
        >>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg", label="Mortality")
274
    """
275
    return _univariate_model(
276
        adata,
277
        duration_col,
278
        event_col,
279
        KaplanMeierFitter,
280
        uns_key,
281
        True,
282
        timeline,
283
        entry,
284
        label,
285
        alpha,
286
        ci_labels,
287
        weights,
288
        fit_options,
289
        censoring,
290
    )
291
292
293
def test_kmf_logrank(
294
    kmf_A: KaplanMeierFitter,
295
    kmf_B: KaplanMeierFitter,
296
    t_0: float | None = -1,
297
    weightings: Literal["wilcoxon", "tarone-ware", "peto", "fleming-harrington"] | None = None,
298
) -> StatisticalResult:
299
    """Calculates the p-value for the logrank test comparing the survival functions of two groups.
300
301
    Measures and reports on whether two intensity processes are different.
302
    That is, given two event series, determines whether the data generating processes are statistically different.
303
    The test-statistic is chi-squared under the null hypothesis.
304
305
    See https://lifelines.readthedocs.io/en/latest/lifelines.statistics.html
306
307
    Args:
308
        kmf_A: The first KaplanMeierFitter object containing the durations and events.
309
        kmf_B: The second KaplanMeierFitter object containing the durations and events.
310
        t_0: The final time period under observation, and subjects who experience the event after this time are set to be censored.
311
             Specify -1 to use all time.
312
        weightings: Apply a weighted logrank test: options are "wilcoxon" for Wilcoxon (also known as Breslow), "tarone-ware"
313
                    for Tarone-Ware, "peto" for Peto test and "fleming-harrington" for Fleming-Harrington test.
314
                    These are useful for testing for early or late differences in the survival curve. For the Fleming-Harrington
315
                    test, keyword arguments p and q must also be provided with non-negative values.
316
317
    Returns:
318
        The p-value for the logrank test comparing the survival functions of the two groups.
319
    """
320
    results_pairwise = logrank_test(
321
        durations_A=kmf_A.durations,
322
        durations_B=kmf_B.durations,
323
        event_observed_A=kmf_A.event_observed,
324
        event_observed_B=kmf_B.event_observed,
325
        weights_A=kmf_A.weights,
326
        weights_B=kmf_B.weights,
327
        t_0=t_0,
328
        weightings=weightings,
329
    )
330
331
    return results_pairwise
332
333
334
def test_nested_f_statistic(small_model: GLMResultsWrapper, big_model: GLMResultsWrapper) -> float:
335
    """Calculate the P value indicating if a larger GLM, encompassing a smaller GLM's parameters, adds explanatory power.
336
337
    See https://stackoverflow.com/questions/27328623/anova-test-for-glm-in-python/60769343#60769343
338
339
    Args:
340
        small_model: fitted generalized linear models.
341
        big_model: fitted generalized linear models.
342
343
    Returns:
344
        float: p_value of Anova test.
345
    """
346
    addtl_params = big_model.df_model - small_model.df_model
347
    f_stat = (small_model.deviance - big_model.deviance) / (addtl_params * big_model.scale)
348
    df_numerator = addtl_params
349
    df_denom = big_model.fittedvalues.shape[0] - big_model.df_model
350
    p_value = stats.f.sf(f_stat, df_numerator, df_denom)
351
352
    return p_value
353
354
355
def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_1: str, formula_2: str) -> pd.DataFrame:
356
    """Anova table for two fitted generalized linear models.
357
358
    Args:
359
        result_1: fitted generalized linear models.
360
        result_2: fitted generalized linear models.
361
        formula_1: The formula specifying the model.
362
        formula_2: The formula specifying the model.
363
364
    Returns:
365
        pd.DataFrame: Anova table.
366
    """
367
    p_value = test_nested_f_statistic(result_1, result_2)
368
369
    table = {
370
        "Model": [1, 2],
371
        "formula": [formula_1, formula_2],
372
        "Df Resid.": [result_1.df_resid, result_2.df_resid],
373
        "Dev.": [result_1.deviance, result_2.deviance],
374
        "Df_diff": [None, result_2.df_model - result_1.df_model],
375
        "Pr(>Chi)": [None, p_value],
376
    }
377
    dataframe = pd.DataFrame(data=table)
378
    return dataframe
379
380
381
def _build_model_input_dataframe(adata: AnnData, duration_col: str, accept_zero_duration=True):
382
    """Convenience function for regression models."""
383
    df = anndata_to_df(adata)
384
    df = df.dropna()
385
386
    if not accept_zero_duration:
387
        df.loc[df[duration_col] == 0, duration_col] += 1e-5
388
389
    return df
390
391
392
def cox_ph(
393
    adata: AnnData,
394
    duration_col: str,
395
    event_col: str = None,
396
    *,
397
    uns_key: str = "cox_ph",
398
    alpha: float = 0.05,
399
    label: str | None = None,
400
    baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow",
401
    penalizer: float | np.ndarray = 0.0,
402
    l1_ratio: float = 0.0,
403
    strata: list[str] | str | None = None,
404
    n_baseline_knots: int = 4,
405
    knots: list[float] | None = None,
406
    breakpoints: list[float] | None = None,
407
    weights_col: str | None = None,
408
    cluster_col: str | None = None,
409
    entry_col: str = None,
410
    robust: bool = False,
411
    formula: str = None,
412
    batch_mode: bool = None,
413
    show_progress: bool = False,
414
    initial_point: np.ndarray | None = None,
415
    fit_options: dict | None = None,
416
) -> CoxPHFitter:
417
    """Fit the Cox’s proportional hazard for the survival function.
418
419
    The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables.
420
    It models the hazard rate as a product of a baseline hazard function and an exponential function of the predictors, assuming proportional hazards over time.
421
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'cox_ph' unless specified otherwise in the `uns_key` parameter.
422
423
    See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html
424
425
    Args:
426
        adata: AnnData object.
427
        duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
428
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
429
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
430
            If left `None`, all individuals are assumed to be uncensored.
431
        uns_key: The key to use for the `.uns` slot in the AnnData object.
432
        alpha: The alpha value in the confidence intervals.
433
        label: The name of the column of the estimate.
434
        baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'.
435
        penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
436
        l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
437
        strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf.
438
        n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead.
439
        knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead.
440
        breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints.
441
        weights_col: The name of the column in DataFrame that contains the weights for each subject.
442
        cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used.
443
        entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
444
        robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
445
        formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing.
446
        batch_mode:  Enabling batch_mode can be faster for datasets with a large number of ties. If left as `None`, lifelines will choose the best option.
447
        show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
448
        initial_point: set the starting point for the iterative solver.
449
        fit_options: Additional keyword arguments to pass into the estimator.
450
451
    Returns:
452
        Fitted CoxPHFitter.
453
454
    Examples:
455
        >>> import ehrapy as ep
456
        >>> adata = ep.dt.mimic_2(encoded=False)
457
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
458
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
459
        >>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
460
    """
461
    df = _build_model_input_dataframe(adata, duration_col)
462
    cox_ph = CoxPHFitter(
463
        alpha=alpha,
464
        label=label,
465
        strata=strata,
466
        baseline_estimation_method=baseline_estimation_method,
467
        penalizer=penalizer,
468
        l1_ratio=l1_ratio,
469
        n_baseline_knots=n_baseline_knots,
470
        knots=knots,
471
        breakpoints=breakpoints,
472
    )
473
    cox_ph.fit(
474
        df,
475
        duration_col=duration_col,
476
        event_col=event_col,
477
        entry_col=entry_col,
478
        robust=robust,
479
        initial_point=initial_point,
480
        weights_col=weights_col,
481
        cluster_col=cluster_col,
482
        batch_mode=batch_mode,
483
        formula=formula,
484
        fit_options=fit_options,
485
        show_progress=show_progress,
486
    )
487
488
    summary = cox_ph.summary
489
    adata.uns[uns_key] = summary
490
491
    return cox_ph
492
493
494
def weibull_aft(
495
    adata: AnnData,
496
    duration_col: str,
497
    event_col: str,
498
    *,
499
    uns_key: str = "weibull_aft",
500
    alpha: float = 0.05,
501
    fit_intercept: bool = True,
502
    penalizer: float | np.ndarray = 0.0,
503
    l1_ratio: float = 0.0,
504
    model_ancillary: bool = True,
505
    ancillary: bool | pd.DataFrame | str | None = None,
506
    show_progress: bool = False,
507
    weights_col: str | None = None,
508
    robust: bool = False,
509
    initial_point=None,
510
    entry_col: str | None = None,
511
    formula: str | None = None,
512
    fit_options: dict | None = None,
513
) -> WeibullAFTFitter:
514
    """Fit the Weibull accelerated failure time regression for the survival function.
515
516
    The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data,
517
    where the underlying assumption is that the logarithm of survival time follows a Weibull distribution.
518
    It models the survival time as an exponential function of the predictors, assuming a specific shape parameter
519
    for the distribution and allowing for accelerated or decelerated failure times based on the covariates.
520
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'weibull_aft' unless specified otherwise in the `uns_key` parameter.
521
522
    See https://lifelines.readthedocs.io/en/latest/fitters/regression/WeibullAFTFitter.html
523
524
    Args:
525
        adata: AnnData object.
526
        duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
527
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
528
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
529
            If left `None`, all individuals are assumed to be uncensored.
530
        uns_key: The key to use for the `.uns` slot in the AnnData object.
531
        alpha: The alpha value in the confidence intervals.
532
        fit_intercept: Whether to fit an intercept term in the model.
533
        penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
534
        l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
535
        model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
536
        ancillary: Choose to model the ancillary parameters.
537
            If None or False, explicitly do not fit the ancillary parameters using any covariates.
538
            If True, model the ancillary parameters with the same covariates as ``df``.
539
            If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
540
            If str, should be a formula
541
        show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
542
        weights_col: The name of the column in DataFrame that contains the weights for each subject.
543
        robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
544
        initial_point: set the starting point for the iterative solver.
545
        entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
546
        formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
547
            If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
548
        fit_options: Additional keyword arguments to pass into the estimator.
549
550
551
    Returns:
552
        Fitted WeibullAFTFitter.
553
554
    Examples:
555
        >>> import ehrapy as ep
556
        >>> adata = ep.dt.mimic_2(encoded=False)
557
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
558
        >>> adata = adata[:, ["mort_day_censored", "censor_flg"]]
559
        >>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg")
560
        >>> aft.print_summary()
561
    """
562
    df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration=False)
563
564
    weibull_aft = WeibullAFTFitter(
565
        alpha=alpha,
566
        fit_intercept=fit_intercept,
567
        penalizer=penalizer,
568
        l1_ratio=l1_ratio,
569
        model_ancillary=model_ancillary,
570
    )
571
572
    weibull_aft.fit(
573
        df,
574
        duration_col=duration_col,
575
        event_col=event_col,
576
        entry_col=entry_col,
577
        ancillary=ancillary,
578
        show_progress=show_progress,
579
        weights_col=weights_col,
580
        robust=robust,
581
        initial_point=initial_point,
582
        formula=formula,
583
        fit_options=fit_options,
584
    )
585
586
    summary = weibull_aft.summary
587
    adata.uns[uns_key] = summary
588
589
    return weibull_aft
590
591
592
def log_logistic_aft(
593
    adata: AnnData,
594
    duration_col: str,
595
    event_col: str | None = None,
596
    *,
597
    uns_key: str = "log_logistic_aft",
598
    alpha: float = 0.05,
599
    fit_intercept: bool = True,
600
    penalizer: float | np.ndarray = 0.0,
601
    l1_ratio: float = 0.0,
602
    model_ancillary: bool = False,
603
    ancillary: bool | pd.DataFrame | str | None = None,
604
    show_progress: bool = False,
605
    weights_col: str | None = None,
606
    robust: bool = False,
607
    initial_point=None,
608
    entry_col: str | None = None,
609
    formula: str | None = None,
610
    fit_options: dict | None = None,
611
) -> LogLogisticAFTFitter:
612
    """Fit the log logistic accelerated failure time regression for the survival function.
613
614
    The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data.
615
    This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times.
616
    By modeling survival time as a function of predictors, the Log-Logistic AFT model enables researchers to explore
617
    how specific factors influence the acceleration or deceleration of failure times, providing valuable insights into the underlying mechanisms driving event occurrence.
618
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'log_logistic_aft' unless specified otherwise in the `uns_key` parameter.
619
620
    See https://lifelines.readthedocs.io/en/latest/fitters/regression/LogLogisticAFTFitter.html
621
622
    Args:
623
        adata: AnnData object.
624
        duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
625
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
626
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
627
            If left `None`, all individuals are assumed to be uncensored.
628
        uns_key: The key to use for the `.uns` slot in the AnnData object.
629
        alpha: The alpha value in the confidence intervals.
630
        fit_intercept: Whether to fit an intercept term in the model.
631
        penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
632
        l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
633
        model_ancillary: Set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
634
        ancillary: Choose to model the ancillary parameters.
635
            If None or False, explicitly do not fit the ancillary parameters using any covariates.
636
            If True, model the ancillary parameters with the same covariates as ``df``.
637
            If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
638
            If str, should be a formula
639
        show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
640
        weights_col: The name of the column in DataFrame that contains the weights for each subject.
641
        robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
642
        initial_point: set the starting point for the iterative solver.
643
        entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
644
        formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
645
            If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
646
        fit_options: Additional keyword arguments to pass into the estimator.
647
648
    Returns:
649
        Fitted LogLogisticAFTFitter.
650
651
    Examples:
652
        >>> import ehrapy as ep
653
        >>> adata = ep.dt.mimic_2(encoded=False)
654
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
655
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
656
        >>> adata = adata[:, ["mort_day_censored", "censor_flg"]]
657
        >>> llf = ep.tl.log_logistic_aft(adata, duration_col="mort_day_censored", event_col="censor_flg")
658
    """
659
    df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration=False)
660
661
    log_logistic_aft = LogLogisticAFTFitter(
662
        alpha=alpha,
663
        fit_intercept=fit_intercept,
664
        penalizer=penalizer,
665
        l1_ratio=l1_ratio,
666
        model_ancillary=model_ancillary,
667
    )
668
669
    log_logistic_aft.fit(
670
        df,
671
        duration_col=duration_col,
672
        event_col=event_col,
673
        entry_col=entry_col,
674
        ancillary=ancillary,
675
        show_progress=show_progress,
676
        weights_col=weights_col,
677
        robust=robust,
678
        initial_point=initial_point,
679
        formula=formula,
680
        fit_options=fit_options,
681
    )
682
683
    summary = log_logistic_aft.summary
684
    adata.uns[uns_key] = summary
685
686
    return log_logistic_aft
687
688
689
def _univariate_model(
690
    adata: AnnData,
691
    duration_col: str,
692
    event_col: str,
693
    model_class,
694
    uns_key: str,
695
    accept_zero_duration=True,
696
    timeline: list[float] | None = None,
697
    entry: str | None = None,
698
    label: str | None = None,
699
    alpha: float | None = None,
700
    ci_labels: list[str] | None = None,
701
    weights: list[float] | None = None,
702
    fit_options: dict | None = None,
703
    censoring: Literal["right", "left"] = "right",
704
):
705
    """Convenience function for univariate models."""
706
    df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration)
707
    T = df[duration_col]
708
    E = df[event_col]
709
710
    model = model_class()
711
    function_name = "fit" if censoring == "right" else "fit_left_censoring"
712
    # get fit function, default to fit if not found
713
    fit_function = getattr(model, function_name, model.fit)
714
715
    fit_function(
716
        T,
717
        event_observed=E,
718
        timeline=timeline,
719
        entry=entry,
720
        label=label,
721
        alpha=alpha,
722
        ci_labels=ci_labels,
723
        weights=weights,
724
        fit_options=fit_options,
725
    )
726
727
    if isinstance(model, NelsonAalenFitter) or isinstance(
728
        model, KaplanMeierFitter
729
    ):  # NelsonAalenFitter and KaplanMeierFitter have no summary attribute
730
        summary = model.event_table
731
    else:
732
        summary = model.summary
733
    adata.uns[uns_key] = summary
734
735
    return model
736
737
738
def nelson_aalen(
739
    adata: AnnData,
740
    duration_col: str,
741
    event_col: str | None = None,
742
    *,
743
    uns_key: str = "nelson_aalen",
744
    timeline: list[float] | None = None,
745
    entry: str | None = None,
746
    label: str | None = None,
747
    alpha: float | None = None,
748
    ci_labels: list[str] | None = None,
749
    weights: list[float] | None = None,
750
    fit_options: dict | None = None,
751
    censoring: Literal["right", "left"] = "right",
752
) -> NelsonAalenFitter:
753
    """Employ the Nelson-Aalen estimator to estimate the cumulative hazard function from censored survival data.
754
755
    The Nelson-Aalen estimator is a non-parametric method used in survival analysis to estimate the cumulative hazard function.
756
    This technique is particularly useful when dealing with censored data, as it accounts for the presence of individuals whose event times are unknown due to censoring.
757
    By estimating the cumulative hazard function, the Nelson-Aalen estimator allows researchers to assess the risk of an event occurring over time, providing valuable insights into the underlying dynamics of the survival process.
758
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'nelson_aalen' unless specified otherwise in the `uns_key` parameter.
759
    See https://lifelines.readthedocs.io/en/latest/fitters/univariate/NelsonAalenFitter.html
760
761
    Args:
762
        adata: AnnData object.
763
        duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
764
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
765
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
766
            If left `None`, all individuals are assumed to be uncensored.
767
        uns_key: The key to use for the `.uns` slot in the AnnData object.
768
        timeline: Return the best estimate at the values in timelines (positively increasing)
769
        entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
770
               If None, all members of the population entered study when they were "born".
771
        label: A string to name the column of the estimate.
772
        alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
773
        ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
774
        weights: If providing a weighted dataset. For example, instead of providing every subject
775
                 as a single element of `durations` and `event_observed`, one could weigh subject differently.
776
        fit_options: Additional keyword arguments to pass into the estimator.
777
        censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
778
                   'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).
779
780
    Returns:
781
        Fitted NelsonAalenFitter.
782
783
    Examples:
784
        >>> import ehrapy as ep
785
        >>> adata = ep.dt.mimic_2(encoded=False)
786
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
787
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
788
        >>> naf = ep.tl.nelson_aalen(adata, "mort_day_censored", "censor_flg")
789
    """
790
    return _univariate_model(
791
        adata,
792
        duration_col,
793
        event_col,
794
        NelsonAalenFitter,
795
        uns_key=uns_key,
796
        accept_zero_duration=True,
797
        timeline=timeline,
798
        entry=entry,
799
        label=label,
800
        alpha=alpha,
801
        ci_labels=ci_labels,
802
        weights=weights,
803
        fit_options=fit_options,
804
        censoring=censoring,
805
    )
806
807
808
def weibull(
809
    adata: AnnData,
810
    duration_col: str,
811
    event_col: str,
812
    *,
813
    uns_key: str = "weibull",
814
    timeline: list[float] | None = None,
815
    entry: str | None = None,
816
    label: str | None = None,
817
    alpha: float | None = None,
818
    ci_labels: list[str] | None = None,
819
    weights: list[float] | None = None,
820
    fit_options: dict | None = None,
821
) -> WeibullFitter:
822
    """Employ the Weibull model in univariate survival analysis to understand event occurrence dynamics.
823
824
    In contrast to the non-parametric Nelson-Aalen estimator, the Weibull model employs a parametric approach with shape and scale parameters,
825
    enabling a more structured analysis of survival data.
826
    This technique is particularly useful when dealing with censored data, as it accounts for the presence of individuals whose event times are unknown due to censoring.
827
    By fitting the Weibull model to censored survival data, researchers can estimate these parameters and gain insights
828
    into the hazard rate over time, facilitating comparisons between different groups or treatments.
829
    This method provides a comprehensive framework for examining survival data and offers valuable insights into the factors influencing event occurrence dynamics.
830
    The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'weibull' unless specified otherwise in the `uns_key` parameter.
831
    See https://lifelines.readthedocs.io/en/latest/fitters/univariate/WeibullFitter.html
832
833
    Args:
834
        adata: AnnData object.
835
        duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
836
        event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored.
837
            Column values are `True` if the event was observed, `False` if the event was lost (right-censored).
838
            If left `None`, all individuals are assumed to be uncensored.
839
        uns_key: The key to use for the `.uns` slot in the AnnData object.
840
        timeline: Return the best estimate at the values in timelines (positively increasing)
841
        entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
842
               If None, all members of the population entered study when they were "born".
843
        label: A string to name the column of the estimate.
844
        alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
845
        ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
846
        weights: If providing a weighted dataset. For example, instead of providing every subject
847
                 as a single element of `durations` and `event_observed`, one could weigh subject differently.
848
        fit_options: Additional keyword arguments to pass into the estimator.
849
850
    Returns:
851
        Fitted WeibullFitter.
852
853
    Examples:
854
        >>> import ehrapy as ep
855
        >>> adata = ep.dt.mimic_2(encoded=False)
856
        >>> # Flip 'censor_fl' because 0 = death and 1 = censored
857
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
858
        >>> wf = ep.tl.weibull(adata, "mort_day_censored", "censor_flg")
859
    """
860
    return _univariate_model(
861
        adata,
862
        duration_col,
863
        event_col,
864
        WeibullFitter,
865
        uns_key=uns_key,
866
        accept_zero_duration=False,
867
        timeline=timeline,
868
        entry=entry,
869
        label=label,
870
        alpha=alpha,
871
        ci_labels=ci_labels,
872
        weights=weights,
873
        fit_options=fit_options,
874
    )