--- a +++ b/src/nichecompass/utils/analysis.py @@ -0,0 +1,1147 @@ +""" +This module contains utilities to analyze niches inferred by the NicheCompass +model. +""" + +from typing import Optional, Tuple + +#import holoviews as hv +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scanpy as sc +import scipy.sparse as sp +import seaborn as sns +from anndata import AnnData +from matplotlib import cm, colors +from matplotlib.lines import Line2D +import networkx as nx + +from ..models import NicheCompass + + +def aggregate_obsp_matrix_per_cell_type( + adata: AnnData, + obsp_key: str, + cell_type_key: str="cell_type", + group_key: Optional[str]=None, + agg_rows: bool=False): + """ + Generic function to aggregate adjacency matrices stored in + ´adata.obsp[obsp_key]´ on cell type level. It can be used to aggregate the + node label aggregator aggregation weights alpha or the reconstructed adjacency + matrix of a trained NicheCompass model by neighbor cell type for downstream + analysis. + + Parameters + ---------- + adata: + AnnData object which contains outputs of NicheCompass model training. + obsp_key: + Key in ´adata.obsp´ where the matrix to be aggregated is stored. + cell_type_key: + Key in ´adata.obs´ where the cell type labels are stored. + group_key: + Key in ´adata.obs´ where additional grouping labels are stored. + agg_rows: + If ´True´, also aggregate over the observations on cell type level. + + Returns + ---------- + cell_type_agg_df: + Pandas DataFrame with the aggregated obsp values (dim: n_obs x + n_cell_types if ´agg_rows == False´, else n_cell_types x n_cell_types). + """ + n_obs = len(adata) + n_cell_types = adata.obs[cell_type_key].nunique() + sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist()) + + cell_type_label_encoder = {k: v for k, v in zip( + sorted_cell_types, + range(n_cell_types))} + + # Retrieve non zero indices and non zero values, and create row-wise + # observation cell type index + nz_obsp_idx = adata.obsp[obsp_key].nonzero() + neighbor_cell_type_index = adata.obs[cell_type_key][nz_obsp_idx[1]].map( + cell_type_label_encoder).values + adata.obsp[obsp_key].eliminate_zeros() # In some sparse reps 0s can appear + nz_obsp = adata.obsp[obsp_key].data + + # Use non zero indices, non zero values and row-wise observation cell type + # index to construct new df with cell types as columns and row-wise + # aggregated values per cell type index as values + cell_type_agg = np.zeros((n_obs, n_cell_types)) + np.add.at(cell_type_agg, + (nz_obsp_idx[0], neighbor_cell_type_index), + nz_obsp) + cell_type_agg_df = pd.DataFrame( + cell_type_agg, + columns=sorted_cell_types) + + # Add cell type labels of observations + cell_type_agg_df[cell_type_key] = adata.obs[cell_type_key].values + + # If specified, add group label + if group_key is not None: + cell_type_agg_df[group_key] = adata.obs[group_key].values + + if agg_rows: + # In addition, aggregate values across rows to get a + # (n_cell_types x n_cell_types) df + if group_key is not None: + cell_type_agg_df = cell_type_agg_df.groupby( + [group_key, cell_type_key]).sum() + else: + cell_type_agg_df = cell_type_agg_df.groupby(cell_type_key).sum() + + # Sort index to have same order as columns + cell_type_agg_df = cell_type_agg_df.loc[ + sorted(cell_type_agg_df.index.tolist()), :] + + return cell_type_agg_df + + +def create_cell_type_chord_plot_from_df( + adata: AnnData, + df: pd.DataFrame, + link_threshold: float=0.01, + cell_type_key: str="cell_type", + group_key: Optional[str]=None, + groups: str="all", + plot_label: str="Niche", + save_fig: bool=False, + file_path: Optional[str]=None): + """ + Create a cell type chord diagram per group based on an input DataFrame. + + Parameters + ---------- + adata: + AnnData object which contains outputs of NicheCompass model training. + df: + A Pandas DataFrame that contains the connection values for the chord + plot (dim: (n_groups x n_cell_types) x n_cell_types). + link_threshold: + Ratio of link strength that a cell type pair needs to exceed compared to + the cell type pair with the maximum link strength to be considered a + link for the chord plot. + cell_type_key: + Key in ´adata.obs´ where the cell type labels are stored. + group_key: + Key in ´adata.obs´ where additional group labels are stored. + groups: + List of groups that will be plotted. If ´all´, plot all groups. + plot_label: + Shared label for the plots. + save_fig: + If ´True´, save the figure. + file_path: + Path where to save the figure. + """ + hv.extension("bokeh") + hv.output(size=200) + + sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist()) + + # Get group labels + if (group_key is not None) & (groups == "all"): + group_labels = df.index.get_level_values( + df.index.names.index(group_key)).unique().tolist() + elif (group_key is not None) & (groups != "all"): + group_labels = groups + else: + group_labels = [""] + + chord_list = [] + for group_label in group_labels: + if group_label == "": + group_df = df + else: + group_df = df[df.index.get_level_values( + df.index.names.index(group_key)) == group_label] + + # Get max value (over rows and columns) of the group for thresholding + group_max = group_df.max().max() + + # Create group chord links + links_list = [] + for i in range(len(sorted_cell_types)): + for j in range(len(sorted_cell_types)): + if group_df.iloc[i, j] > group_max * link_threshold: + link_dict = {} + link_dict["source"] = j + link_dict["target"] = i + link_dict["value"] = group_df.iloc[i, j] + links_list.append(link_dict) + links = pd.DataFrame(links_list) + + # Create group chord nodes (only where links exist) + nodes_list = [] + nodes_idx = [] + for i, cell_type in enumerate(sorted_cell_types): + if i in (links["source"].values) or i in (links["target"].values): + nodes_idx.append(i) + nodes_dict = {} + nodes_dict["name"] = cell_type + nodes_dict["group"] = 1 + nodes_list.append(nodes_dict) + nodes = hv.Dataset(pd.DataFrame(nodes_list, index=nodes_idx), "index") + + # Create group chord plot + chord = hv.Chord((links, nodes)).select(value=(5, None)) + chord.opts(hv.opts.Chord(cmap="Category20", + edge_cmap="Category20", + edge_color=hv.dim("source").str(), + labels="name", + node_color=hv.dim("index").str(), + title=f"{plot_label} {group_label}")) + chord_list.append(chord) + + # Display chord plots + layout = hv.Layout(chord_list).cols(2) + hv.output(layout) + + # Save chord plots + if save_fig: + hv.save(layout, + file_path, + fmt="png") + + +def generate_enriched_gp_info_plots(plot_label: str, + model: NicheCompass, + sample_key: str, + differential_gp_test_results_key: str, + cat_key: str, + cat_palette: dict, + n_top_enriched_gp_start_idx: int=0, + n_top_enriched_gp_end_idx: int=10, + feature_spaces: list=["latent"], + n_top_genes_per_gp: int=3, + n_top_peaks_per_gp: int=0, + scale_omics_ft: bool=False, + save_figs: bool=False, + figure_folder_path: str="", + file_format: str="png", + spot_size: float=30.): + """ + Generate info plots of enriched gene programs. These show the enriched + category, the gp activities, as well as the counts (or log normalized + counts) of the top genes and/or peaks in a specified feature space. + + Parameters + ---------- + plot_label: + Main label of the plots. + model: + A trained NicheCompass model. + sample_key: + Key in ´adata.obs´ where the samples are stored. + differential_gp_test_results_key: + Key in ´adata.uns´ where the results of the differential gene program + testing are stored. + cat_key: + Key in ´adata.obs´ where the categories that are used as colors for the + enriched category plot are stored. + cat_palette: + Dictionary of colors that are used to highlight the categories, where + the category is the key of the dictionary and the color is the value. + n_top_enriched_gp_start_idx: + Number of top enriched gene program from which to start the creation + of plots. + n_top_enriched_gp_end_idx: + Number of top enriched gene program at which to stop the creation + of plots. + feature_spaces: + List of feature spaces used for the info plots. Can be ´latent´ to use + the latent embeddings for the plots, or it can be any of the samples + stored in ´adata.obs[sample_key]´ to use the respective physical + feature space for the plots. + n_top_genes_per_gp: + Number of top genes per gp to be considered in the info plots. + n_top_peaks_per_gp: + Number of top peaks per gp to be considered in the info plots. If ´>0´, + requires the model to be trained inlcuding ATAC modality. + scale_omics_ft: + If ´True´, scale genes and peaks before plotting. + save_figs: + If ´True´, save the figures. + figure_folder_path: + Folder path where the figures will be saved. + file_format: + Format with which the figures will be saved. + spot_size: + Spot size used for the spatial plots. + """ + model._check_if_trained(warn=True) + + adata = model.adata.copy() + if n_top_peaks_per_gp > 0: + if "atac" not in model.modalities_: + raise ValueError("The model needs to be trained with ATAC data if" + "'n_top_peaks_per_gp' > 0.") + adata_atac = model.adata_atac.copy() + + # TODO + if scale_omics_ft: + sc.pp.scale(adata) + if n_top_peaks_per_gp > 0: + sc.pp.scale(adata_atac) + adata.uns["omics_ft_pos_cmap"] = "RdBu" + adata.uns["omics_ft_neg_cmap"] = "RdBu_r" + else: + if n_top_peaks_per_gp > 0: + adata_atac.X = adata_atac.X.toarray() + adata.uns["omics_ft_pos_cmap"] = "Blues" + adata.uns["omics_ft_neg_cmap"] = "Reds" + + cats = list(adata.uns[differential_gp_test_results_key]["category"][ + n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) + gps = list(adata.uns[differential_gp_test_results_key]["gene_program"][ + n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) + log_bayes_factors = list(adata.uns[differential_gp_test_results_key]["log_bayes_factor"][ + n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx]) + + for gp in gps: + # Get source and target genes, gene importances and gene signs and store + # in temporary adata + gp_gene_importances_df = model.compute_gp_gene_importances( + selected_gp=gp) + + gp_source_genes_gene_importances_df = gp_gene_importances_df[ + gp_gene_importances_df["gene_entity"] == "source"] + gp_target_genes_gene_importances_df = gp_gene_importances_df[ + gp_gene_importances_df["gene_entity"] == "target"] + adata.uns["n_top_source_genes"] = n_top_genes_per_gp + adata.uns[f"{gp}_source_genes_top_genes"] = ( + gp_source_genes_gene_importances_df["gene"][ + :n_top_genes_per_gp].values) + adata.uns[f"{gp}_source_genes_top_gene_importances"] = ( + gp_source_genes_gene_importances_df["gene_importance"][ + :n_top_genes_per_gp].values) + adata.uns[f"{gp}_source_genes_top_gene_signs"] = ( + np.where(gp_source_genes_gene_importances_df[ + "gene_weight"] > 0, "+", "-")) + adata.uns["n_top_target_genes"] = n_top_genes_per_gp + adata.uns[f"{gp}_target_genes_top_genes"] = ( + gp_target_genes_gene_importances_df["gene"][ + :n_top_genes_per_gp].values) + adata.uns[f"{gp}_target_genes_top_gene_importances"] = ( + gp_target_genes_gene_importances_df["gene_importance"][ + :n_top_genes_per_gp].values) + adata.uns[f"{gp}_target_genes_top_gene_signs"] = ( + np.where(gp_target_genes_gene_importances_df[ + "gene_weight"] > 0, "+", "-")) + + if n_top_peaks_per_gp > 0: + # Get source and target peaks, peak importances and peak signs and + # store in temporary adata + gp_peak_importances_df = model.compute_gp_peak_importances( + selected_gp=gp) + gp_source_peaks_peak_importances_df = gp_peak_importances_df[ + gp_peak_importances_df["peak_entity"] == "source"] + gp_target_peaks_peak_importances_df = gp_peak_importances_df[ + gp_peak_importances_df["peak_entity"] == "target"] + adata.uns["n_top_source_peaks"] = n_top_peaks_per_gp + adata.uns[f"{gp}_source_peaks_top_peaks"] = ( + gp_source_peaks_peak_importances_df["peak"][ + :n_top_peaks_per_gp].values) + adata.uns[f"{gp}_source_peaks_top_peak_importances"] = ( + gp_source_peaks_peak_importances_df["peak_importance"][ + :n_top_peaks_per_gp].values) + adata.uns[f"{gp}_source_peaks_top_peak_signs"] = ( + np.where(gp_source_peaks_peak_importances_df[ + "peak_weight"] > 0, "+", "-")) + adata.uns["n_top_target_peaks"] = n_top_peaks_per_gp + adata.uns[f"{gp}_target_peaks_top_peaks"] = ( + gp_target_peaks_peak_importances_df["peak"][ + :n_top_peaks_per_gp].values) + adata.uns[f"{gp}_target_peaks_top_peak_importances"] = ( + gp_target_peaks_peak_importances_df["peak_importance"][ + :n_top_peaks_per_gp].values) + adata.uns[f"{gp}_target_peaks_top_peak_signs"] = ( + np.where(gp_target_peaks_peak_importances_df[ + "peak_weight"] > 0, "+", "-")) + + # Add peak counts to temporary adata for plotting + adata.obs[[peak for peak in + adata.uns[f"{gp}_target_peaks_top_peaks"]]] = ( + adata_atac.X[ + :, [adata_atac.var_names.tolist().index(peak) + for peak in adata.uns[f"{gp}_target_peaks_top_peaks"]]]) + adata.obs[[peak for peak in + adata.uns[f"{gp}_source_peaks_top_peaks"]]] = ( + adata_atac.X[ + :, [adata_atac.var_names.tolist().index(peak) + for peak in adata.uns[f"{gp}_source_peaks_top_peaks"]]]) + else: + adata.uns["n_top_source_peaks"] = 0 + adata.uns["n_top_target_peaks"] = 0 + + for feature_space in feature_spaces: + plot_enriched_gp_info_plots_( + adata=adata, + sample_key=sample_key, + gps=gps, + log_bayes_factors=log_bayes_factors, + cat_key=cat_key, + cat_palette=cat_palette, + cats=cats, + feature_space=feature_space, + spot_size=spot_size, + suptitle=f"{plot_label.replace('_', ' ').title()} " + f"Top {n_top_enriched_gp_start_idx} to " + f"{n_top_enriched_gp_end_idx} Enriched GPs: " + f"GP Scores and Omics Feature Counts in " + f"{feature_space} Feature Space", + save_fig=save_figs, + figure_folder_path=figure_folder_path, + fig_name=f"{plot_label}_top_{n_top_enriched_gp_start_idx}" + f"-{n_top_enriched_gp_end_idx}_enriched_gps_gp_scores_" + f"omics_feature_counts_in_{feature_space}_" + f"feature_space.{file_format}") + + +def plot_enriched_gp_info_plots_(adata: AnnData, + sample_key: str, + gps: list, + log_bayes_factors: list, + cat_key: str, + cat_palette: dict, + cats: list, + feature_space: str, + spot_size: float, + suptitle: str, + save_fig: bool, + figure_folder_path: str, + fig_name: str): + """ + This is a helper function to plot gene program info plots in a specified + feature space. + + Parameters + ---------- + adata: + An AnnData object with stored information about the gene programs to be + plotted. + sample_key: + Key in ´adata.obs´ where the samples are stored. + gps: + List of gene programs for which info plots will be created. + log_bayes_factors: + List of log bayes factors corresponding to gene programs + cat_key: + Key in ´adata.obs´ where the categories that are used as colors for the + enriched category plot are stored. + cat_palette: + Dictionary of colors that are used to highlight the categories, where + the category is the key of the dictionary and the color is the value. + cats: + List of categories for which the corresponding gene programs in ´gps´ + are enriched. + feature_space: + Feature space used for the plots. Can be ´latent´ to use the latent + embeddings for the plots, or it can be any of the samples stored in + ´adata.obs[sample_key]´ to use the respective physical feature space for + the plots. + spot_size: + Spot size used for the spatial plots. + subtitle: + Overall figure title. + save_fig: + If ´True´, save the figure. + figure_folder_path: + Path of the folder where the figure will be saved. + fig_name: + Name of the figure under which it will be saved. + """ + # Define figure configurations + ncols = (2 + + adata.uns["n_top_source_genes"] + + adata.uns["n_top_target_genes"] + + adata.uns["n_top_source_peaks"] + + adata.uns["n_top_target_peaks"]) + fig_width = (12 + (6 * ( + adata.uns["n_top_source_genes"] + + adata.uns["n_top_target_genes"] + + adata.uns["n_top_source_peaks"] + + adata.uns["n_top_target_peaks"]))) + wspace = 0.3 + fig, axs = plt.subplots(nrows=len(gps), + ncols=ncols, + figsize=(fig_width, 6*len(gps))) + if axs.ndim == 1: + axs = axs.reshape(1, -1) + title = fig.suptitle(t=suptitle, + x=0.55, + y=(1.1 if len(gps) == 1 else 0.97), + fontsize=20) + + # Plot enriched gp category and gene program latent scores + for i, gp in enumerate(gps): + if feature_space == "latent": + sc.pl.umap( + adata, + color=cat_key, + palette=cat_palette, + groups=cats[i], + ax=axs[i, 0], + title="Enriched GP Category", + legend_loc="on data", + na_in_legend=False, + show=False) + sc.pl.umap( + adata, + color=gps[i], + color_map="RdBu", + ax=axs[i, 1], + title=f"{gp[:gp.index('_')]}\n" + f"{gp[gp.index('_') + 1: gp.rindex('_')].replace('_', ' ')}" + f"\n{gp[gps[i].rindex('_') + 1:]} score (LBF: {round(log_bayes_factors[i])})", + colorbar_loc="bottom", + show=False) + else: + sc.pl.spatial( + adata=adata[adata.obs[sample_key] == feature_space], + color=cat_key, + palette=cat_palette, + groups=cats[i], + ax=axs[i, 0], + spot_size=spot_size, + title="Enriched GP Category", + legend_loc="on data", + na_in_legend=False, + show=False) + sc.pl.spatial( + adata=adata[adata.obs[sample_key] == feature_space], + color=gps[i], + color_map="RdBu", + spot_size=spot_size, + title=f"{gps[i].split('_', 1)[0]}\n{gps[i].split('_', 1)[1]} " + f"(LBF: {round(log_bayes_factors[i], 2)})", + legend_loc=None, + ax=axs[i, 1], + colorbar_loc="bottom", + show=False) + axs[i, 0].xaxis.label.set_visible(False) + axs[i, 0].yaxis.label.set_visible(False) + axs[i, 1].xaxis.label.set_visible(False) + axs[i, 1].yaxis.label.set_visible(False) + + # Plot omics feature counts (or log normalized counts) + modality_entities = [] + if len(adata.uns[f"{gp}_source_genes_top_genes"]) > 0: + modality_entities.append("source_genes") + if len(adata.uns[f"{gp}_target_genes_top_genes"]) > 0: + modality_entities.append("target_genes") + if f"{gp}_source_peaks_top_peaks" in adata.uns.keys(): + gp_n_source_peaks_top_peaks = ( + len(adata.uns[f"{gp}_source_peaks_top_peaks"])) + if len(adata.uns[f"{gp}_source_peaks_top_peaks"]) > 0: + modality_entities.append("source_peaks") + else: + gp_n_source_peaks_top_peaks = 0 + if f"{gp}_target_peaks_top_peaks" in adata.uns.keys(): + gp_n_target_peaks_top_peaks = ( + len(adata.uns[f"{gp}_target_peaks_top_peaks"])) + if len(adata.uns[f"{gp}_target_peaks_top_peaks"]) > 0: + modality_entities.append("target_peaks") + else: + gp_n_target_peaks_top_peaks = 0 + for modality_entity in modality_entities: + # Define k for index iteration + if modality_entity == "source_genes": + k = 0 + elif modality_entity == "target_genes": + k = len(adata.uns[f"{gp}_source_genes_top_genes"]) + elif modality_entity == "source_peaks": + k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) + + len(adata.uns[f"{gp}_target_genes_top_genes"])) + elif modality_entity == "target_peaks": + k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) + + len(adata.uns[f"{gp}_target_genes_top_genes"]) + + len(adata.uns[f"{gp}_source_peaks_top_peaks"])) + for j in range(len(adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1]}"])): + if feature_space == "latent": + sc.pl.umap( + adata, + color=adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1]}"][j], + color_map=(adata.uns["omics_ft_pos_cmap"] if + adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]), + ax=axs[i, 2+k+j], + legend_loc="on data", + na_in_legend=False, + title=f"""{adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1]}" + ][j]}: """ + f"""{adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_importances"][j]:.2f} """ + f"({modality_entity[:-1]}; " + f"""{adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_signs"][j]})""", + colorbar_loc="bottom", + show=False) + else: + sc.pl.spatial( + adata=adata[adata.obs[sample_key] == feature_space], + color=adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1]}"][j], + color_map=(adata.uns["omics_ft_pos_cmap"] if + adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]), + legend_loc="on data", + na_in_legend=False, + ax=axs[i, 2+k+j], + spot_size=spot_size, + title=f"""{adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1]}" + ][j]} \n""" + f"""({adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_importances"][j]:.2f}; """ + f"{modality_entity[:-1]}; " + f"""{adata.uns[f"{gp}_{modality_entity}_top_" + f"{modality_entity.split('_')[1][:-1]}" + "_signs"][j]})""", + colorbar_loc="bottom", + show=False) + axs[i, 2+k+j].xaxis.label.set_visible(False) + axs[i, 2+k+j].yaxis.label.set_visible(False) + # Remove unnecessary axes + for l in range(2 + + len(adata.uns[f"{gp}_source_genes_top_genes"]) + + len(adata.uns[f"{gp}_target_genes_top_genes"]) + + gp_n_source_peaks_top_peaks + + gp_n_target_peaks_top_peaks, ncols): + axs[i, l].set_visible(False) + + # Save and display plot + plt.subplots_adjust(wspace=wspace, hspace=0.275) + if save_fig: + fig.savefig(f"{figure_folder_path}/{fig_name}", + bbox_extra_artists=(title,), + bbox_inches="tight") + plt.show() + +default_color_dict = { + "0": "#66C5CC", + "1": "#F6CF71", + "2": "#F89C74", + "3": "#DCB0F2", + "4": "#87C55F", + "5": "#9EB9F3", + "6": "#FE88B1", + "7": "#C9DB74", + "8": "#8BE0A4", + "9": "#B497E7", + "10": "#D3B484", + "11": "#B3B3B3", + "12": "#276A8C", # Royal Blue + "13": "#DAB6C4", # Pink + "14": "#C38D9E", # Mauve-Pink + "15": "#9D88A2", # Mauve + "16": "#FF4D4D", # Light Red + "17": "#9B4DCA", # Lavender-Purple + "18": "#FF9CDA", # Bright Pink + "19": "#FF69B4", # Hot Pink + "20": "#FF00FF", # Magenta + "21": "#DA70D6", # Orchid + "22": "#BA55D3", # Medium Orchid + "23": "#8A2BE2", # Blue Violet + "24": "#9370DB", # Medium Purple + "25": "#7B68EE", # Medium Slate Blue + "26": "#4169E1", # Royal Blue + "27": "#FF8C8C", # Salmon Pink + "28": "#FFAA80", # Light Coral + "29": "#48D1CC", # Medium Turquoise + "30": "#40E0D0", # Turquoise + "31": "#00FF00", # Lime + "32": "#7FFF00", # Chartreuse + "33": "#ADFF2F", # Green Yellow + "34": "#32CD32", # Lime Green + "35": "#228B22", # Forest Green + "36": "#FFD8B8", # Peach + "37": "#008080", # Teal + "38": "#20B2AA", # Light Sea Green + "39": "#00FFFF", # Cyan + "40": "#00BFFF", # Deep Sky Blue + "41": "#4169E1", # Royal Blue + "42": "#0000CD", # Medium Blue + "43": "#00008B", # Dark Blue + "44": "#8B008B", # Dark Magenta + "45": "#FF1493", # Deep Pink + "46": "#FF4500", # Orange Red + "47": "#006400", # Dark Green + "48": "#FF6347", # Tomato + "49": "#FF7F50", # Coral + "50": "#CD5C5C", # Indian Red + "51": "#B22222", # Fire Brick + "52": "#FFB83F", # Light Orange + "53": "#8B0000", # Dark Red + "54": "#D2691E", # Chocolate + "55": "#A0522D", # Sienna + "56": "#800000", # Maroon + "57": "#808080", # Gray + "58": "#A9A9A9", # Dark Gray + "59": "#C0C0C0", # Silver + "60": "#9DD84A", + "61": "#F5F5F5", # White Smoke + "62": "#F17171", # Light Red + "63": "#000000", # Black + "64": "#FF8C42", # Tangerine + "65": "#F9A11F", # Bright Orange-Yellow + "66": "#FACC15", # Golden Yellow + "67": "#E2E062", # Pale Lime + "68": "#BADE92", # Soft Lime + "69": "#70C1B3", # Greenish-Blue + "70": "#41B3A3", # Turquoise + "71": "#5EAAA8", # Gray-Green + "72": "#72B01D", # Chartreuse + "73": "#9CD08F", # Light Green + "74": "#8EBA43", # Olive Green + "75": "#FAC8C3", # Light Pink + "76": "#E27D60", # Dark Salmon + "77": "#C38D9E", # Mauve-Pink + "78": "#937D64", # Light Brown + "79": "#B1C1CC", # Light Blue-Gray + "80": "#88A0A8", # Gray-Blue-Green + "81": "#4E598C", # Dark Blue-Purple + "82": "#4B4E6D", # Dark Gray-Blue + "83": "#8E9AAF", # Light Blue-Grey + "84": "#C0D6DF", # Pale Blue-Grey + "85": "#97C1A9", # Blue-Green + "86": "#4C6E5D", # Dark Green + "87": "#95B9C7", # Pale Blue-Green + "88": "#C1D5E0", # Pale Gray-Blue + "89": "#ECDB54", # Bright Yellow + "90": "#E89B3B", # Bright Orange + "91": "#CE5A57", # Deep Red + "92": "#C3525A", # Dark Red + "93": "#B85D8E", # Berry + "94": "#7D5295", # Deep Purple + "-1" : "#E1D9D1", + "None" : "#E1D9D1" +} + +def create_new_color_dict( + adata, + cat_key, + color_palette="default", + overwrite_color_dict={"-1" : "#E1D9D1"}, + skip_default_colors=0): + """ + Create a dictionary of color hexcodes for a specified category. + + Parameters + ---------- + adata: + AnnData object. + cat_key: + Key in ´adata.obs´ where the categories are stored for which color + hexcodes will be created. + color_palette: + Type of color palette. + overwrite_color_dict: + Dictionary with overwrite values that will take precedence over the + automatically created dictionary. + skip_default_colors: + Number of colors to skip from the default color dict. + + Returns + ---------- + new_color_dict: + The color dictionary with a hexcode for each category. + """ + new_categories = adata.obs[cat_key].unique().tolist() + if color_palette == "cell_type_30": + # https://github.com/scverse/scanpy/blob/master/scanpy/plotting/palettes.py#L40 + new_color_dict = {key: value for key, value in zip( + new_categories, + ["#023fa5", + "#7d87b9", + "#bec1d4", + "#d6bcc0", + "#bb7784", + "#8e063b", + "#4a6fe3", + "#8595e1", + "#b5bbe3", + "#e6afb9", + "#e07b91", + "#d33f6a", + "#11c638", + "#8dd593", + "#c6dec7", + "#ead3c6", + "#f0b98d", + "#ef9708", + "#0fcfc0", + "#9cded6", + "#d5eae7", + "#f3e1eb", + "#f6c4e1", + "#f79cd4", + '#7f7f7f', + "#c7c7c7", + "#1CE6FF", + "#336600"])} + elif color_palette == "cell_type_20": + # https://github.com/vega/vega/wiki/Scales#scale-range-literals (some adjusted) + new_color_dict = {key: value for key, value in zip( + new_categories, + ['#1f77b4', + '#ff7f0e', + '#279e68', + '#d62728', + '#aa40fc', + '#8c564b', + '#e377c2', + '#b5bd61', + '#17becf', + '#aec7e8', + '#ffbb78', + '#98df8a', + '#ff9896', + '#c5b0d5', + '#c49c94', + '#f7b6d2', + '#dbdb8d', + '#9edae5', + '#ad494a', + '#8c6d31'])} + elif color_palette == "cell_type_10": + # scanpy vega10 + new_color_dict = {key: value for key, value in zip( + new_categories, + ['#7f7f7f', + '#ff7f0e', + '#279e68', + '#e377c2', + '#17becf', + '#8c564b', + '#d62728', + '#1f77b4', + '#b5bd61', + '#aa40fc'])} + elif color_palette == "batch": + # sns.color_palette("colorblind").as_hex() + new_color_dict = {key: value for key, value in zip( + new_categories, + ['#0173b2', '#d55e00', '#ece133', '#ca9161', '#fbafe4', + '#949494', '#de8f05', '#029e73', '#cc78bc', '#56b4e9', + '#F0F8FF', '#FAEBD7', '#00FFFF', '#7FFFD4', '#F0FFFF', + '#F5F5DC', '#FFE4C4', '#000000', '#FFEBCD', '#0000FF', + '#8A2BE2', '#A52A2A', '#DEB887', '#5F9EA0', '#7FFF00', + '#D2691E', '#FF7F50', '#6495ED', '#FFF8DC', '#DC143C'])} + elif color_palette == "default": + new_color_dict = {key: value for key, value in zip(new_categories, list(default_color_dict.values())[skip_default_colors:])} + for key, val in overwrite_color_dict.items(): + new_color_dict[key] = val + return new_color_dict + + +def plot_non_zero_gene_count_means_dist( + adata: AnnData, + genes: list, + gene_label: str): + """ + Plot distribution of non zero gene count means in the adata over all + specified genes. + """ + gene_counts = adata[ + :, [gene for gene in adata.var_names if gene in genes]].layers["counts"] + nz_gene_means = np.mean( + np.ma.masked_equal(gene_counts.toarray(), 0), axis=0).data + + sns.kdeplot(nz_gene_means) + plt.title(f"{gene_label} Genes Average Non-Zero Gene Counts per Gene") + plt.xlabel("Average Non-zero Gene Counts") + plt.ylabel("Gene Density") + plt.show() + + +def compute_communication_gp_network( + gp_list: list, + model: NicheCompass, + group_key: str="niche", + filter_key: Optional[str]=None, + filter_cat: Optional[str]=None, + n_neighbors: int=90): + """ + Compute a network of category aggregated cell-pair communication strengths. + + First, compute cell-cell communication potential scores for each cell. + Then dot product them and take into account neighborhoods to compute + cell-pair communication strengths. Then, normalize cell-pair communication + strengths. + + Parameters + ---------- + gp_list: + List of GPs for which the cell-pair communication strengths are computed. + model: + A trained NicheCompass model. + group_key: + Key in ´adata.obs´ where the groups are stored over which the cell-pair + communication strengths will be aggregated. + filter_key: + Key in ´adata.obs´ that contains the category for which the results are + filtered. + filter_cat: + Category for which the results are filtered. + n_neighbors: + Number of neighbors for the gp-specific neighborhood graph. + + Returns + ---------- + network_df: + A pandas dataframe with aggregated, normalized cell-pair communication strengths. + """ + # Compute neighborhood graph + compute_knn = True + if 'spatial_cci' in model.adata.uns.keys(): + if model.adata.uns['spatial_cci']['params']['n_neighbors'] == n_neighbors: + compute_knn = False + if compute_knn: + sc.pp.neighbors(model.adata, + n_neighbors=n_neighbors, + use_rep="spatial", + key_added="spatial_cci") + + gp_network_dfs = [] + gp_summary_df = model.get_gp_summary() + for gp in gp_list: + gp_idx = model.adata.uns[model.gp_names_key_].tolist().index(gp) + active_gp_idx = model.adata.uns[model.active_gp_names_key_].tolist().index(gp) + gp_scores = model.adata.obsm[model.latent_key_][:, active_gp_idx] + gp_targets_cats = model.adata.varm[model.gp_targets_categories_mask_key_][:, gp_idx] + gp_sources_cats = model.adata.varm[model.gp_sources_categories_mask_key_][:, gp_idx] + targets_cats_label_encoder = model.adata.uns[model.targets_categories_label_encoder_key_] + sources_cats_label_encoder = model.adata.uns[model.sources_categories_label_encoder_key_] + + sources_cat_idx_dict = {} + for source_cat, source_cat_label in sources_cats_label_encoder.items(): + sources_cat_idx_dict[source_cat] = np.where(gp_sources_cats == source_cat_label)[0] + + targets_cat_idx_dict = {} + for target_cat, target_cat_label in targets_cats_label_encoder.items(): + targets_cat_idx_dict[target_cat] = np.where(gp_targets_cats == target_cat_label)[0] + + # Get indices of all source and target genes + source_genes_idx = np.array([], dtype=np.int64) + for key in sources_cat_idx_dict.keys(): + source_genes_idx = np.append(source_genes_idx, + sources_cat_idx_dict[key]) + target_genes_idx = np.array([], dtype=np.int64) + for key in targets_cat_idx_dict.keys(): + target_genes_idx = np.append(target_genes_idx, + targets_cat_idx_dict[key]) + + # Compute cell-cell communication potential scores + gp_source_scores = np.zeros((len(model.adata.obs), len(source_genes_idx))) + gp_target_scores = np.zeros((len(model.adata.obs), len(target_genes_idx))) + + for i, source_gene_idx in enumerate(source_genes_idx): + source_gene = model.adata.var_names[source_gene_idx] + gp_source_scores[:, i] = ( + model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten().max() * + gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes"].values[0].index(source_gene)] * + gp_scores) + + for j, target_gene_idx in enumerate(target_genes_idx): + target_gene = model.adata.var_names[target_gene_idx] + gp_target_scores[:, j] = ( + model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten().max() * + gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes"].values[0].index(target_gene)] * + gp_scores) + + agg_gp_source_score = gp_source_scores.mean(1).astype("float32") + agg_gp_target_score = gp_target_scores.mean(1).astype("float32") + agg_gp_source_score[agg_gp_source_score < 0] = 0. + agg_gp_target_score[agg_gp_target_score < 0] = 0. + + model.adata.obs[f"{gp}_source_score"] = agg_gp_source_score + model.adata.obs[f"{gp}_target_score"] = agg_gp_target_score + + del(gp_target_scores) + del(gp_source_scores) + + agg_gp_source_score = sp.csr_matrix(agg_gp_source_score) + agg_gp_target_score = sp.csr_matrix(agg_gp_target_score) + + model.adata.obsp[f"{gp}_connectivities"] = (model.adata.obsp["spatial_cci_connectivities"] > 0).multiply( + agg_gp_source_score.T.dot(agg_gp_target_score)) + + # Aggregate gp connectivities for each group + gp_network_df_pivoted = aggregate_obsp_matrix_per_cell_type( + adata=model.adata, + obsp_key=f"{gp}_connectivities", + cell_type_key=group_key, + group_key=filter_key, + agg_rows=True) + + if filter_key is not None: + gp_network_df_pivoted = gp_network_df_pivoted.loc[filter_cat, :] + + gp_network_df = gp_network_df_pivoted.melt(var_name="source", value_name="gp_score", ignore_index=False).reset_index() + gp_network_df.columns = ["source", "target", "strength"] + + gp_network_df = gp_network_df.sort_values("strength", ascending=False) + + # Normalize strength + min_value = gp_network_df["strength"].min() + max_value = gp_network_df["strength"].max() + gp_network_df["strength_unscaled"] = gp_network_df["strength"] + gp_network_df["strength"] = (gp_network_df["strength"] - min_value) / (max_value - min_value) + gp_network_df["strength"] = np.round(gp_network_df["strength"], 2) + gp_network_df = gp_network_df[gp_network_df["strength"] > 0] + + gp_network_df["edge_type"] = gp + gp_network_dfs.append(gp_network_df) + + network_df = pd.concat(gp_network_dfs, ignore_index=True) + return network_df + + +def visualize_communication_gp_network( + adata, + network_df, + cat_colors, + edge_type_colors: Optional[dict]=None, + edge_width_scale: int=20.0, + node_size: int=500, + fontsize: int=14, + figsize: Tuple[int, int]=(18, 16), + plot_legend: bool=True, + save: bool=False, + save_path: str="communication_gp_network.svg", + show: bool=True, + text_space: float=1.3, + connection_style="arc3, rad = 0.1", + cat_key: str="niche", + edge_attr: str="strength"): + """ + Visualize a communication gp network. + """ + # Assuming you have unique edge types in your 'edge_type' column + edge_types = np.unique(network_df['edge_type']) + + if edge_type_colors is None: + # Colorblindness adjusted vega_10 + # See https://github.com/theislab/scanpy/issues/387 + vega_10 = list(map(colors.to_hex, cm.tab10.colors)) + vega_10_scanpy = vega_10.copy() + vega_10_scanpy[2] = "#279e68" # green + vega_10_scanpy[4] = "#aa40fc" # purple + vega_10_scanpy[8] = "#b5bd61" # kakhi + edge_type_colors = vega_10_scanpy + + # Create a dictionary that maps edge types to colors + edge_type_color_dict = {edge_type: color for edge_type, color in zip(edge_types, edge_type_colors)} + + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) + ax.axis("off") + G = nx.from_pandas_edgelist( + network_df, + source="source", + target="target", + edge_attr=["edge_type", edge_attr], + create_using=nx.DiGraph(), + ) + pos = nx.circular_layout(G) + + nx.set_node_attributes(G, cat_colors, "color") + node_color = nx.get_node_attributes(G, "color") + + description = nx.draw_networkx_labels(G, pos, font_size=fontsize) + n = adata.obs[cat_key].nunique() + node_list = sorted(G.nodes()) + angle = [] + angle_dict = {} + for i, node in zip(range(n), node_list): + theta = 2.0 * np.pi * i / n + angle.append((np.cos(theta), np.sin(theta))) + angle_dict[node] = theta + pos = {} + for node_i, node in enumerate(node_list): + pos[node] = angle[node_i] + + r = fig.canvas.get_renderer() + trans = plt.gca().transData.inverted() + for node, t in description.items(): + bb = t.get_window_extent(renderer=r) + bbdata = bb.transformed(trans) + radius = text_space + bbdata.width / 2.0 + position = (radius * np.cos(angle_dict[node]), radius * np.sin(angle_dict[node])) + t.set_position(position) + t.set_rotation(angle_dict[node] * 360.0 / (2.0 * np.pi)) + t.set_clip_on(False) + + edgelist = [(u, v) for u, v, e in G.edges(data=True) if u != v] + edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u != v] + width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u != v] + + h2 = nx.draw_networkx( + G, + pos, + with_labels=False, + node_size=node_size, + edgelist=edgelist, + width=width, + edge_vmin=0.0, + edge_vmax=1.0, + edge_color=edge_colors, # Use the edge type colors here + arrows=True, + arrowstyle="-|>", + arrowsize=20, + vmin=0.0, + vmax=1.0, + cmap=plt.cm.binary, # Use a colormap for node colors if needed + node_color=list(node_color.values()), + ax=ax, + connectionstyle=connection_style, + ) + + #https://stackoverflow.com/questions/19877666/add-legends-to-linecollection-plot - uses plotted data to define the color but here we already have colors defined, so just need a Line2D object. + def make_proxy(clr, mappable, **kwargs): + return Line2D([0, 1], [0, 1], color=clr, **kwargs) + + # generate proxies with the above function + proxies = [make_proxy(clr, h2, lw=5) for clr in set(edge_colors)] + labels = [edge.split("_")[0] + " GP" for edge in edge_types[::-1]] + + if plot_legend: + lgd = plt.legend(proxies, labels, loc="lower left") + + edgelist = [(u, v) for u, v, e in G.edges(data=True) if ((u == v))] + [(u, v) for u, v, e in G.edges(data=True) if ((u != v))] + edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u == v] + width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u == v] + [0 for u, v, e in G.edges(data=True) if ((u != v))] + nx.draw_networkx_edges( + G, + pos, + node_size=node_size, + edgelist=edgelist, + width=width, + edge_vmin=0.0, + edge_vmax=1.0, + edge_color=edge_colors, + arrows=False, + arrowstyle="-|>", + arrowsize=20, + ax=ax, + connectionstyle=connection_style) + plt.tight_layout() + if save: + plt.savefig(save_path) + if show: + plt.show() + plt.close(fig) + plt.ion()