Switch to unified view

a b/ehrapy/plot/_survival_analysis.py
1
from __future__ import annotations
2
3
import warnings
4
from typing import TYPE_CHECKING
5
6
import matplotlib.gridspec as gridspec
7
import matplotlib.pyplot as plt
8
import matplotlib.ticker as ticker
9
import numpy as np
10
import pandas as pd
11
from numpy import ndarray
12
13
from ehrapy.plot import scatter
14
15
if TYPE_CHECKING:
16
    from collections.abc import Iterable, Sequence
17
    from xmlrpc.client import Boolean
18
19
    from anndata import AnnData
20
    from lifelines import KaplanMeierFitter
21
    from matplotlib.axes import Axes
22
    from statsmodels.regression.linear_model import RegressionResults
23
24
25
def ols(
26
    adata: AnnData | None = None,
27
    x: str | None = None,
28
    y: str | None = None,
29
    scatter_plot: Boolean | None = True,
30
    ols_results: list[RegressionResults] | None = None,
31
    ols_color: list[str] | None | None = None,
32
    xlabel: str | None = None,
33
    ylabel: str | None = None,
34
    figsize: tuple[float, float] | None = None,
35
    lines: list[tuple[ndarray | float, ndarray | float]] | None = None,
36
    lines_color: list[str] | None | None = None,
37
    lines_style: list[str] | None | None = None,
38
    lines_label: list[str] | None | None = None,
39
    xlim: tuple[float, float] | None = None,
40
    ylim: tuple[float, float] | None = None,
41
    show: bool | None = None,
42
    ax: Axes | None = None,
43
    title: str | None = None,
44
    **kwds,
45
) -> Axes | None:
46
    """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot.
47
48
    Args:
49
        adata: :class:`~anndata.AnnData` object containing all observations.
50
        x: x coordinate, for scatter plotting.
51
        y: y coordinate, for scatter plotting.
52
        scatter_plot: Whether to show a scatter plot.
53
        ols_results: List of RegressionResults from ehrapy.tl.ols. Example: [result_1, result_2]
54
        ols_color: List of colors for each ols_results. Example: ['red', 'blue'].
55
        xlabel: The x-axis label text.
56
        ylabel: The y-axis label text.
57
        figsize: Width, height in inches.
58
        lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points.
59
               Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)]
60
        lines_color: List of colors for each line. Example: ['red', 'blue']
61
        lines_style: List of line styles for each line. Example: ['-', '--']
62
        lines_label: List of line labels for each line. Example: ['Line1', 'Line2']
63
        xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept.
64
        ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept.
65
        show: Show the plot, do not return axis.
66
        ax: A matplotlib axes object. Only works if plotting a single component.
67
        title: Set the title of the plot.
68
        kwds: Passed to matplotblib scatterplot.
69
70
    Examples:
71
        >>> import ehrapy as ep
72
        >>> adata = ep.dt.mimic_2(encoded=False)
73
        >>> co2_lm_result = ep.tl.ols(
74
        ...     adata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop"
75
        ... ).fit()
76
        >>> ep.pl.ols(
77
        ...     adata,
78
        ...     x="pco2_first",
79
        ...     y="tco2_first",
80
        ...     ols_results=[co2_lm_result],
81
        ...     ols_color=["red"],
82
        ...     xlabel="PCO2",
83
        ...     ylabel="TCO2",
84
        ... )
85
86
        .. image:: /_static/docstring_previews/ols_plot_1.png
87
88
        >>> import ehrapy as ep
89
        >>> adata = ep.dt.mimic_2(encoded=False)
90
        >>> ep.pl.ols(adata, x='pco2_first', y='tco2_first', lines=[(0.25, 10), (0.3, 20)],
91
        >>>           lines_color=['red', 'blue'], lines_style=['-', ':'], lines_label=['Line1', 'Line2'])
92
93
        .. image:: /_static/docstring_previews/ols_plot_2.png
94
95
        >>> import ehrapy as ep
96
        >>> ep.pl.ols(lines=[(0.25, 10), (0.3, 20)], lines_color=['red', 'blue'], lines_style=['-', ':'],
97
        >>>           lines_label=['Line1', 'Line2'], xlim=(0, 150), ylim=(0, 50))
98
99
        .. image:: /_static/docstring_previews/ols_plot_3.png
100
    """
