|
a |
|
b/ehrapy/plot/_survival_analysis.py |
|
|
1 |
from __future__ import annotations |
|
|
2 |
|
|
|
3 |
import warnings |
|
|
4 |
from typing import TYPE_CHECKING |
|
|
5 |
|
|
|
6 |
import matplotlib.gridspec as gridspec |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
import matplotlib.ticker as ticker |
|
|
9 |
import numpy as np |
|
|
10 |
import pandas as pd |
|
|
11 |
from numpy import ndarray |
|
|
12 |
|
|
|
13 |
from ehrapy.plot import scatter |
|
|
14 |
|
|
|
15 |
if TYPE_CHECKING: |
|
|
16 |
from collections.abc import Iterable, Sequence |
|
|
17 |
from xmlrpc.client import Boolean |
|
|
18 |
|
|
|
19 |
from anndata import AnnData |
|
|
20 |
from lifelines import KaplanMeierFitter |
|
|
21 |
from matplotlib.axes import Axes |
|
|
22 |
from statsmodels.regression.linear_model import RegressionResults |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def ols( |
|
|
26 |
adata: AnnData | None = None, |
|
|
27 |
x: str | None = None, |
|
|
28 |
y: str | None = None, |
|
|
29 |
scatter_plot: Boolean | None = True, |
|
|
30 |
ols_results: list[RegressionResults] | None = None, |
|
|
31 |
ols_color: list[str] | None | None = None, |
|
|
32 |
xlabel: str | None = None, |
|
|
33 |
ylabel: str | None = None, |
|
|
34 |
figsize: tuple[float, float] | None = None, |
|
|
35 |
lines: list[tuple[ndarray | float, ndarray | float]] | None = None, |
|
|
36 |
lines_color: list[str] | None | None = None, |
|
|
37 |
lines_style: list[str] | None | None = None, |
|
|
38 |
lines_label: list[str] | None | None = None, |
|
|
39 |
xlim: tuple[float, float] | None = None, |
|
|
40 |
ylim: tuple[float, float] | None = None, |
|
|
41 |
show: bool | None = None, |
|
|
42 |
ax: Axes | None = None, |
|
|
43 |
title: str | None = None, |
|
|
44 |
**kwds, |
|
|
45 |
) -> Axes | None: |
|
|
46 |
"""Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. |
|
|
47 |
|
|
|
48 |
Args: |
|
|
49 |
adata: :class:`~anndata.AnnData` object containing all observations. |
|
|
50 |
x: x coordinate, for scatter plotting. |
|
|
51 |
y: y coordinate, for scatter plotting. |
|
|
52 |
scatter_plot: Whether to show a scatter plot. |
|
|
53 |
ols_results: List of RegressionResults from ehrapy.tl.ols. Example: [result_1, result_2] |
|
|
54 |
ols_color: List of colors for each ols_results. Example: ['red', 'blue']. |
|
|
55 |
xlabel: The x-axis label text. |
|
|
56 |
ylabel: The y-axis label text. |
|
|
57 |
figsize: Width, height in inches. |
|
|
58 |
lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points. |
|
|
59 |
Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)] |
|
|
60 |
lines_color: List of colors for each line. Example: ['red', 'blue'] |
|
|
61 |
lines_style: List of line styles for each line. Example: ['-', '--'] |
|
|
62 |
lines_label: List of line labels for each line. Example: ['Line1', 'Line2'] |
|
|
63 |
xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept. |
|
|
64 |
ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept. |
|
|
65 |
show: Show the plot, do not return axis. |
|
|
66 |
ax: A matplotlib axes object. Only works if plotting a single component. |
|
|
67 |
title: Set the title of the plot. |
|
|
68 |
kwds: Passed to matplotblib scatterplot. |
|
|
69 |
|
|
|
70 |
Examples: |
|
|
71 |
>>> import ehrapy as ep |
|
|
72 |
>>> adata = ep.dt.mimic_2(encoded=False) |
|
|
73 |
>>> co2_lm_result = ep.tl.ols( |
|
|
74 |
... adata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop" |
|
|
75 |
... ).fit() |
|
|
76 |
>>> ep.pl.ols( |
|
|
77 |
... adata, |
|
|
78 |
... x="pco2_first", |
|
|
79 |
... y="tco2_first", |
|
|
80 |
... ols_results=[co2_lm_result], |
|
|
81 |
... ols_color=["red"], |
|
|
82 |
... xlabel="PCO2", |
|
|
83 |
... ylabel="TCO2", |
|
|
84 |
... ) |
|
|
85 |
|
|
|
86 |
.. image:: /_static/docstring_previews/ols_plot_1.png |
|
|
87 |
|
|
|
88 |
>>> import ehrapy as ep |
|
|
89 |
>>> adata = ep.dt.mimic_2(encoded=False) |
|
|
90 |
>>> ep.pl.ols(adata, x='pco2_first', y='tco2_first', lines=[(0.25, 10), (0.3, 20)], |
|
|
91 |
>>> lines_color=['red', 'blue'], lines_style=['-', ':'], lines_label=['Line1', 'Line2']) |
|
|
92 |
|
|
|
93 |
.. image:: /_static/docstring_previews/ols_plot_2.png |
|
|
94 |
|
|
|
95 |
>>> import ehrapy as ep |
|
|
96 |
>>> ep.pl.ols(lines=[(0.25, 10), (0.3, 20)], lines_color=['red', 'blue'], lines_style=['-', ':'], |
|
|
97 |
>>> lines_label=['Line1', 'Line2'], xlim=(0, 150), ylim=(0, 50)) |
|
|
98 |
|
|
|
99 |
.. image:: /_static/docstring_previews/ols_plot_3.png |
|
|
100 |
""" |
|
|
101 |
if ax is None: |
|
|
102 |
_, ax = plt.subplots(figsize=figsize) |
|
|
103 |
if xlim is not None: |
|
|
104 |
plt.xlim(xlim) |
|
|
105 |
if ylim is not None: |
|
|
106 |
plt.ylim(ylim) |
|
|
107 |
if ols_color is None and ols_results is not None: |
|
|
108 |
ols_color = [None] * len(ols_results) |
|
|
109 |
if lines_color is None and lines is not None: |
|
|
110 |
lines_color = [None] * len(lines) |
|
|
111 |
if lines_style is None and lines is not None: |
|
|
112 |
lines_style = [None] * len(lines) |
|
|
113 |
if lines_label is None and lines is not None: |
|
|
114 |
lines_label = [None] * len(lines) |
|
|
115 |
if adata is not None and x is not None and y is not None: |
|
|
116 |
x_processed = np.array(adata[:, x].X).astype(float) |
|
|
117 |
x_processed = x_processed[~np.isnan(x_processed)] |
|
|
118 |
if scatter_plot is True: |
|
|
119 |
ax = scatter(adata, x=x, y=y, show=False, ax=ax, **kwds) |
|
|
120 |
if ols_results is not None: |
|
|
121 |
for i, ols_result in enumerate(ols_results): |
|
|
122 |
ax.plot(x_processed, ols_result.predict(), color=ols_color[i]) |
|
|
123 |
|
|
|
124 |
if lines is not None: |
|
|
125 |
for i, line in enumerate(lines): |
|
|
126 |
a, b = line |
|
|
127 |
if np.ndim(a) == 0 and np.ndim(b) == 0: |
|
|
128 |
line_x = np.array(ax.get_xlim()) |
|
|
129 |
line_y = a * line_x + b |
|
|
130 |
ax.plot(line_x, line_y, linestyle=lines_style[i], color=lines_color[i], label=lines_label[i]) |
|
|
131 |
else: |
|
|
132 |
ax.plot(a, b, lines_style[i], color=lines_color[i], label=lines_label[i]) |
|
|
133 |
plt.xlabel(xlabel) |
|
|
134 |
plt.ylabel(ylabel) |
|
|
135 |
if title: |
|
|
136 |
plt.title(title) |
|
|
137 |
if lines_label is not None and lines_label[0] is not None: |
|
|
138 |
plt.legend() |
|
|
139 |
|
|
|
140 |
if not show: |
|
|
141 |
return ax |
|
|
142 |
else: |
|
|
143 |
return None |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
def kmf( |
|
|
147 |
kmfs: Sequence[KaplanMeierFitter], |
|
|
148 |
ci_alpha: list[float] | None = None, |
|
|
149 |
ci_force_lines: list[Boolean] | None = None, |
|
|
150 |
ci_show: list[Boolean] | None = None, |
|
|
151 |
ci_legend: list[Boolean] | None = None, |
|
|
152 |
at_risk_counts: list[Boolean] | None = None, |
|
|
153 |
color: list[str] | None | None = None, |
|
|
154 |
grid: Boolean | None = False, |
|
|
155 |
xlim: tuple[float, float] | None = None, |
|
|
156 |
ylim: tuple[float, float] | None = None, |
|
|
157 |
xlabel: str | None = None, |
|
|
158 |
ylabel: str | None = None, |
|
|
159 |
figsize: tuple[float, float] | None = None, |
|
|
160 |
show: bool | None = None, |
|
|
161 |
title: str | None = None, |
|
|
162 |
) -> Axes | None: |
|
|
163 |
warnings.warn( |
|
|
164 |
"This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.", |
|
|
165 |
DeprecationWarning, |
|
|
166 |
stacklevel=2, |
|
|
167 |
) |
|
|
168 |
return kaplan_meier( |
|
|
169 |
kmfs=kmfs, |
|
|
170 |
ci_alpha=ci_alpha, |
|
|
171 |
ci_force_lines=ci_force_lines, |
|
|
172 |
ci_show=ci_show, |
|
|
173 |
ci_legend=ci_legend, |
|
|
174 |
at_risk_counts=at_risk_counts, |
|
|
175 |
color=color, |
|
|
176 |
grid=grid, |
|
|
177 |
xlim=xlim, |
|
|
178 |
ylim=ylim, |
|
|
179 |
xlabel=xlabel, |
|
|
180 |
ylabel=ylabel, |
|
|
181 |
figsize=figsize, |
|
|
182 |
show=show, |
|
|
183 |
title=title, |
|
|
184 |
) |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
def kaplan_meier( |
|
|
188 |
kmfs: Sequence[KaplanMeierFitter], |
|
|
189 |
*, |
|
|
190 |
display_survival_statistics: bool = False, |
|
|
191 |
ci_alpha: list[float] | None = None, |
|
|
192 |
ci_force_lines: list[Boolean] | None = None, |
|
|
193 |
ci_show: list[Boolean] | None = None, |
|
|
194 |
ci_legend: list[Boolean] | None = None, |
|
|
195 |
at_risk_counts: list[Boolean] | None = None, |
|
|
196 |
color: list[str] | None | None = None, |
|
|
197 |
grid: Boolean | None = False, |
|
|
198 |
xlim: tuple[float, float] | None = None, |
|
|
199 |
ylim: tuple[float, float] | None = None, |
|
|
200 |
xlabel: str | None = None, |
|
|
201 |
ylabel: str | None = None, |
|
|
202 |
figsize: tuple[float, float] | None = None, |
|
|
203 |
show: bool | None = None, |
|
|
204 |
title: str | None = None, |
|
|
205 |
) -> Axes | None: |
|
|
206 |
"""Plots a pretty figure of the Fitted KaplanMeierFitter model. |
|
|
207 |
|
|
|
208 |
See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html |
|
|
209 |
|
|
|
210 |
Args: |
|
|
211 |
kmfs: Iterables of fitted KaplanMeierFitter objects. |
|
|
212 |
display_survival_statistics: Whether to show survival statistics in a table below the plot. |
|
|
213 |
ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. |
|
|
214 |
ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). |
|
|
215 |
If more than one kmfs, this should be a list. |
|
|
216 |
ci_show: Show confidence intervals. If more than one kmfs, this should be a list. |
|
|
217 |
ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend. |
|
|
218 |
If more than one kmfs, this should be a list. |
|
|
219 |
at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list. |
|
|
220 |
color: List of colors for each kmf. If more than one kmfs, this should be a list. |
|
|
221 |
grid: If True, plot grid lines. |
|
|
222 |
xlim: Set the x-axis view limits. |
|
|
223 |
ylim: Set the y-axis view limits. |
|
|
224 |
xlabel: The x-axis label text. |
|
|
225 |
ylabel: The y-axis label text. |
|
|
226 |
figsize: Width, height in inches. |
|
|
227 |
show: Show the plot, do not return axis. |
|
|
228 |
title: Set the title of the plot. |
|
|
229 |
|
|
|
230 |
Examples: |
|
|
231 |
>>> import ehrapy as ep |
|
|
232 |
>>> import numpy as np |
|
|
233 |
>>> adata = ep.dt.mimic_2(encoded=False) |
|
|
234 |
|
|
|
235 |
# Because in MIMIC-II database, `censor_fl` is censored or death (binary: 0 = death, 1 = censored). |
|
|
236 |
# While in KaplanMeierFitter, `event_observed` is True if the the death was observed, False if the event was lost (right-censored). |
|
|
237 |
# So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter |
|
|
238 |
|
|
|
239 |
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) |
|
|
240 |
>>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg") |
|
|
241 |
>>> ep.pl.kaplan_meier( |
|
|
242 |
... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True |
|
|
243 |
... ) |
|
|
244 |
|
|
|
245 |
.. image:: /_static/docstring_previews/kmf_plot_1.png |
|
|
246 |
|
|
|
247 |
>>> groups = adata[:, ["service_unit"]].X |
|
|
248 |
>>> adata_ficu = adata[groups == "FICU"] |
|
|
249 |
>>> adata_micu = adata[groups == "MICU"] |
|
|
250 |
>>> adata_sicu = adata[groups == "SICU"] |
|
|
251 |
>>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU") |
|
|
252 |
>>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU") |
|
|
253 |
>>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU") |
|
|
254 |
>>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], |
|
|
255 |
>>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived") |
|
|
256 |
|
|
|
257 |
.. image:: /_static/docstring_previews/kmf_plot_2.png |
|
|
258 |
""" |
|
|
259 |
if ci_alpha is None: |
|
|
260 |
ci_alpha = [0.3] * len(kmfs) |
|
|
261 |
if ci_force_lines is None: |
|
|
262 |
ci_force_lines = [False] * len(kmfs) |
|
|
263 |
if ci_show is None: |
|
|
264 |
ci_show = [True] * len(kmfs) |
|
|
265 |
if ci_legend is None: |
|
|
266 |
ci_legend = [False] * len(kmfs) |
|
|
267 |
if at_risk_counts is None: |
|
|
268 |
at_risk_counts = [False] * len(kmfs) |
|
|
269 |
if color is None: |
|
|
270 |
color = [None] * len(kmfs) |
|
|
271 |
|
|
|
272 |
fig = plt.figure(constrained_layout=True, figsize=figsize) |
|
|
273 |
spec = fig.add_gridspec(2, 1) if display_survival_statistics else fig.add_gridspec(1, 1) |
|
|
274 |
ax = plt.subplot(spec[0, 0]) |
|
|
275 |
|
|
|
276 |
for i, kmf in enumerate(kmfs): |
|
|
277 |
if i == 0: |
|
|
278 |
ax = kmf.plot_survival_function( |
|
|
279 |
ci_alpha=ci_alpha[i], |
|
|
280 |
ci_force_lines=ci_force_lines[i], |
|
|
281 |
ci_show=ci_show[i], |
|
|
282 |
ci_legend=ci_legend[i], |
|
|
283 |
at_risk_counts=at_risk_counts[i], |
|
|
284 |
color=color[i], |
|
|
285 |
) |
|
|
286 |
else: |
|
|
287 |
ax = kmf.plot_survival_function( |
|
|
288 |
ax=ax, |
|
|
289 |
ci_alpha=ci_alpha[i], |
|
|
290 |
ci_force_lines=ci_force_lines[i], |
|
|
291 |
ci_show=ci_show[i], |
|
|
292 |
ci_legend=ci_legend[i], |
|
|
293 |
at_risk_counts=at_risk_counts[i], |
|
|
294 |
color=color[i], |
|
|
295 |
) |
|
|
296 |
# Configure plot appearance |
|
|
297 |
ax.grid(grid) |
|
|
298 |
ax.set_xlim(xlim) |
|
|
299 |
ax.set_ylim(ylim) |
|
|
300 |
ax.set_xlabel(xlabel) |
|
|
301 |
ax.set_ylabel(ylabel) |
|
|
302 |
if title: |
|
|
303 |
ax.set_title(title) |
|
|
304 |
|
|
|
305 |
if display_survival_statistics: |
|
|
306 |
xticks = [x for x in ax.get_xticks() if x >= 0] |
|
|
307 |
xticks_space = xticks[1] - xticks[0] |
|
|
308 |
if xlabel is None: |
|
|
309 |
xlabel = "Time" |
|
|
310 |
|
|
|
311 |
yticks = np.arange(len(kmfs)) |
|
|
312 |
|
|
|
313 |
ax_table = plt.subplot(spec[1, 0]) |
|
|
314 |
ax_table.set_xticks(xticks) |
|
|
315 |
ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2) |
|
|
316 |
ax_table.set_ylim(-1, len(kmfs)) |
|
|
317 |
ax_table.set_yticks(yticks) |
|
|
318 |
ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])]) |
|
|
319 |
|
|
|
320 |
for i, kmf in enumerate(kmfs[::-1]): |
|
|
321 |
survival_probs = kmf.survival_function_at_times(xticks).values |
|
|
322 |
for j, prob in enumerate(survival_probs): |
|
|
323 |
ax_table.text( |
|
|
324 |
xticks[j], # x position |
|
|
325 |
yticks[i], # y position |
|
|
326 |
f"{prob:.2f}", # formatted survival probability |
|
|
327 |
ha="center", |
|
|
328 |
va="center", |
|
|
329 |
bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"}, |
|
|
330 |
) |
|
|
331 |
|
|
|
332 |
ax_table.grid(grid) |
|
|
333 |
ax_table.spines["top"].set_visible(False) |
|
|
334 |
ax_table.spines["right"].set_visible(False) |
|
|
335 |
ax_table.spines["bottom"].set_visible(False) |
|
|
336 |
ax_table.spines["left"].set_visible(False) |
|
|
337 |
|
|
|
338 |
if not show: |
|
|
339 |
return fig, ax |
|
|
340 |
|
|
|
341 |
else: |
|
|
342 |
return None |
|
|
343 |
|
|
|
344 |
|
|
|
345 |
def cox_ph_forestplot( |
|
|
346 |
adata: AnnData, |
|
|
347 |
*, |
|
|
348 |
uns_key: str = "cox_ph", |
|
|
349 |
labels: Iterable[str] | None = None, |
|
|
350 |
fig_size: tuple = (10, 10), |
|
|
351 |
t_adjuster: float = 0.1, |
|
|
352 |
ecolor: str = "dimgray", |
|
|
353 |
size: int = 3, |
|
|
354 |
marker: str = "o", |
|
|
355 |
decimal: int = 2, |
|
|
356 |
text_size: int = 12, |
|
|
357 |
color: str = "k", |
|
|
358 |
show: bool = None, |
|
|
359 |
title: str | None = None, |
|
|
360 |
): |
|
|
361 |
"""Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model. |
|
|
362 |
|
|
|
363 |
The `adata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `adata`. |
|
|
364 |
The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function. |
|
|
365 |
For more information on the `CoxPHFitter`, see the `Lifelines documentation <https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html>`_. |
|
|
366 |
|
|
|
367 |
Inspired by `zepid.graphics.EffectMeasurePlot <https://readthedocs.org>`_ (zEpid Package, https://pypi.org/project/zepid/). |
|
|
368 |
|
|
|
369 |
Args: |
|
|
370 |
adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`. |
|
|
371 |
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`. |
|
|
372 |
labels: List of labels for each coefficient, default uses the index of the summary ta |
|
|
373 |
fig_size: Width, height in inches. |
|
|
374 |
t_adjuster: Adjust the table to the right. |
|
|
375 |
ecolor: Color of the error bars. |
|
|
376 |
size: Size of the markers. |
|
|
377 |
marker: Marker style. |
|
|
378 |
decimal: Number of decimal places to display. |
|
|
379 |
text_size: Font size of the text. |
|
|
380 |
color: Color of the markers. |
|
|
381 |
show: Show the plot, do not return figure and axis. |
|
|
382 |
title: Set the title of the plot. |
|
|
383 |
|
|
|
384 |
Examples: |
|
|
385 |
>>> import ehrapy as ep |
|
|
386 |
>>> adata = ep.dt.mimic_2(encoded=False) |
|
|
387 |
>>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] |
|
|
388 |
>>> coxph = ep.tl.cox_ph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored") |
|
|
389 |
>>> ep.pl.cox_ph_forestplot(adata_subset) |
|
|
390 |
|
|
|
391 |
.. image:: /_static/docstring_previews/coxph_forestplot.png |
|
|
392 |
|
|
|
393 |
""" |
|
|
394 |
if uns_key not in adata.uns: |
|
|
395 |
raise ValueError(f"Key {uns_key} not found in adata.uns. Please provide a valid key.") |
|
|
396 |
|
|
|
397 |
coxph_fitting_summary = adata.uns[ |
|
|
398 |
uns_key |
|
|
399 |
] # pd.Dataframe with columns: coef, exp(coef), se(coef), z, p, lower 0.95, upper 0.95 |
|
|
400 |
auc_col = "coef" |
|
|
401 |
|
|
|
402 |
if labels is None: |
|
|
403 |
labels = coxph_fitting_summary.index |
|
|
404 |
tval = [] |
|
|
405 |
ytick = [] |
|
|
406 |
for row_index in range(len(coxph_fitting_summary)): |
|
|
407 |
if not np.isnan(coxph_fitting_summary[auc_col][row_index]): |
|
|
408 |
if ( |
|
|
409 |
(isinstance(coxph_fitting_summary[auc_col][row_index], float)) |
|
|
410 |
& (isinstance(coxph_fitting_summary["coef lower 95%"][row_index], float)) |
|
|
411 |
& (isinstance(coxph_fitting_summary["coef upper 95%"][row_index], float)) |
|
|
412 |
): |
|
|
413 |
tval.append( |
|
|
414 |
[ |
|
|
415 |
round(coxph_fitting_summary[auc_col][row_index], decimal), |
|
|
416 |
( |
|
|
417 |
"(" |
|
|
418 |
+ str(round(coxph_fitting_summary["coef lower 95%"][row_index], decimal)) |
|
|
419 |
+ ", " |
|
|
420 |
+ str(round(coxph_fitting_summary["coef upper 95%"][row_index], decimal)) |
|
|
421 |
+ ")" |
|
|
422 |
), |
|
|
423 |
] |
|
|
424 |
) |
|
|
425 |
else: |
|
|
426 |
tval.append( |
|
|
427 |
[ |
|
|
428 |
coxph_fitting_summary[auc_col][row_index], |
|
|
429 |
( |
|
|
430 |
"(" |
|
|
431 |
+ str(coxph_fitting_summary["coef lower 95%"][row_index]) |
|
|
432 |
+ ", " |
|
|
433 |
+ str(coxph_fitting_summary["coef upper 95%"][row_index]) |
|
|
434 |
+ ")" |
|
|
435 |
), |
|
|
436 |
] |
|
|
437 |
) |
|
|
438 |
ytick.append(row_index) |
|
|
439 |
else: |
|
|
440 |
tval.append([" ", " "]) |
|
|
441 |
ytick.append(row_index) |
|
|
442 |
|
|
|
443 |
x_axis_upper_bound = round(((pd.to_numeric(coxph_fitting_summary["coef upper 95%"])).max() + 0.1), 2) |
|
|
444 |
|
|
|
445 |
x_axis_lower_bound = round(((pd.to_numeric(coxph_fitting_summary["coef lower 95%"])).min() - 0.1), 1) |
|
|
446 |
|
|
|
447 |
fig = plt.figure(figsize=fig_size) |
|
|
448 |
gspec = gridspec.GridSpec(1, 6) |
|
|
449 |
plot = plt.subplot(gspec[0, 0:4]) |
|
|
450 |
table = plt.subplot(gspec[0, 4:]) |
|
|
451 |
plot.set_ylim(-1, (len(coxph_fitting_summary))) # spacing out y-axis properly |
|
|
452 |
|
|
|
453 |
plot.axvline(1, color="gray", zorder=1) |
|
|
454 |
lower_diff = coxph_fitting_summary[auc_col] - coxph_fitting_summary["coef lower 95%"] |
|
|
455 |
upper_diff = coxph_fitting_summary["coef upper 95%"] - coxph_fitting_summary[auc_col] |
|
|
456 |
plot.errorbar( |
|
|
457 |
coxph_fitting_summary[auc_col], |
|
|
458 |
coxph_fitting_summary.index, |
|
|
459 |
xerr=[lower_diff, upper_diff], |
|
|
460 |
marker="None", |
|
|
461 |
zorder=2, |
|
|
462 |
ecolor=ecolor, |
|
|
463 |
linewidth=0, |
|
|
464 |
elinewidth=1, |
|
|
465 |
) |
|
|
466 |
# plot markers |
|
|
467 |
plot.scatter( |
|
|
468 |
coxph_fitting_summary[auc_col], |
|
|
469 |
coxph_fitting_summary.index, |
|
|
470 |
c=color, |
|
|
471 |
s=(size * 25), |
|
|
472 |
marker=marker, |
|
|
473 |
zorder=3, |
|
|
474 |
edgecolors="None", |
|
|
475 |
) |
|
|
476 |
# plot settings |
|
|
477 |
plot.xaxis.set_ticks_position("bottom") |
|
|
478 |
plot.yaxis.set_ticks_position("left") |
|
|
479 |
plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter()) |
|
|
480 |
plot.get_xaxis().set_minor_formatter(ticker.NullFormatter()) |
|
|
481 |
plot.set_yticks(ytick) |
|
|
482 |
plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound]) |
|
|
483 |
plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound]) |
|
|
484 |
plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound]) |
|
|
485 |
plot.set_yticklabels(labels) |
|
|
486 |
plot.tick_params(axis="y", labelsize=text_size) |
|
|
487 |
plot.yaxis.set_ticks_position("none") |
|
|
488 |
plot.invert_yaxis() # invert y-axis to align values properly with table |
|
|
489 |
tb = table.table( |
|
|
490 |
cellText=tval, cellLoc="center", loc="right", colLabels=[auc_col, "95% CI"], bbox=[0, t_adjuster, 1, 1] |
|
|
491 |
) |
|
|
492 |
table.axis("off") |
|
|
493 |
tb.auto_set_font_size(False) |
|
|
494 |
tb.set_fontsize(text_size) |
|
|
495 |
for _, cell in tb.get_celld().items(): |
|
|
496 |
cell.set_linewidth(0) |
|
|
497 |
|
|
|
498 |
# remove spines |
|
|
499 |
plot.spines["top"].set_visible(False) |
|
|
500 |
plot.spines["right"].set_visible(False) |
|
|
501 |
plot.spines["left"].set_visible(False) |
|
|
502 |
|
|
|
503 |
if title: |
|
|
504 |
plt.title(title) |
|
|
505 |
|
|
|
506 |
if not show: |
|
|
507 |
return fig, plot |
|
|
508 |
|
|
|
509 |
else: |
|
|
510 |
return None |