|
a |
|
b/src/nichecompass/utils/analysis.py |
|
|
1 |
""" |
|
|
2 |
This module contains utilities to analyze niches inferred by the NicheCompass |
|
|
3 |
model. |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
from typing import Optional, Tuple |
|
|
7 |
|
|
|
8 |
#import holoviews as hv |
|
|
9 |
import matplotlib.pyplot as plt |
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
import scanpy as sc |
|
|
13 |
import scipy.sparse as sp |
|
|
14 |
import seaborn as sns |
|
|
15 |
from anndata import AnnData |
|
|
16 |
from matplotlib import cm, colors |
|
|
17 |
from matplotlib.lines import Line2D |
|
|
18 |
import networkx as nx |
|
|
19 |
|
|
|
20 |
from ..models import NicheCompass |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def aggregate_obsp_matrix_per_cell_type( |
|
|
24 |
adata: AnnData, |
|
|
25 |
obsp_key: str, |
|
|
26 |
cell_type_key: str="cell_type", |
|
|
27 |
group_key: Optional[str]=None, |
|
|
28 |
agg_rows: bool=False): |
|
|
29 |
""" |
|
|
30 |
Generic function to aggregate adjacency matrices stored in |
|
|
31 |
´adata.obsp[obsp_key]´ on cell type level. It can be used to aggregate the |
|
|
32 |
node label aggregator aggregation weights alpha or the reconstructed adjacency |
|
|
33 |
matrix of a trained NicheCompass model by neighbor cell type for downstream |
|
|
34 |
analysis. |
|
|
35 |
|
|
|
36 |
Parameters |
|
|
37 |
---------- |
|
|
38 |
adata: |
|
|
39 |
AnnData object which contains outputs of NicheCompass model training. |
|
|
40 |
obsp_key: |
|
|
41 |
Key in ´adata.obsp´ where the matrix to be aggregated is stored. |
|
|
42 |
cell_type_key: |
|
|
43 |
Key in ´adata.obs´ where the cell type labels are stored. |
|
|
44 |
group_key: |
|
|
45 |
Key in ´adata.obs´ where additional grouping labels are stored. |
|
|
46 |
agg_rows: |
|
|
47 |
If ´True´, also aggregate over the observations on cell type level. |
|
|
48 |
|
|
|
49 |
Returns |
|
|
50 |
---------- |
|
|
51 |
cell_type_agg_df: |
|
|
52 |
Pandas DataFrame with the aggregated obsp values (dim: n_obs x |
|
|
53 |
n_cell_types if ´agg_rows == False´, else n_cell_types x n_cell_types). |
|
|
54 |
""" |
|
|
55 |
n_obs = len(adata) |
|
|
56 |
n_cell_types = adata.obs[cell_type_key].nunique() |
|
|
57 |
sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist()) |
|
|
58 |
|
|
|
59 |
cell_type_label_encoder = {k: v for k, v in zip( |
|
|
60 |
sorted_cell_types, |
|
|
61 |
range(n_cell_types))} |
|
|
62 |
|
|
|
63 |
# Retrieve non zero indices and non zero values, and create row-wise |
|
|
64 |
# observation cell type index |
|
|
65 |
nz_obsp_idx = adata.obsp[obsp_key].nonzero() |
|
|
66 |
neighbor_cell_type_index = adata.obs[cell_type_key][nz_obsp_idx[1]].map( |
|
|
67 |
cell_type_label_encoder).values |
|
|
68 |
adata.obsp[obsp_key].eliminate_zeros() # In some sparse reps 0s can appear |
|
|
69 |
nz_obsp = adata.obsp[obsp_key].data |
|
|
70 |
|
|
|
71 |
# Use non zero indices, non zero values and row-wise observation cell type |
|
|
72 |
# index to construct new df with cell types as columns and row-wise |
|
|
73 |
# aggregated values per cell type index as values |
|
|
74 |
cell_type_agg = np.zeros((n_obs, n_cell_types)) |
|
|
75 |
np.add.at(cell_type_agg, |
|
|
76 |
(nz_obsp_idx[0], neighbor_cell_type_index), |
|
|
77 |
nz_obsp) |
|
|
78 |
cell_type_agg_df = pd.DataFrame( |
|
|
79 |
cell_type_agg, |
|
|
80 |
columns=sorted_cell_types) |
|
|
81 |
|
|
|
82 |
# Add cell type labels of observations |
|
|
83 |
cell_type_agg_df[cell_type_key] = adata.obs[cell_type_key].values |
|
|
84 |
|
|
|
85 |
# If specified, add group label |
|
|
86 |
if group_key is not None: |
|
|
87 |
cell_type_agg_df[group_key] = adata.obs[group_key].values |
|
|
88 |
|
|
|
89 |
if agg_rows: |
|
|
90 |
# In addition, aggregate values across rows to get a |
|
|
91 |
# (n_cell_types x n_cell_types) df |
|
|
92 |
if group_key is not None: |
|
|
93 |
cell_type_agg_df = cell_type_agg_df.groupby( |
|
|
94 |
[group_key, cell_type_key]).sum() |
|
|
95 |
else: |
|
|
96 |
cell_type_agg_df = cell_type_agg_df.groupby(cell_type_key).sum() |
|
|
97 |
|
|
|
98 |
# Sort index to have same order as columns |
|
|
99 |
cell_type_agg_df = cell_type_agg_df.loc[ |
|
|
100 |
sorted(cell_type_agg_df.index.tolist()), :] |
|
|
101 |
|
|
|
102 |
return cell_type_agg_df |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def create_cell_type_chord_plot_from_df( |
|
|
106 |
adata: AnnData, |
|
|
107 |
df: pd.DataFrame, |
|
|
108 |
link_threshold: float=0.01, |
|
|
109 |
cell_type_key: str="cell_type", |
|
|
110 |
group_key: Optional[str]=None, |
|
|
111 |
groups: str="all", |
|
|
112 |
plot_label: str="Niche", |
|
|
113 |
save_fig: bool=False, |
|
|
114 |
file_path: Optional[str]=None): |
|
|
115 |
""" |
|
|
116 |
Create a cell type chord diagram per group based on an input DataFrame. |
|
|
117 |
|
|
|
118 |
Parameters |
|
|
119 |
---------- |
|
|
120 |
adata: |
|
|
121 |
AnnData object which contains outputs of NicheCompass model training. |
|
|
122 |
df: |
|
|
123 |
A Pandas DataFrame that contains the connection values for the chord |
|
|
124 |
plot (dim: (n_groups x n_cell_types) x n_cell_types). |
|
|
125 |
link_threshold: |
|
|
126 |
Ratio of link strength that a cell type pair needs to exceed compared to |
|
|
127 |
the cell type pair with the maximum link strength to be considered a |
|
|
128 |
link for the chord plot. |
|
|
129 |
cell_type_key: |
|
|
130 |
Key in ´adata.obs´ where the cell type labels are stored. |
|
|
131 |
group_key: |
|
|
132 |
Key in ´adata.obs´ where additional group labels are stored. |
|
|
133 |
groups: |
|
|
134 |
List of groups that will be plotted. If ´all´, plot all groups. |
|
|
135 |
plot_label: |
|
|
136 |
Shared label for the plots. |
|
|
137 |
save_fig: |
|
|
138 |
If ´True´, save the figure. |
|
|
139 |
file_path: |
|
|
140 |
Path where to save the figure. |
|
|
141 |
""" |
|
|
142 |
hv.extension("bokeh") |
|
|
143 |
hv.output(size=200) |
|
|
144 |
|
|
|
145 |
sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist()) |
|
|
146 |
|
|
|
147 |
# Get group labels |
|
|
148 |
if (group_key is not None) & (groups == "all"): |
|
|
149 |
group_labels = df.index.get_level_values( |
|
|
150 |
df.index.names.index(group_key)).unique().tolist() |
|
|
151 |
elif (group_key is not None) & (groups != "all"): |
|
|
152 |
group_labels = groups |
|
|
153 |
else: |
|
|
154 |
group_labels = [""] |
|
|
155 |
|
|
|
156 |
chord_list = [] |
|
|
157 |
for group_label in group_labels: |
|
|
158 |
if group_label == "": |
|
|
159 |
group_df = df |
|
|
160 |
else: |
|
|
161 |
group_df = df[df.index.get_level_values( |
|
|
162 |
df.index.names.index(group_key)) == group_label] |
|
|
163 |
|
|
|
164 |
# Get max value (over rows and columns) of the group for thresholding |
|
|
165 |
group_max = group_df.max().max() |
|
|
166 |
|
|
|
167 |
# Create group chord links |
|
|
168 |
links_list = [] |
|
|
169 |
for i in range(len(sorted_cell_types)): |
|
|
170 |
for j in range(len(sorted_cell_types)): |
|
|
171 |
if group_df.iloc[i, j] > group_max * link_threshold: |
|
|
172 |
link_dict = {} |
|
|
173 |
link_dict["source"] = j |
|
|
174 |
link_dict["target"] = i |
|
|
175 |
link_dict["value"] = group_df.iloc[i, j] |
|
|
176 |
links_list.append(link_dict) |
|
|
177 |
links = pd.DataFrame(links_list) |
|
|
178 |
|
|
|
179 |
# Create group chord nodes (only where links exist) |
|
|
180 |
nodes_list = [] |
|
|
181 |
nodes_idx = [] |
|
|
182 |
for i, cell_type in enumerate(sorted_cell_types): |
|
|
183 |
if i in (links["source"].values) or i in (links["target"].values): |
|
|
184 |
nodes_idx.append(i) |
|
|
185 |
nodes_dict = {} |
|
|
186 |
nodes_dict["name"] = cell_type |
|
|
187 |
nodes_dict["group"] = 1 |
|
|
188 |
nodes_list.append(nodes_dict) |
|
|
189 |
nodes = hv.Dataset(pd.DataFrame(nodes_list, index=nodes_idx), "index") |
|
|
190 |
|
|
|
191 |
# Create group chord plot |
|
|
192 |
chord = hv.Chord((links, nodes)).select(value=(5, None)) |
|
|
193 |
chord.opts(hv.opts.Chord(cmap="Category20", |
|
|
194 |
edge_cmap="Category20", |
|
|
195 |
edge_color=hv.dim("source").str(), |
|
|
196 |
labels="name", |
|
|
197 |
node_color=hv.dim("index").str(), |
|
|
198 |
title=f"{plot_label} {group_label}")) |
|
|
199 |
chord_list.append(chord) |
|
|
200 |
|
|
|
201 |
# Display chord plots |
|
|
202 |
layout = hv.Layout(chord_list).cols(2) |
|
|
203 |
hv.output(layout) |
|
|
204 |
|
|
|
205 |
# Save chord plots |
|
|
206 |
if save_fig: |
|
|
207 |
hv.save(layout, |
|
|
208 |
file_path, |
|
|
209 |
fmt="png") |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
def generate_enriched_gp_info_plots(plot_label: str, |
|
|
213 |
model: NicheCompass, |
|
|
214 |
sample_key: str, |
|
|
215 |
differential_gp_test_results_key: str, |
|
|
216 |
cat_key: str, |
|
|
217 |
cat_palette: dict, |
|
|
218 |
n_top_enriched_gp_start_idx: int=0, |
|
|
219 |
n_top_enriched_gp_end_idx: int=10, |
|
|
220 |
feature_spaces: list=["latent"], |
|
|
221 |
n_top_genes_per_gp: int=3, |
|
|
222 |
n_top_peaks_per_gp: int=0, |
|
|
223 |
scale_omics_ft: bool=False, |
|
|
224 |
save_figs: bool=False, |
|
|
225 |
figure_folder_path: str="", |
|
|
226 |
file_format: str="png", |
|
|
227 |
spot_size: float=30.): |
|
|
228 |
""" |
|
|
229 |
Generate info plots of enriched gene programs. These show the enriched |
|
|
230 |
category, the gp activities, as well as the counts (or log normalized |
|
|
231 |
counts) of the top genes and/or peaks in a specified feature space. |
|
|
232 |
|
|
|
233 |
Parameters |
|
|
234 |
---------- |
|
|
235 |
plot_label: |
|
|
236 |
Main label of the plots. |
|
|
237 |
model: |
|
|
238 |
A trained NicheCompass model. |
|
|
239 |
sample_key: |
|
|
240 |
Key in ´adata.obs´ where the samples are stored. |
|
|
241 |
differential_gp_test_results_key: |
|
|
242 |
Key in ´adata.uns´ where the results of the differential gene program |
|
|
243 |
testing are stored. |
|
|
244 |
cat_key: |
|
|
245 |
Key in ´adata.obs´ where the categories that are used as colors for the |
|
|
246 |
enriched category plot are stored. |
|
|
247 |
cat_palette: |
|
|
248 |
Dictionary of colors that are used to highlight the categories, where |
|
|
249 |
the category is the key of the dictionary and the color is the value. |
|
|
250 |
n_top_enriched_gp_start_idx: |
|
|
251 |
Number of top enriched gene program from which to start the creation |
|
|
252 |
of plots. |
|
|
253 |
n_top_enriched_gp_end_idx: |
|
|
254 |
Number of top enriched gene program at which to stop the creation |
|
|
255 |
of plots. |
|
|
256 |
feature_spaces: |
|
|
257 |
List of feature spaces used for the info plots. Can be ´latent´ to use |
|
|
258 |
the latent embeddings for the plots, or it can be any of the samples |
|
|
259 |
stored in ´adata.obs[sample_key]´ to use the respective physical |
|
|
260 |
feature space for the plots. |
|
|
261 |
n_top_genes_per_gp: |
|
|
262 |
Number of top genes per gp to be considered in the info plots. |
|
|
263 |
n_top_peaks_per_gp: |
|
|
264 |
Number of top peaks per gp to be considered in the info plots. If ´>0´, |
|
|
265 |
requires the model to be trained inlcuding ATAC modality. |
|
|
266 |
scale_omics_ft: |
|
|
267 |
If ´True´, scale genes and peaks before plotting. |
|
|
268 |
save_figs: |
|
|
269 |
If ´True´, save the figures. |
|
|
270 |
figure_folder_path: |
|
|
271 |
Folder path where the figures will be saved. |
|
|
272 |
file_format: |
|
|
273 |
Format with which the figures will be saved. |
|
|
274 |
spot_size: |
|
|
275 |
Spot size used for the spatial plots. |
|
|
276 |
""" |
|
|
277 |
model._check_if_trained(warn=True) |
|
|
278 |
|
|
|
279 |
adata = model.adata.copy() |
|
|
280 |
if n_top_peaks_per_gp > 0: |
|
|
281 |
if "atac" not in model.modalities_: |
|
|
282 |
raise ValueError("The model needs to be trained with ATAC data if" |
|
|
283 |
"'n_top_peaks_per_gp' > 0.") |
|
|
284 |
adata_atac = model.adata_atac.copy() |
|
|
285 |
|
|
|
286 |
# TODO |
|
|
287 |
if scale_omics_ft: |
|
|
288 |
sc.pp.scale(adata) |
|
|
289 |
if n_top_peaks_per_gp > 0: |
|
|
290 |
sc.pp.scale(adata_atac) |
|
|
291 |
adata.uns["omics_ft_pos_cmap"] = "RdBu" |
|
|
292 |
adata.uns["omics_ft_neg_cmap"] = "RdBu_r" |
|
|
293 |
else: |
|
|
294 |
if n_top_peaks_per_gp > 0: |
|
|
295 |
adata_atac.X = adata_atac.X.toarray() |
|
|
296 |
adata.uns["omics_ft_pos_cmap"] = "Blues" |
|
|
297 |
adata.uns["omics_ft_neg_cmap"] = "Reds" |
|
|
298 |
|
|
|
299 |
cats = list(adata.uns[differential_gp_test_results_key]["category"][ |
|
|
300 |
n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) |
|
|
301 |
gps = list(adata.uns[differential_gp_test_results_key]["gene_program"][ |
|
|
302 |
n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) |
|
|
303 |
log_bayes_factors = list(adata.uns[differential_gp_test_results_key]["log_bayes_factor"][ |
|
|
304 |
n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) |
|
|
305 |
|
|
|
306 |
for gp in gps: |
|
|
307 |
# Get source and target genes, gene importances and gene signs and store |
|
|
308 |
# in temporary adata |
|
|
309 |
gp_gene_importances_df = model.compute_gp_gene_importances( |
|
|
310 |
selected_gp=gp) |
|
|
311 |
|
|
|
312 |
gp_source_genes_gene_importances_df = gp_gene_importances_df[ |
|
|
313 |
gp_gene_importances_df["gene_entity"] == "source"] |
|
|
314 |
gp_target_genes_gene_importances_df = gp_gene_importances_df[ |
|
|
315 |
gp_gene_importances_df["gene_entity"] == "target"] |
|
|
316 |
adata.uns["n_top_source_genes"] = n_top_genes_per_gp |
|
|
317 |
adata.uns[f"{gp}_source_genes_top_genes"] = ( |
|
|
318 |
gp_source_genes_gene_importances_df["gene"][ |
|
|
319 |
:n_top_genes_per_gp].values) |
|
|
320 |
adata.uns[f"{gp}_source_genes_top_gene_importances"] = ( |
|
|
321 |
gp_source_genes_gene_importances_df["gene_importance"][ |
|
|
322 |
:n_top_genes_per_gp].values) |
|
|
323 |
adata.uns[f"{gp}_source_genes_top_gene_signs"] = ( |
|
|
324 |
np.where(gp_source_genes_gene_importances_df[ |
|
|
325 |
"gene_weight"] > 0, "+", "-")) |
|
|
326 |
adata.uns["n_top_target_genes"] = n_top_genes_per_gp |
|
|
327 |
adata.uns[f"{gp}_target_genes_top_genes"] = ( |
|
|
328 |
gp_target_genes_gene_importances_df["gene"][ |
|
|
329 |
:n_top_genes_per_gp].values) |
|
|
330 |
adata.uns[f"{gp}_target_genes_top_gene_importances"] = ( |
|
|
331 |
gp_target_genes_gene_importances_df["gene_importance"][ |
|
|
332 |
:n_top_genes_per_gp].values) |
|
|
333 |
adata.uns[f"{gp}_target_genes_top_gene_signs"] = ( |
|
|
334 |
np.where(gp_target_genes_gene_importances_df[ |
|
|
335 |
"gene_weight"] > 0, "+", "-")) |
|
|
336 |
|
|
|
337 |
if n_top_peaks_per_gp > 0: |
|
|
338 |
# Get source and target peaks, peak importances and peak signs and |
|
|
339 |
# store in temporary adata |
|
|
340 |
gp_peak_importances_df = model.compute_gp_peak_importances( |
|
|
341 |
selected_gp=gp) |
|
|
342 |
gp_source_peaks_peak_importances_df = gp_peak_importances_df[ |
|
|
343 |
gp_peak_importances_df["peak_entity"] == "source"] |
|
|
344 |
gp_target_peaks_peak_importances_df = gp_peak_importances_df[ |
|
|
345 |
gp_peak_importances_df["peak_entity"] == "target"] |
|
|
346 |
adata.uns["n_top_source_peaks"] = n_top_peaks_per_gp |
|
|
347 |
adata.uns[f"{gp}_source_peaks_top_peaks"] = ( |
|
|
348 |
gp_source_peaks_peak_importances_df["peak"][ |
|
|
349 |
:n_top_peaks_per_gp].values) |
|
|
350 |
adata.uns[f"{gp}_source_peaks_top_peak_importances"] = ( |
|
|
351 |
gp_source_peaks_peak_importances_df["peak_importance"][ |
|
|
352 |
:n_top_peaks_per_gp].values) |
|
|
353 |
adata.uns[f"{gp}_source_peaks_top_peak_signs"] = ( |
|
|
354 |
np.where(gp_source_peaks_peak_importances_df[ |
|
|
355 |
"peak_weight"] > 0, "+", "-")) |
|
|
356 |
adata.uns["n_top_target_peaks"] = n_top_peaks_per_gp |
|
|
357 |
adata.uns[f"{gp}_target_peaks_top_peaks"] = ( |
|
|
358 |
gp_target_peaks_peak_importances_df["peak"][ |
|
|
359 |
:n_top_peaks_per_gp].values) |
|
|
360 |
adata.uns[f"{gp}_target_peaks_top_peak_importances"] = ( |
|
|
361 |
gp_target_peaks_peak_importances_df["peak_importance"][ |
|
|
362 |
:n_top_peaks_per_gp].values) |
|
|
363 |
adata.uns[f"{gp}_target_peaks_top_peak_signs"] = ( |
|
|
364 |
np.where(gp_target_peaks_peak_importances_df[ |
|
|
365 |
"peak_weight"] > 0, "+", "-")) |
|
|
366 |
|
|
|
367 |
# Add peak counts to temporary adata for plotting |
|
|
368 |
adata.obs[[peak for peak in |
|
|
369 |
adata.uns[f"{gp}_target_peaks_top_peaks"]]] = ( |
|
|
370 |
adata_atac.X[ |
|
|
371 |
:, [adata_atac.var_names.tolist().index(peak) |
|
|
372 |
for peak in adata.uns[f"{gp}_target_peaks_top_peaks"]]]) |
|
|
373 |
adata.obs[[peak for peak in |
|
|
374 |
adata.uns[f"{gp}_source_peaks_top_peaks"]]] = ( |
|
|
375 |
adata_atac.X[ |
|
|
376 |
:, [adata_atac.var_names.tolist().index(peak) |
|
|
377 |
for peak in adata.uns[f"{gp}_source_peaks_top_peaks"]]]) |
|
|
378 |
else: |
|
|
379 |
adata.uns["n_top_source_peaks"] = 0 |
|
|
380 |
adata.uns["n_top_target_peaks"] = 0 |
|
|
381 |
|
|
|
382 |
for feature_space in feature_spaces: |
|
|
383 |
plot_enriched_gp_info_plots_( |
|
|
384 |
adata=adata, |
|
|
385 |
sample_key=sample_key, |
|
|
386 |
gps=gps, |
|
|
387 |
log_bayes_factors=log_bayes_factors, |
|
|
388 |
cat_key=cat_key, |
|
|
389 |
cat_palette=cat_palette, |
|
|
390 |
cats=cats, |
|
|
391 |
feature_space=feature_space, |
|
|
392 |
spot_size=spot_size, |
|
|
393 |
suptitle=f"{plot_label.replace('_', ' ').title()} " |
|
|
394 |
f"Top {n_top_enriched_gp_start_idx} to " |
|
|
395 |
f"{n_top_enriched_gp_end_idx} Enriched GPs: " |
|
|
396 |
f"GP Scores and Omics Feature Counts in " |
|
|
397 |
f"{feature_space} Feature Space", |
|
|
398 |
save_fig=save_figs, |
|
|
399 |
figure_folder_path=figure_folder_path, |
|
|
400 |
fig_name=f"{plot_label}_top_{n_top_enriched_gp_start_idx}" |
|
|
401 |
f"-{n_top_enriched_gp_end_idx}_enriched_gps_gp_scores_" |
|
|
402 |
f"omics_feature_counts_in_{feature_space}_" |
|
|
403 |
f"feature_space.{file_format}") |
|
|
404 |
|
|
|
405 |
|
|
|
406 |
def plot_enriched_gp_info_plots_(adata: AnnData, |
|
|
407 |
sample_key: str, |
|
|
408 |
gps: list, |
|
|
409 |
log_bayes_factors: list, |
|
|
410 |
cat_key: str, |
|
|
411 |
cat_palette: dict, |
|
|
412 |
cats: list, |
|
|
413 |
feature_space: str, |
|
|
414 |
spot_size: float, |
|
|
415 |
suptitle: str, |
|
|
416 |
save_fig: bool, |
|
|
417 |
figure_folder_path: str, |
|
|
418 |
fig_name: str): |
|
|
419 |
""" |
|
|
420 |
This is a helper function to plot gene program info plots in a specified |
|
|
421 |
feature space. |
|
|
422 |
|
|
|
423 |
Parameters |
|
|
424 |
---------- |
|
|
425 |
adata: |
|
|
426 |
An AnnData object with stored information about the gene programs to be |
|
|
427 |
plotted. |
|
|
428 |
sample_key: |
|
|
429 |
Key in ´adata.obs´ where the samples are stored. |
|
|
430 |
gps: |
|
|
431 |
List of gene programs for which info plots will be created. |
|
|
432 |
log_bayes_factors: |
|
|
433 |
List of log bayes factors corresponding to gene programs |
|
|
434 |
cat_key: |
|
|
435 |
Key in ´adata.obs´ where the categories that are used as colors for the |
|
|
436 |
enriched category plot are stored. |
|
|
437 |
cat_palette: |
|
|
438 |
Dictionary of colors that are used to highlight the categories, where |
|
|
439 |
the category is the key of the dictionary and the color is the value. |
|
|
440 |
cats: |
|
|
441 |
List of categories for which the corresponding gene programs in ´gps´ |
|
|
442 |
are enriched. |
|
|
443 |
feature_space: |
|
|
444 |
Feature space used for the plots. Can be ´latent´ to use the latent |
|
|
445 |
embeddings for the plots, or it can be any of the samples stored in |
|
|
446 |
´adata.obs[sample_key]´ to use the respective physical feature space for |
|
|
447 |
the plots. |
|
|
448 |
spot_size: |
|
|
449 |
Spot size used for the spatial plots. |
|
|
450 |
subtitle: |
|
|
451 |
Overall figure title. |
|
|
452 |
save_fig: |
|
|
453 |
If ´True´, save the figure. |
|
|
454 |
figure_folder_path: |
|
|
455 |
Path of the folder where the figure will be saved. |
|
|
456 |
fig_name: |
|
|
457 |
Name of the figure under which it will be saved. |
|
|
458 |
""" |
|
|
459 |
# Define figure configurations |
|
|
460 |
ncols = (2 + |
|
|
461 |
adata.uns["n_top_source_genes"] + |
|
|
462 |
adata.uns["n_top_target_genes"] + |
|
|
463 |
adata.uns["n_top_source_peaks"] + |
|
|
464 |
adata.uns["n_top_target_peaks"]) |
|
|
465 |
fig_width = (12 + (6 * ( |
|
|
466 |
adata.uns["n_top_source_genes"] + |
|
|
467 |
adata.uns["n_top_target_genes"] + |
|
|
468 |
adata.uns["n_top_source_peaks"] + |
|
|
469 |
adata.uns["n_top_target_peaks"]))) |
|
|
470 |
wspace = 0.3 |
|
|
471 |
fig, axs = plt.subplots(nrows=len(gps), |
|
|
472 |
ncols=ncols, |
|
|
473 |
figsize=(fig_width, 6*len(gps))) |
|
|
474 |
if axs.ndim == 1: |
|
|
475 |
axs = axs.reshape(1, -1) |
|
|
476 |
title = fig.suptitle(t=suptitle, |
|
|
477 |
x=0.55, |
|
|
478 |
y=(1.1 if len(gps) == 1 else 0.97), |
|
|
479 |
fontsize=20) |
|
|
480 |
|
|
|
481 |
# Plot enriched gp category and gene program latent scores |
|
|
482 |
for i, gp in enumerate(gps): |
|
|
483 |
if feature_space == "latent": |
|
|
484 |
sc.pl.umap( |
|
|
485 |
adata, |
|
|
486 |
color=cat_key, |
|
|
487 |
palette=cat_palette, |
|
|
488 |
groups=cats[i], |
|
|
489 |
ax=axs[i, 0], |
|
|
490 |
title="Enriched GP Category", |
|
|
491 |
legend_loc="on data", |
|
|
492 |
na_in_legend=False, |
|
|
493 |
show=False) |
|
|
494 |
sc.pl.umap( |
|
|
495 |
adata, |
|
|
496 |
color=gps[i], |
|
|
497 |
color_map="RdBu", |
|
|
498 |
ax=axs[i, 1], |
|
|
499 |
title=f"{gp[:gp.index('_')]}\n" |
|
|
500 |
f"{gp[gp.index('_') + 1: gp.rindex('_')].replace('_', ' ')}" |
|
|
501 |
f"\n{gp[gps[i].rindex('_') + 1:]} score (LBF: {round(log_bayes_factors[i])})", |
|
|
502 |
colorbar_loc="bottom", |
|
|
503 |
show=False) |
|
|
504 |
else: |
|
|
505 |
sc.pl.spatial( |
|
|
506 |
adata=adata[adata.obs[sample_key] == feature_space], |
|
|
507 |
color=cat_key, |
|
|
508 |
palette=cat_palette, |
|
|
509 |
groups=cats[i], |
|
|
510 |
ax=axs[i, 0], |
|
|
511 |
spot_size=spot_size, |
|
|
512 |
title="Enriched GP Category", |
|
|
513 |
legend_loc="on data", |
|
|
514 |
na_in_legend=False, |
|
|
515 |
show=False) |
|
|
516 |
sc.pl.spatial( |
|
|
517 |
adata=adata[adata.obs[sample_key] == feature_space], |
|
|
518 |
color=gps[i], |
|
|
519 |
color_map="RdBu", |
|
|
520 |
spot_size=spot_size, |
|
|
521 |
title=f"{gps[i].split('_', 1)[0]}\n{gps[i].split('_', 1)[1]} " |
|
|
522 |
f"(LBF: {round(log_bayes_factors[i], 2)})", |
|
|
523 |
legend_loc=None, |
|
|
524 |
ax=axs[i, 1], |
|
|
525 |
colorbar_loc="bottom", |
|
|
526 |
show=False) |
|
|
527 |
axs[i, 0].xaxis.label.set_visible(False) |
|
|
528 |
axs[i, 0].yaxis.label.set_visible(False) |
|
|
529 |
axs[i, 1].xaxis.label.set_visible(False) |
|
|
530 |
axs[i, 1].yaxis.label.set_visible(False) |
|
|
531 |
|
|
|
532 |
# Plot omics feature counts (or log normalized counts) |
|
|
533 |
modality_entities = [] |
|
|
534 |
if len(adata.uns[f"{gp}_source_genes_top_genes"]) > 0: |
|
|
535 |
modality_entities.append("source_genes") |
|
|
536 |
if len(adata.uns[f"{gp}_target_genes_top_genes"]) > 0: |
|
|
537 |
modality_entities.append("target_genes") |
|
|
538 |
if f"{gp}_source_peaks_top_peaks" in adata.uns.keys(): |
|
|
539 |
gp_n_source_peaks_top_peaks = ( |
|
|
540 |
len(adata.uns[f"{gp}_source_peaks_top_peaks"])) |
|
|
541 |
if len(adata.uns[f"{gp}_source_peaks_top_peaks"]) > 0: |
|
|
542 |
modality_entities.append("source_peaks") |
|
|
543 |
else: |
|
|
544 |
gp_n_source_peaks_top_peaks = 0 |
|
|
545 |
if f"{gp}_target_peaks_top_peaks" in adata.uns.keys(): |
|
|
546 |
gp_n_target_peaks_top_peaks = ( |
|
|
547 |
len(adata.uns[f"{gp}_target_peaks_top_peaks"])) |
|
|
548 |
if len(adata.uns[f"{gp}_target_peaks_top_peaks"]) > 0: |
|
|
549 |
modality_entities.append("target_peaks") |
|
|
550 |
else: |
|
|
551 |
gp_n_target_peaks_top_peaks = 0 |
|
|
552 |
for modality_entity in modality_entities: |
|
|
553 |
# Define k for index iteration |
|
|
554 |
if modality_entity == "source_genes": |
|
|
555 |
k = 0 |
|
|
556 |
elif modality_entity == "target_genes": |
|
|
557 |
k = len(adata.uns[f"{gp}_source_genes_top_genes"]) |
|
|
558 |
elif modality_entity == "source_peaks": |
|
|
559 |
k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) + |
|
|
560 |
len(adata.uns[f"{gp}_target_genes_top_genes"])) |
|
|
561 |
elif modality_entity == "target_peaks": |
|
|
562 |
k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) + |
|
|
563 |
len(adata.uns[f"{gp}_target_genes_top_genes"]) + |
|
|
564 |
len(adata.uns[f"{gp}_source_peaks_top_peaks"])) |
|
|
565 |
for j in range(len(adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
566 |
f"{modality_entity.split('_')[1]}"])): |
|
|
567 |
if feature_space == "latent": |
|
|
568 |
sc.pl.umap( |
|
|
569 |
adata, |
|
|
570 |
color=adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
571 |
f"{modality_entity.split('_')[1]}"][j], |
|
|
572 |
color_map=(adata.uns["omics_ft_pos_cmap"] if |
|
|
573 |
adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
574 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
575 |
"_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]), |
|
|
576 |
ax=axs[i, 2+k+j], |
|
|
577 |
legend_loc="on data", |
|
|
578 |
na_in_legend=False, |
|
|
579 |
title=f"""{adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
580 |
f"{modality_entity.split('_')[1]}" |
|
|
581 |
][j]}: """ |
|
|
582 |
f"""{adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
583 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
584 |
"_importances"][j]:.2f} """ |
|
|
585 |
f"({modality_entity[:-1]}; " |
|
|
586 |
f"""{adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
587 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
588 |
"_signs"][j]})""", |
|
|
589 |
colorbar_loc="bottom", |
|
|
590 |
show=False) |
|
|
591 |
else: |
|
|
592 |
sc.pl.spatial( |
|
|
593 |
adata=adata[adata.obs[sample_key] == feature_space], |
|
|
594 |
color=adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
595 |
f"{modality_entity.split('_')[1]}"][j], |
|
|
596 |
color_map=(adata.uns["omics_ft_pos_cmap"] if |
|
|
597 |
adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
598 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
599 |
"_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]), |
|
|
600 |
legend_loc="on data", |
|
|
601 |
na_in_legend=False, |
|
|
602 |
ax=axs[i, 2+k+j], |
|
|
603 |
spot_size=spot_size, |
|
|
604 |
title=f"""{adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
605 |
f"{modality_entity.split('_')[1]}" |
|
|
606 |
][j]} \n""" |
|
|
607 |
f"""({adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
608 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
609 |
"_importances"][j]:.2f}; """ |
|
|
610 |
f"{modality_entity[:-1]}; " |
|
|
611 |
f"""{adata.uns[f"{gp}_{modality_entity}_top_" |
|
|
612 |
f"{modality_entity.split('_')[1][:-1]}" |
|
|
613 |
"_signs"][j]})""", |
|
|
614 |
colorbar_loc="bottom", |
|
|
615 |
show=False) |
|
|
616 |
axs[i, 2+k+j].xaxis.label.set_visible(False) |
|
|
617 |
axs[i, 2+k+j].yaxis.label.set_visible(False) |
|
|
618 |
# Remove unnecessary axes |
|
|
619 |
for l in range(2 + |
|
|
620 |
len(adata.uns[f"{gp}_source_genes_top_genes"]) + |
|
|
621 |
len(adata.uns[f"{gp}_target_genes_top_genes"]) + |
|
|
622 |
gp_n_source_peaks_top_peaks + |
|
|
623 |
gp_n_target_peaks_top_peaks, ncols): |
|
|
624 |
axs[i, l].set_visible(False) |
|
|
625 |
|
|
|
626 |
# Save and display plot |
|
|
627 |
plt.subplots_adjust(wspace=wspace, hspace=0.275) |
|
|
628 |
if save_fig: |
|
|
629 |
fig.savefig(f"{figure_folder_path}/{fig_name}", |
|
|
630 |
bbox_extra_artists=(title,), |
|
|
631 |
bbox_inches="tight") |
|
|
632 |
plt.show() |
|
|
633 |
|
|
|
634 |
default_color_dict = { |
|
|
635 |
"0": "#66C5CC", |
|
|
636 |
"1": "#F6CF71", |
|
|
637 |
"2": "#F89C74", |
|
|
638 |
"3": "#DCB0F2", |
|
|
639 |
"4": "#87C55F", |
|
|
640 |
"5": "#9EB9F3", |
|
|
641 |
"6": "#FE88B1", |
|
|
642 |
"7": "#C9DB74", |
|
|
643 |
"8": "#8BE0A4", |
|
|
644 |
"9": "#B497E7", |
|
|
645 |
"10": "#D3B484", |
|
|
646 |
"11": "#B3B3B3", |
|
|
647 |
"12": "#276A8C", # Royal Blue |
|
|
648 |
"13": "#DAB6C4", # Pink |
|
|
649 |
"14": "#C38D9E", # Mauve-Pink |
|
|
650 |
"15": "#9D88A2", # Mauve |
|
|
651 |
"16": "#FF4D4D", # Light Red |
|
|
652 |
"17": "#9B4DCA", # Lavender-Purple |
|
|
653 |
"18": "#FF9CDA", # Bright Pink |
|
|
654 |
"19": "#FF69B4", # Hot Pink |
|
|
655 |
"20": "#FF00FF", # Magenta |
|
|
656 |
"21": "#DA70D6", # Orchid |
|
|
657 |
"22": "#BA55D3", # Medium Orchid |
|
|
658 |
"23": "#8A2BE2", # Blue Violet |
|
|
659 |
"24": "#9370DB", # Medium Purple |
|
|
660 |
"25": "#7B68EE", # Medium Slate Blue |
|
|
661 |
"26": "#4169E1", # Royal Blue |
|
|
662 |
"27": "#FF8C8C", # Salmon Pink |
|
|
663 |
"28": "#FFAA80", # Light Coral |
|
|
664 |
"29": "#48D1CC", # Medium Turquoise |
|
|
665 |
"30": "#40E0D0", # Turquoise |
|
|
666 |
"31": "#00FF00", # Lime |
|
|
667 |
"32": "#7FFF00", # Chartreuse |
|
|
668 |
"33": "#ADFF2F", # Green Yellow |
|
|
669 |
"34": "#32CD32", # Lime Green |
|
|
670 |
"35": "#228B22", # Forest Green |
|
|
671 |
"36": "#FFD8B8", # Peach |
|
|
672 |
"37": "#008080", # Teal |
|
|
673 |
"38": "#20B2AA", # Light Sea Green |
|
|
674 |
"39": "#00FFFF", # Cyan |
|
|
675 |
"40": "#00BFFF", # Deep Sky Blue |
|
|
676 |
"41": "#4169E1", # Royal Blue |
|
|
677 |
"42": "#0000CD", # Medium Blue |
|
|
678 |
"43": "#00008B", # Dark Blue |
|
|
679 |
"44": "#8B008B", # Dark Magenta |
|
|
680 |
"45": "#FF1493", # Deep Pink |
|
|
681 |
"46": "#FF4500", # Orange Red |
|
|
682 |
"47": "#006400", # Dark Green |
|
|
683 |
"48": "#FF6347", # Tomato |
|
|
684 |
"49": "#FF7F50", # Coral |
|
|
685 |
"50": "#CD5C5C", # Indian Red |
|
|
686 |
"51": "#B22222", # Fire Brick |
|
|
687 |
"52": "#FFB83F", # Light Orange |
|
|
688 |
"53": "#8B0000", # Dark Red |
|
|
689 |
"54": "#D2691E", # Chocolate |
|
|
690 |
"55": "#A0522D", # Sienna |
|
|
691 |
"56": "#800000", # Maroon |
|
|
692 |
"57": "#808080", # Gray |
|
|
693 |
"58": "#A9A9A9", # Dark Gray |
|
|
694 |
"59": "#C0C0C0", # Silver |
|
|
695 |
"60": "#9DD84A", |
|
|
696 |
"61": "#F5F5F5", # White Smoke |
|
|
697 |
"62": "#F17171", # Light Red |
|
|
698 |
"63": "#000000", # Black |
|
|
699 |
"64": "#FF8C42", # Tangerine |
|
|
700 |
"65": "#F9A11F", # Bright Orange-Yellow |
|
|
701 |
"66": "#FACC15", # Golden Yellow |
|
|
702 |
"67": "#E2E062", # Pale Lime |
|
|
703 |
"68": "#BADE92", # Soft Lime |
|
|
704 |
"69": "#70C1B3", # Greenish-Blue |
|
|
705 |
"70": "#41B3A3", # Turquoise |
|
|
706 |
"71": "#5EAAA8", # Gray-Green |
|
|
707 |
"72": "#72B01D", # Chartreuse |
|
|
708 |
"73": "#9CD08F", # Light Green |
|
|
709 |
"74": "#8EBA43", # Olive Green |
|
|
710 |
"75": "#FAC8C3", # Light Pink |
|
|
711 |
"76": "#E27D60", # Dark Salmon |
|
|
712 |
"77": "#C38D9E", # Mauve-Pink |
|
|
713 |
"78": "#937D64", # Light Brown |
|
|
714 |
"79": "#B1C1CC", # Light Blue-Gray |
|
|
715 |
"80": "#88A0A8", # Gray-Blue-Green |
|
|
716 |
"81": "#4E598C", # Dark Blue-Purple |
|
|
717 |
"82": "#4B4E6D", # Dark Gray-Blue |
|
|
718 |
"83": "#8E9AAF", # Light Blue-Grey |
|
|
719 |
"84": "#C0D6DF", # Pale Blue-Grey |
|
|
720 |
"85": "#97C1A9", # Blue-Green |
|
|
721 |
"86": "#4C6E5D", # Dark Green |
|
|
722 |
"87": "#95B9C7", # Pale Blue-Green |
|
|
723 |
"88": "#C1D5E0", # Pale Gray-Blue |
|
|
724 |
"89": "#ECDB54", # Bright Yellow |
|
|
725 |
"90": "#E89B3B", # Bright Orange |
|
|
726 |
"91": "#CE5A57", # Deep Red |
|
|
727 |
"92": "#C3525A", # Dark Red |
|
|
728 |
"93": "#B85D8E", # Berry |
|
|
729 |
"94": "#7D5295", # Deep Purple |
|
|
730 |
"-1" : "#E1D9D1", |
|
|
731 |
"None" : "#E1D9D1" |
|
|
732 |
} |
|
|
733 |
|
|
|
734 |
def create_new_color_dict( |
|
|
735 |
adata, |
|
|
736 |
cat_key, |
|
|
737 |
color_palette="default", |
|
|
738 |
overwrite_color_dict={"-1" : "#E1D9D1"}, |
|
|
739 |
skip_default_colors=0): |
|
|
740 |
""" |
|
|
741 |
Create a dictionary of color hexcodes for a specified category. |
|
|
742 |
|
|
|
743 |
Parameters |
|
|
744 |
---------- |
|
|
745 |
adata: |
|
|
746 |
AnnData object. |
|
|
747 |
cat_key: |
|
|
748 |
Key in ´adata.obs´ where the categories are stored for which color |
|
|
749 |
hexcodes will be created. |
|
|
750 |
color_palette: |
|
|
751 |
Type of color palette. |
|
|
752 |
overwrite_color_dict: |
|
|
753 |
Dictionary with overwrite values that will take precedence over the |
|
|
754 |
automatically created dictionary. |
|
|
755 |
skip_default_colors: |
|
|
756 |
Number of colors to skip from the default color dict. |
|
|
757 |
|
|
|
758 |
Returns |
|
|
759 |
---------- |
|
|
760 |
new_color_dict: |
|
|
761 |
The color dictionary with a hexcode for each category. |
|
|
762 |
""" |
|
|
763 |
new_categories = adata.obs[cat_key].unique().tolist() |
|
|
764 |
if color_palette == "cell_type_30": |
|
|
765 |
# https://github.com/scverse/scanpy/blob/master/scanpy/plotting/palettes.py#L40 |
|
|
766 |
new_color_dict = {key: value for key, value in zip( |
|
|
767 |
new_categories, |
|
|
768 |
["#023fa5", |
|
|
769 |
"#7d87b9", |
|
|
770 |
"#bec1d4", |
|
|
771 |
"#d6bcc0", |
|
|
772 |
"#bb7784", |
|
|
773 |
"#8e063b", |
|
|
774 |
"#4a6fe3", |
|
|
775 |
"#8595e1", |
|
|
776 |
"#b5bbe3", |
|
|
777 |
"#e6afb9", |
|
|
778 |
"#e07b91", |
|
|
779 |
"#d33f6a", |
|
|
780 |
"#11c638", |
|
|
781 |
"#8dd593", |
|
|
782 |
"#c6dec7", |
|
|
783 |
"#ead3c6", |
|
|
784 |
"#f0b98d", |
|
|
785 |
"#ef9708", |
|
|
786 |
"#0fcfc0", |
|
|
787 |
"#9cded6", |
|
|
788 |
"#d5eae7", |
|
|
789 |
"#f3e1eb", |
|
|
790 |
"#f6c4e1", |
|
|
791 |
"#f79cd4", |
|
|
792 |
'#7f7f7f', |
|
|
793 |
"#c7c7c7", |
|
|
794 |
"#1CE6FF", |
|
|
795 |
"#336600"])} |
|
|
796 |
elif color_palette == "cell_type_20": |
|
|
797 |
# https://github.com/vega/vega/wiki/Scales#scale-range-literals (some adjusted) |
|
|
798 |
new_color_dict = {key: value for key, value in zip( |
|
|
799 |
new_categories, |
|
|
800 |
['#1f77b4', |
|
|
801 |
'#ff7f0e', |
|
|
802 |
'#279e68', |
|
|
803 |
'#d62728', |
|
|
804 |
'#aa40fc', |
|
|
805 |
'#8c564b', |
|
|
806 |
'#e377c2', |
|
|
807 |
'#b5bd61', |
|
|
808 |
'#17becf', |
|
|
809 |
'#aec7e8', |
|
|
810 |
'#ffbb78', |
|
|
811 |
'#98df8a', |
|
|
812 |
'#ff9896', |
|
|
813 |
'#c5b0d5', |
|
|
814 |
'#c49c94', |
|
|
815 |
'#f7b6d2', |
|
|
816 |
'#dbdb8d', |
|
|
817 |
'#9edae5', |
|
|
818 |
'#ad494a', |
|
|
819 |
'#8c6d31'])} |
|
|
820 |
elif color_palette == "cell_type_10": |
|
|
821 |
# scanpy vega10 |
|
|
822 |
new_color_dict = {key: value for key, value in zip( |
|
|
823 |
new_categories, |
|
|
824 |
['#7f7f7f', |
|
|
825 |
'#ff7f0e', |
|
|
826 |
'#279e68', |
|
|
827 |
'#e377c2', |
|
|
828 |
'#17becf', |
|
|
829 |
'#8c564b', |
|
|
830 |
'#d62728', |
|
|
831 |
'#1f77b4', |
|
|
832 |
'#b5bd61', |
|
|
833 |
'#aa40fc'])} |
|
|
834 |
elif color_palette == "batch": |
|
|
835 |
# sns.color_palette("colorblind").as_hex() |
|
|
836 |
new_color_dict = {key: value for key, value in zip( |
|
|
837 |
new_categories, |
|
|
838 |
['#0173b2', '#d55e00', '#ece133', '#ca9161', '#fbafe4', |
|
|
839 |
'#949494', '#de8f05', '#029e73', '#cc78bc', '#56b4e9', |
|
|
840 |
'#F0F8FF', '#FAEBD7', '#00FFFF', '#7FFFD4', '#F0FFFF', |
|
|
841 |
'#F5F5DC', '#FFE4C4', '#000000', '#FFEBCD', '#0000FF', |
|
|
842 |
'#8A2BE2', '#A52A2A', '#DEB887', '#5F9EA0', '#7FFF00', |
|
|
843 |
'#D2691E', '#FF7F50', '#6495ED', '#FFF8DC', '#DC143C'])} |
|
|
844 |
elif color_palette == "default": |
|
|
845 |
new_color_dict = {key: value for key, value in zip(new_categories, list(default_color_dict.values())[skip_default_colors:])} |
|
|
846 |
for key, val in overwrite_color_dict.items(): |
|
|
847 |
new_color_dict[key] = val |
|
|
848 |
return new_color_dict |
|
|
849 |
|
|
|
850 |
|
|
|
851 |
def plot_non_zero_gene_count_means_dist( |
|
|
852 |
adata: AnnData, |
|
|
853 |
genes: list, |
|
|
854 |
gene_label: str): |
|
|
855 |
""" |
|
|
856 |
Plot distribution of non zero gene count means in the adata over all |
|
|
857 |
specified genes. |
|
|
858 |
""" |
|
|
859 |
gene_counts = adata[ |
|
|
860 |
:, [gene for gene in adata.var_names if gene in genes]].layers["counts"] |
|
|
861 |
nz_gene_means = np.mean( |
|
|
862 |
np.ma.masked_equal(gene_counts.toarray(), 0), axis=0).data |
|
|
863 |
|
|
|
864 |
sns.kdeplot(nz_gene_means) |
|
|
865 |
plt.title(f"{gene_label} Genes Average Non-Zero Gene Counts per Gene") |
|
|
866 |
plt.xlabel("Average Non-zero Gene Counts") |
|
|
867 |
plt.ylabel("Gene Density") |
|
|
868 |
plt.show() |
|
|
869 |
|
|
|
870 |
|
|
|
871 |
def compute_communication_gp_network( |
|
|
872 |
gp_list: list, |
|
|
873 |
model: NicheCompass, |
|
|
874 |
group_key: str="niche", |
|
|
875 |
filter_key: Optional[str]=None, |
|
|
876 |
filter_cat: Optional[str]=None, |
|
|
877 |
n_neighbors: int=90): |
|
|
878 |
""" |
|
|
879 |
Compute a network of category aggregated cell-pair communication strengths. |
|
|
880 |
|
|
|
881 |
First, compute cell-cell communication potential scores for each cell. |
|
|
882 |
Then dot product them and take into account neighborhoods to compute |
|
|
883 |
cell-pair communication strengths. Then, normalize cell-pair communication |
|
|
884 |
strengths. |
|
|
885 |
|
|
|
886 |
Parameters |
|
|
887 |
---------- |
|
|
888 |
gp_list: |
|
|
889 |
List of GPs for which the cell-pair communication strengths are computed. |
|
|
890 |
model: |
|
|
891 |
A trained NicheCompass model. |
|
|
892 |
group_key: |
|
|
893 |
Key in ´adata.obs´ where the groups are stored over which the cell-pair |
|
|
894 |
communication strengths will be aggregated. |
|
|
895 |
filter_key: |
|
|
896 |
Key in ´adata.obs´ that contains the category for which the results are |
|
|
897 |
filtered. |
|
|
898 |
filter_cat: |
|
|
899 |
Category for which the results are filtered. |
|
|
900 |
n_neighbors: |
|
|
901 |
Number of neighbors for the gp-specific neighborhood graph. |
|
|
902 |
|
|
|
903 |
Returns |
|
|
904 |
---------- |
|
|
905 |
network_df: |
|
|
906 |
A pandas dataframe with aggregated, normalized cell-pair communication strengths. |
|
|
907 |
""" |
|
|
908 |
# Compute neighborhood graph |
|
|
909 |
compute_knn = True |
|
|
910 |
if 'spatial_cci' in model.adata.uns.keys(): |
|
|
911 |
if model.adata.uns['spatial_cci']['params']['n_neighbors'] == n_neighbors: |
|
|
912 |
compute_knn = False |
|
|
913 |
if compute_knn: |
|
|
914 |
sc.pp.neighbors(model.adata, |
|
|
915 |
n_neighbors=n_neighbors, |
|
|
916 |
use_rep="spatial", |
|
|
917 |
key_added="spatial_cci") |
|
|
918 |
|
|
|
919 |
gp_network_dfs = [] |
|
|
920 |
gp_summary_df = model.get_gp_summary() |
|
|
921 |
for gp in gp_list: |
|
|
922 |
gp_idx = model.adata.uns[model.gp_names_key_].tolist().index(gp) |
|
|
923 |
active_gp_idx = model.adata.uns[model.active_gp_names_key_].tolist().index(gp) |
|
|
924 |
gp_scores = model.adata.obsm[model.latent_key_][:, active_gp_idx] |
|
|
925 |
gp_targets_cats = model.adata.varm[model.gp_targets_categories_mask_key_][:, gp_idx] |
|
|
926 |
gp_sources_cats = model.adata.varm[model.gp_sources_categories_mask_key_][:, gp_idx] |
|
|
927 |
targets_cats_label_encoder = model.adata.uns[model.targets_categories_label_encoder_key_] |
|
|
928 |
sources_cats_label_encoder = model.adata.uns[model.sources_categories_label_encoder_key_] |
|
|
929 |
|
|
|
930 |
sources_cat_idx_dict = {} |
|
|
931 |
for source_cat, source_cat_label in sources_cats_label_encoder.items(): |
|
|
932 |
sources_cat_idx_dict[source_cat] = np.where(gp_sources_cats == source_cat_label)[0] |
|
|
933 |
|
|
|
934 |
targets_cat_idx_dict = {} |
|
|
935 |
for target_cat, target_cat_label in targets_cats_label_encoder.items(): |
|
|
936 |
targets_cat_idx_dict[target_cat] = np.where(gp_targets_cats == target_cat_label)[0] |
|
|
937 |
|
|
|
938 |
# Get indices of all source and target genes |
|
|
939 |
source_genes_idx = np.array([], dtype=np.int64) |
|
|
940 |
for key in sources_cat_idx_dict.keys(): |
|
|
941 |
source_genes_idx = np.append(source_genes_idx, |
|
|
942 |
sources_cat_idx_dict[key]) |
|
|
943 |
target_genes_idx = np.array([], dtype=np.int64) |
|
|
944 |
for key in targets_cat_idx_dict.keys(): |
|
|
945 |
target_genes_idx = np.append(target_genes_idx, |
|
|
946 |
targets_cat_idx_dict[key]) |
|
|
947 |
|
|
|
948 |
# Compute cell-cell communication potential scores |
|
|
949 |
gp_source_scores = np.zeros((len(model.adata.obs), len(source_genes_idx))) |
|
|
950 |
gp_target_scores = np.zeros((len(model.adata.obs), len(target_genes_idx))) |
|
|
951 |
|
|
|
952 |
for i, source_gene_idx in enumerate(source_genes_idx): |
|
|
953 |
source_gene = model.adata.var_names[source_gene_idx] |
|
|
954 |
gp_source_scores[:, i] = ( |
|
|
955 |
model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten().max() * |
|
|
956 |
gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes"].values[0].index(source_gene)] * |
|
|
957 |
gp_scores) |
|
|
958 |
|
|
|
959 |
for j, target_gene_idx in enumerate(target_genes_idx): |
|
|
960 |
target_gene = model.adata.var_names[target_gene_idx] |
|
|
961 |
gp_target_scores[:, j] = ( |
|
|
962 |
model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten().max() * |
|
|
963 |
gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes"].values[0].index(target_gene)] * |
|
|
964 |
gp_scores) |
|
|
965 |
|
|
|
966 |
agg_gp_source_score = gp_source_scores.mean(1).astype("float32") |
|
|
967 |
agg_gp_target_score = gp_target_scores.mean(1).astype("float32") |
|
|
968 |
agg_gp_source_score[agg_gp_source_score < 0] = 0. |
|
|
969 |
agg_gp_target_score[agg_gp_target_score < 0] = 0. |
|
|
970 |
|
|
|
971 |
model.adata.obs[f"{gp}_source_score"] = agg_gp_source_score |
|
|
972 |
model.adata.obs[f"{gp}_target_score"] = agg_gp_target_score |
|
|
973 |
|
|
|
974 |
del(gp_target_scores) |
|
|
975 |
del(gp_source_scores) |
|
|
976 |
|
|
|
977 |
agg_gp_source_score = sp.csr_matrix(agg_gp_source_score) |
|
|
978 |
agg_gp_target_score = sp.csr_matrix(agg_gp_target_score) |
|
|
979 |
|
|
|
980 |
model.adata.obsp[f"{gp}_connectivities"] = (model.adata.obsp["spatial_cci_connectivities"] > 0).multiply( |
|
|
981 |
agg_gp_source_score.T.dot(agg_gp_target_score)) |
|
|
982 |
|
|
|
983 |
# Aggregate gp connectivities for each group |
|
|
984 |
gp_network_df_pivoted = aggregate_obsp_matrix_per_cell_type( |
|
|
985 |
adata=model.adata, |
|
|
986 |
obsp_key=f"{gp}_connectivities", |
|
|
987 |
cell_type_key=group_key, |
|
|
988 |
group_key=filter_key, |
|
|
989 |
agg_rows=True) |
|
|
990 |
|
|
|
991 |
if filter_key is not None: |
|
|
992 |
gp_network_df_pivoted = gp_network_df_pivoted.loc[filter_cat, :] |
|
|
993 |
|
|
|
994 |
gp_network_df = gp_network_df_pivoted.melt(var_name="source", value_name="gp_score", ignore_index=False).reset_index() |
|
|
995 |
gp_network_df.columns = ["source", "target", "strength"] |
|
|
996 |
|
|
|
997 |
gp_network_df = gp_network_df.sort_values("strength", ascending=False) |
|
|
998 |
|
|
|
999 |
# Normalize strength |
|
|
1000 |
min_value = gp_network_df["strength"].min() |
|
|
1001 |
max_value = gp_network_df["strength"].max() |
|
|
1002 |
gp_network_df["strength_unscaled"] = gp_network_df["strength"] |
|
|
1003 |
gp_network_df["strength"] = (gp_network_df["strength"] - min_value) / (max_value - min_value) |
|
|
1004 |
gp_network_df["strength"] = np.round(gp_network_df["strength"], 2) |
|
|
1005 |
gp_network_df = gp_network_df[gp_network_df["strength"] > 0] |
|
|
1006 |
|
|
|
1007 |
gp_network_df["edge_type"] = gp |
|
|
1008 |
gp_network_dfs.append(gp_network_df) |
|
|
1009 |
|
|
|
1010 |
network_df = pd.concat(gp_network_dfs, ignore_index=True) |
|
|
1011 |
return network_df |
|
|
1012 |
|
|
|
1013 |
|
|
|
1014 |
def visualize_communication_gp_network( |
|
|
1015 |
adata, |
|
|
1016 |
network_df, |
|
|
1017 |
cat_colors, |
|
|
1018 |
edge_type_colors: Optional[dict]=None, |
|
|
1019 |
edge_width_scale: int=20.0, |
|
|
1020 |
node_size: int=500, |
|
|
1021 |
fontsize: int=14, |
|
|
1022 |
figsize: Tuple[int, int]=(18, 16), |
|
|
1023 |
plot_legend: bool=True, |
|
|
1024 |
save: bool=False, |
|
|
1025 |
save_path: str="communication_gp_network.svg", |
|
|
1026 |
show: bool=True, |
|
|
1027 |
text_space: float=1.3, |
|
|
1028 |
connection_style="arc3, rad = 0.1", |
|
|
1029 |
cat_key: str="niche", |
|
|
1030 |
edge_attr: str="strength"): |
|
|
1031 |
""" |
|
|
1032 |
Visualize a communication gp network. |
|
|
1033 |
""" |
|
|
1034 |
# Assuming you have unique edge types in your 'edge_type' column |
|
|
1035 |
edge_types = np.unique(network_df['edge_type']) |
|
|
1036 |
|
|
|
1037 |
if edge_type_colors is None: |
|
|
1038 |
# Colorblindness adjusted vega_10 |
|
|
1039 |
# See https://github.com/theislab/scanpy/issues/387 |
|
|
1040 |
vega_10 = list(map(colors.to_hex, cm.tab10.colors)) |
|
|
1041 |
vega_10_scanpy = vega_10.copy() |
|
|
1042 |
vega_10_scanpy[2] = "#279e68" # green |
|
|
1043 |
vega_10_scanpy[4] = "#aa40fc" # purple |
|
|
1044 |
vega_10_scanpy[8] = "#b5bd61" # kakhi |
|
|
1045 |
edge_type_colors = vega_10_scanpy |
|
|
1046 |
|
|
|
1047 |
# Create a dictionary that maps edge types to colors |
|
|
1048 |
edge_type_color_dict = {edge_type: color for edge_type, color in zip(edge_types, edge_type_colors)} |
|
|
1049 |
|
|
|
1050 |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) |
|
|
1051 |
ax.axis("off") |
|
|
1052 |
G = nx.from_pandas_edgelist( |
|
|
1053 |
network_df, |
|
|
1054 |
source="source", |
|
|
1055 |
target="target", |
|
|
1056 |
edge_attr=["edge_type", edge_attr], |
|
|
1057 |
create_using=nx.DiGraph(), |
|
|
1058 |
) |
|
|
1059 |
pos = nx.circular_layout(G) |
|
|
1060 |
|
|
|
1061 |
nx.set_node_attributes(G, cat_colors, "color") |
|
|
1062 |
node_color = nx.get_node_attributes(G, "color") |
|
|
1063 |
|
|
|
1064 |
description = nx.draw_networkx_labels(G, pos, font_size=fontsize) |
|
|
1065 |
n = adata.obs[cat_key].nunique() |
|
|
1066 |
node_list = sorted(G.nodes()) |
|
|
1067 |
angle = [] |
|
|
1068 |
angle_dict = {} |
|
|
1069 |
for i, node in zip(range(n), node_list): |
|
|
1070 |
theta = 2.0 * np.pi * i / n |
|
|
1071 |
angle.append((np.cos(theta), np.sin(theta))) |
|
|
1072 |
angle_dict[node] = theta |
|
|
1073 |
pos = {} |
|
|
1074 |
for node_i, node in enumerate(node_list): |
|
|
1075 |
pos[node] = angle[node_i] |
|
|
1076 |
|
|
|
1077 |
r = fig.canvas.get_renderer() |
|
|
1078 |
trans = plt.gca().transData.inverted() |
|
|
1079 |
for node, t in description.items(): |
|
|
1080 |
bb = t.get_window_extent(renderer=r) |
|
|
1081 |
bbdata = bb.transformed(trans) |
|
|
1082 |
radius = text_space + bbdata.width / 2.0 |
|
|
1083 |
position = (radius * np.cos(angle_dict[node]), radius * np.sin(angle_dict[node])) |
|
|
1084 |
t.set_position(position) |
|
|
1085 |
t.set_rotation(angle_dict[node] * 360.0 / (2.0 * np.pi)) |
|
|
1086 |
t.set_clip_on(False) |
|
|
1087 |
|
|
|
1088 |
edgelist = [(u, v) for u, v, e in G.edges(data=True) if u != v] |
|
|
1089 |
edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u != v] |
|
|
1090 |
width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u != v] |
|
|
1091 |
|
|
|
1092 |
h2 = nx.draw_networkx( |
|
|
1093 |
G, |
|
|
1094 |
pos, |
|
|
1095 |
with_labels=False, |
|
|
1096 |
node_size=node_size, |
|
|
1097 |
edgelist=edgelist, |
|
|
1098 |
width=width, |
|
|
1099 |
edge_vmin=0.0, |
|
|
1100 |
edge_vmax=1.0, |
|
|
1101 |
edge_color=edge_colors, # Use the edge type colors here |
|
|
1102 |
arrows=True, |
|
|
1103 |
arrowstyle="-|>", |
|
|
1104 |
arrowsize=20, |
|
|
1105 |
vmin=0.0, |
|
|
1106 |
vmax=1.0, |
|
|
1107 |
cmap=plt.cm.binary, # Use a colormap for node colors if needed |
|
|
1108 |
node_color=list(node_color.values()), |
|
|
1109 |
ax=ax, |
|
|
1110 |
connectionstyle=connection_style, |
|
|
1111 |
) |
|
|
1112 |
|
|
|
1113 |
#https://stackoverflow.com/questions/19877666/add-legends-to-linecollection-plot - uses plotted data to define the color but here we already have colors defined, so just need a Line2D object. |
|
|
1114 |
def make_proxy(clr, mappable, **kwargs): |
|
|
1115 |
return Line2D([0, 1], [0, 1], color=clr, **kwargs) |
|
|
1116 |
|
|
|
1117 |
# generate proxies with the above function |
|
|
1118 |
proxies = [make_proxy(clr, h2, lw=5) for clr in set(edge_colors)] |
|
|
1119 |
labels = [edge.split("_")[0] + " GP" for edge in edge_types[::-1]] |
|
|
1120 |
|
|
|
1121 |
if plot_legend: |
|
|
1122 |
lgd = plt.legend(proxies, labels, loc="lower left") |
|
|
1123 |
|
|
|
1124 |
edgelist = [(u, v) for u, v, e in G.edges(data=True) if ((u == v))] + [(u, v) for u, v, e in G.edges(data=True) if ((u != v))] |
|
|
1125 |
edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u == v] |
|
|
1126 |
width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u == v] + [0 for u, v, e in G.edges(data=True) if ((u != v))] |
|
|
1127 |
nx.draw_networkx_edges( |
|
|
1128 |
G, |
|
|
1129 |
pos, |
|
|
1130 |
node_size=node_size, |
|
|
1131 |
edgelist=edgelist, |
|
|
1132 |
width=width, |
|
|
1133 |
edge_vmin=0.0, |
|
|
1134 |
edge_vmax=1.0, |
|
|
1135 |
edge_color=edge_colors, |
|
|
1136 |
arrows=False, |
|
|
1137 |
arrowstyle="-|>", |
|
|
1138 |
arrowsize=20, |
|
|
1139 |
ax=ax, |
|
|
1140 |
connectionstyle=connection_style) |
|
|
1141 |
plt.tight_layout() |
|
|
1142 |
if save: |
|
|
1143 |
plt.savefig(save_path) |
|
|
1144 |
if show: |
|
|
1145 |
plt.show() |
|
|
1146 |
plt.close(fig) |
|
|
1147 |
plt.ion() |