101
    if ax is None:
102
        _, ax = plt.subplots(figsize=figsize)
103
    if xlim is not None:
104
        plt.xlim(xlim)
105
    if ylim is not None:
106
        plt.ylim(ylim)
107
    if ols_color is None and ols_results is not None:
108
        ols_color = [None] * len(ols_results)
109
    if lines_color is None and lines is not None:
110
        lines_color = [None] * len(lines)
111
    if lines_style is None and lines is not None:
112
        lines_style = [None] * len(lines)
113
    if lines_label is None and lines is not None:
114
        lines_label = [None] * len(lines)
115
    if adata is not None and x is not None and y is not None:
116
        x_processed = np.array(adata[:, x].X).astype(float)
117
        x_processed = x_processed[~np.isnan(x_processed)]
118
        if scatter_plot is True:
119
            ax = scatter(adata, x=x, y=y, show=False, ax=ax, **kwds)
120
        if ols_results is not None:
121
            for i, ols_result in enumerate(ols_results):
122
                ax.plot(x_processed, ols_result.predict(), color=ols_color[i])
123
124
    if lines is not None:
125
        for i, line in enumerate(lines):
126
            a, b = line
127
            if np.ndim(a) == 0 and np.ndim(b) == 0:
128
                line_x = np.array(ax.get_xlim())
129
                line_y = a * line_x + b
130
                ax.plot(line_x, line_y, linestyle=lines_style[i], color=lines_color[i], label=lines_label[i])
131
            else:
132
                ax.plot(a, b, lines_style[i], color=lines_color[i], label=lines_label[i])
133
    plt.xlabel(xlabel)
134
    plt.ylabel(ylabel)
135
    if title:
136
        plt.title(title)
137
    if lines_label is not None and lines_label[0] is not None:
138
        plt.legend()
139
140
    if not show:
141
        return ax
142
    else:
143
        return None
144
145
146
def kmf(
147
    kmfs: Sequence[KaplanMeierFitter],
148
    ci_alpha: list[float] | None = None,
149
    ci_force_lines: list[Boolean] | None = None,
150
    ci_show: list[Boolean] | None = None,
151
    ci_legend: list[Boolean] | None = None,
152
    at_risk_counts: list[Boolean] | None = None,
153
    color: list[str] | None | None = None,
154
    grid: Boolean | None = False,
155
    xlim: tuple[float, float] | None = None,
156
    ylim: tuple[float, float] | None = None,
157
    xlabel: str | None = None,
158
    ylabel: str | None = None,
159
    figsize: tuple[float, float] | None = None,
160
    show: bool | None = None,
161
    title: str | None = None,
162
) -> Axes | None:
163
    warnings.warn(
164
        "This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.",
165
        DeprecationWarning,
166
        stacklevel=2,
167
    )
168
    return kaplan_meier(
169
        kmfs=kmfs,
170
        ci_alpha=ci_alpha,
171
        ci_force_lines=ci_force_lines,
172
        ci_show=ci_show,
173
        ci_legend=ci_legend,
174
        at_risk_counts=at_risk_counts,
175
        color=color,
176
        grid=grid,
177
        xlim=xlim,
178
        ylim=ylim,
179
        xlabel=xlabel,
180
        ylabel=ylabel,
181
        figsize=figsize,
182
        show=show,
183
        title=title,
184
    )
185
186
187
def kaplan_meier(
188
    kmfs: Sequence[KaplanMeierFitter],
189
    *,
190
    display_survival_statistics: bool = False,
191
    ci_alpha: list[float] | None = None,
192
    ci_force_lines: list[Boolean] | None = None,
193
    ci_show: list[Boolean] | None = None,
194
    ci_legend: list[Boolean] | None = None,
195
    at_risk_counts: list[Boolean] | None = None,
196
    color: list[str] | None | None = None,
197
    grid: Boolean | None = False,
198
    xlim: tuple[float, float] | None = None,
199
    ylim: tuple[float, float] | None = None,
200
    xlabel: str | None = None,
201
    ylabel: str | None = None,
202
    figsize: tuple[float, float] | None = None,
203
    show: bool | None = None,
204
    title: str | None = None,
205
) -> Axes | None:
206
    """Plots a pretty figure of the Fitted KaplanMeierFitter model.
207
208
    See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html
209
210
    Args:
211
        kmfs: Iterables of fitted KaplanMeierFitter objects.
212
        display_survival_statistics: Whether to show survival statistics in a table below the plot.
213
        ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list.
214
        ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas).
215
                        If more than one kmfs, this should be a list.
216
        ci_show: Show confidence intervals. If more than one kmfs, this should be a list.
217
        ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend.
218
                   If more than one kmfs, this should be a list.
219
        at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list.
220
        color: List of colors for each kmf. If more than one kmfs, this should be a list.
221
        grid: If True, plot grid lines.
222
        xlim: Set the x-axis view limits.
223
        ylim: Set the y-axis view limits.
224
        xlabel: The x-axis label text.
225
        ylabel: The y-axis label text.
226
        figsize: Width, height in inches.
227
        show: Show the plot, do not return axis.
228
        title: Set the title of the plot.
229
230
    Examples:
231
        >>> import ehrapy as ep
232
        >>> import numpy as np
233
        >>> adata = ep.dt.mimic_2(encoded=False)
234
235
        # Because in MIMIC-II database, `censor_fl` is censored or death (binary: 0 = death, 1 = censored).
236
        # While in KaplanMeierFitter, `event_observed` is True if the the death was observed, False if the event was lost (right-censored).
237
        # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter
238
239
        >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
240
        >>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg")
241
        >>> ep.pl.kaplan_meier(
242
        ...     [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True
243
        ... )
244
245
        .. image:: /_static/docstring_previews/kmf_plot_1.png
246
247
        >>> groups = adata[:, ["service_unit"]].X
248
        >>> adata_ficu = adata[groups == "FICU"]
249
        >>> adata_micu = adata[groups == "MICU"]
250
        >>> adata_sicu = adata[groups == "SICU"]
251
        >>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU")
252
        >>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU")
253
        >>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU")
254
        >>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'],
255
        >>>           xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived")
256
257
        .. image:: /_static/docstring_previews/kmf_plot_2.png
258
    """
259
    if ci_alpha is None:
260
        ci_alpha = [0.3] * len(kmfs)
261
    if ci_force_lines is None:
262
        ci_force_lines = [False] * len(kmfs)
263
    if ci_show is None:
264
        ci_show = [True] * len(kmfs)
265
    if ci_legend is None:
266
        ci_legend = [False] * len(kmfs)
267
    if at_risk_counts is None:
268
        at_risk_counts = [False] * len(kmfs)
269
    if color is None:
270
        color = [None] * len(kmfs)
271
272
    fig = plt.figure(constrained_layout=True, figsize=figsize)
273
    spec = fig.add_gridspec(2, 1) if display_survival_statistics else fig.add_gridspec(1, 1)
274
    ax = plt.subplot(spec[0, 0])
275
276
    for i, kmf in enumerate(kmfs):
277
        if i == 0:
278
            ax = kmf.plot_survival_function(
279
                ci_alpha=ci_alpha[i],
280
                ci_force_lines=ci_force_lines[i],
281
                ci_show=ci_show[i],
282
                ci_legend=ci_legend[i],
283
                at_risk_counts=at_risk_counts[i],
284
                color=color[i],
285
            )
286
        else:
287
            ax = kmf.plot_survival_function(
288
                ax=ax,
289
                ci_alpha=ci_alpha[i],
290
                ci_force_lines=ci_force_lines[i],
291
                ci_show=ci_show[i],
292
                ci_legend=ci_legend[i],
293
                at_risk_counts=at_risk_counts[i],
294
                color=color[i],
295
            )
296
    # Configure plot appearance
297
    ax.grid(grid)
298
    ax.set_xlim(xlim)
299
    ax.set_ylim(ylim)
300
    ax.set_xlabel(xlabel)
301
    ax.set_ylabel(ylabel)
302
    if title:
303
        ax.set_title(title)
304
305
    if display_survival_statistics:
306
        xticks = [x for x in ax.get_xticks() if x >= 0]
307
        xticks_space = xticks[1] - xticks[0]
308
        if xlabel is None:
309
            xlabel = "Time"
310
311
        yticks = np.arange(len(kmfs))
312
313
        ax_table = plt.subplot(spec[1, 0])
314
        ax_table.set_xticks(xticks)
315
        ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2)
316
        ax_table.set_ylim(-1, len(kmfs))
317
        ax_table.set_yticks(yticks)
318
        ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])])
319
320
        for i, kmf in enumerate(kmfs[::-1]):
321
            survival_probs = kmf.survival_function_at_times(xticks).values
322
            for j, prob in enumerate(survival_probs):
323
                ax_table.text(
324
                    xticks[j],  # x position
325
                    yticks[i],  # y position
326
                    f"{prob:.2f}",  # formatted survival probability
327
                    ha="center",
328
                    va="center",
329
                    bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"},
330
                )
331
332
        ax_table.grid(grid)
333
        ax_table.spines["top"].set_visible(False)
334
        ax_table.spines["right"].set_visible(False)
335
        ax_table.spines["bottom"].set_visible(False)
336
        ax_table.spines["left"].set_visible(False)
337
338
    if not show:
339
        return fig, ax
340
341
    else:
342
        return None
343
344
345
def cox_ph_forestplot(
346
    adata: AnnData,
347
    *,
348
    uns_key: str = "cox_ph",
349
    labels: Iterable[str] | None = None,
350
    fig_size: tuple = (10, 10),
351
    t_adjuster: float = 0.1,
352
    ecolor: str = "dimgray",
353
    size: int = 3,
354
    marker: str = "o",
355
    decimal: int = 2,
356
    text_size: int = 12,
357
    color: str = "k",
358
    show: bool = None,
359
    title: str | None = None,
360
):
361
    """Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model.
362
363
    The `adata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `adata`.
364
    The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function.
365
    For more information on the `CoxPHFitter`, see the `Lifelines documentation <https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html>`_.
366
367
    Inspired by `zepid.graphics.EffectMeasurePlot <https://readthedocs.org>`_ (zEpid Package, https://pypi.org/project/zepid/).
368
369
    Args:
370
        adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`.
371
        uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`.
372
        labels: List of labels for each coefficient, default uses the index of the summary ta
373
        fig_size: Width, height in inches.
374
        t_adjuster: Adjust the table to the right.
375
        ecolor: Color of the error bars.
376
        size: Size of the markers.
377
        marker: Marker style.
378
        decimal: Number of decimal places to display.
379
        text_size: Font size of the text.
380
        color: Color of the markers.
381
        show: Show the plot, do not return figure and axis.
382
        title: Set the title of the plot.
383
384
    Examples:
385
        >>> import ehrapy as ep
386
        >>> adata = ep.dt.mimic_2(encoded=False)
387
        >>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
388
        >>> coxph = ep.tl.cox_ph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored")
389
        >>> ep.pl.cox_ph_forestplot(adata_subset)
390
391
        .. image:: /_static/docstring_previews/coxph_forestplot.png
392
393
    """
394
    if uns_key not in adata.uns:
395
        raise ValueError(f"Key {uns_key} not found in adata.uns. Please provide a valid key.")
396
397
    coxph_fitting_summary = adata.uns[
398
        uns_key
399
    ]  # pd.Dataframe with columns: coef, exp(coef), se(coef), z, p, lower 0.95, upper 0.95
400
    auc_col = "coef"
401
402
    if labels is None:
403
        labels = coxph_fitting_summary.index
404
    tval = []
405
    ytick = []
406
    for row_index in range(len(coxph_fitting_summary)):
407
        if not np.isnan(coxph_fitting_summary[auc_col][row_index]):
408
            if (
409
                (isinstance(coxph_fitting_summary[auc_col][row_index], float))
410
                & (isinstance(coxph_fitting_summary["coef lower 95%"][row_index], float))
411
                & (isinstance(coxph_fitting_summary["coef upper 95%"][row_index], float))
412
            ):
413
                tval.append(
414
                    [
415
                        round(coxph_fitting_summary[auc_col][row_index], decimal),
416
                        (
417
                            "("
418
                            + str(round(coxph_fitting_summary["coef lower 95%"][row_index], decimal))
419
                            + ", "
420
                            + str(round(coxph_fitting_summary["coef upper 95%"][row_index], decimal))
421
                            + ")"
422
                        ),
423
                    ]
424
                )
425
            else:
426
                tval.append(
427
                    [
428
                        coxph_fitting_summary[auc_col][row_index],
429
                        (
430
                            "("
431
                            + str(coxph_fitting_summary["coef lower 95%"][row_index])
432
                            + ", "
433
                            + str(coxph_fitting_summary["coef upper 95%"][row_index])
434
                            + ")"
435
                        ),
436
                    ]
437
                )
438
            ytick.append(row_index)
439
        else:
440
            tval.append([" ", " "])
441
            ytick.append(row_index)
442
443
    x_axis_upper_bound = round(((pd.to_numeric(coxph_fitting_summary["coef upper 95%"])).max() + 0.1), 2)
444
445
    x_axis_lower_bound = round(((pd.to_numeric(coxph_fitting_summary["coef lower 95%"])).min() - 0.1), 1)
446
447
    fig = plt.figure(figsize=fig_size)
448
    gspec = gridspec.GridSpec(1, 6)
449
    plot = plt.subplot(gspec[0, 0:4])
450
    table = plt.subplot(gspec[0, 4:])
451
    plot.set_ylim(-1, (len(coxph_fitting_summary)))  # spacing out y-axis properly
452
453
    plot.axvline(1, color="gray", zorder=1)
454
    lower_diff = coxph_fitting_summary[auc_col] - coxph_fitting_summary["coef lower 95%"]
455
    upper_diff = coxph_fitting_summary["coef upper 95%"] - coxph_fitting_summary[auc_col]
456
    plot.errorbar(
457
        coxph_fitting_summary[auc_col],
458
        coxph_fitting_summary.index,
459
        xerr=[lower_diff, upper_diff],
460
        marker="None",
461
        zorder=2,
462
        ecolor=ecolor,
463
        linewidth=0,
464
        elinewidth=1,
465
    )
466
    # plot markers
467
    plot.scatter(
468
        coxph_fitting_summary[auc_col],
469
        coxph_fitting_summary.index,
470
        c=color,
471
        s=(size * 25),
472
        marker=marker,
473
        zorder=3,
474
        edgecolors="None",
475
    )
476
    # plot settings
477
    plot.xaxis.set_ticks_position("bottom")
478
    plot.yaxis.set_ticks_position("left")
479
    plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
480
    plot.get_xaxis().set_minor_formatter(ticker.NullFormatter())
481
    plot.set_yticks(ytick)
482
    plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound])
483
    plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound])
484
    plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound])
485
    plot.set_yticklabels(labels)
486
    plot.tick_params(axis="y", labelsize=text_size)
487
    plot.yaxis.set_ticks_position("none")
488
    plot.invert_yaxis()  # invert y-axis to align values properly with table
489
    tb = table.table(
490
        cellText=tval, cellLoc="center", loc="right", colLabels=[auc_col, "95% CI"], bbox=[0, t_adjuster, 1, 1]
491
    )
492
    table.axis("off")
493
    tb.auto_set_font_size(False)
494
    tb.set_fontsize(text_size)
495
    for _, cell in tb.get_celld().items():
496
        cell.set_linewidth(0)
497
498
    # remove spines
499
    plot.spines["top"].set_visible(False)
500
    plot.spines["right"].set_visible(False)
501
    plot.spines["left"].set_visible(False)
502
503
    if title:
504
        plt.title(title)
505
506
    if not show:
507
        return fig, plot
508
509
    else:
510
        return None