Switch to side-by-side view

--- a
+++ b/src/nichecompass/utils/analysis.py
@@ -0,0 +1,1147 @@
+"""
+This module contains utilities to analyze niches inferred by the NicheCompass
+model.
+"""
+
+from typing import Optional, Tuple
+
+#import holoviews as hv
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scanpy as sc
+import scipy.sparse as sp
+import seaborn as sns
+from anndata import AnnData
+from matplotlib import cm, colors
+from matplotlib.lines import Line2D
+import networkx as nx
+
+from ..models import NicheCompass
+
+
+def aggregate_obsp_matrix_per_cell_type(
+        adata: AnnData,
+        obsp_key: str,
+        cell_type_key: str="cell_type",
+        group_key: Optional[str]=None,
+        agg_rows: bool=False):
+    """
+    Generic function to aggregate adjacency matrices stored in
+    ´adata.obsp[obsp_key]´ on cell type level. It can be used to aggregate the
+    node label aggregator aggregation weights alpha or the reconstructed adjacency
+    matrix of a trained NicheCompass model by neighbor cell type for downstream
+    analysis.
+
+    Parameters
+    ----------
+    adata:
+        AnnData object which contains outputs of NicheCompass model training.
+    obsp_key:
+        Key in ´adata.obsp´ where the matrix to be aggregated is stored.
+    cell_type_key:
+        Key in ´adata.obs´ where the cell type labels are stored.
+    group_key:
+        Key in ´adata.obs´ where additional grouping labels are stored.    
+    agg_rows:
+        If ´True´, also aggregate over the observations on cell type level.
+
+    Returns
+    ----------
+    cell_type_agg_df:
+        Pandas DataFrame with the aggregated obsp values (dim: n_obs x
+        n_cell_types if ´agg_rows == False´, else n_cell_types x n_cell_types).
+    """
+    n_obs = len(adata)
+    n_cell_types = adata.obs[cell_type_key].nunique()
+    sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist())
+
+    cell_type_label_encoder = {k: v for k, v in zip(
+        sorted_cell_types,
+        range(n_cell_types))}
+
+    # Retrieve non zero indices and non zero values, and create row-wise
+    # observation cell type index
+    nz_obsp_idx = adata.obsp[obsp_key].nonzero()
+    neighbor_cell_type_index = adata.obs[cell_type_key][nz_obsp_idx[1]].map(
+        cell_type_label_encoder).values
+    adata.obsp[obsp_key].eliminate_zeros() # In some sparse reps 0s can appear
+    nz_obsp = adata.obsp[obsp_key].data
+
+    # Use non zero indices, non zero values and row-wise observation cell type
+    # index to construct new df with cell types as columns and row-wise
+    # aggregated values per cell type index as values
+    cell_type_agg = np.zeros((n_obs, n_cell_types))
+    np.add.at(cell_type_agg,
+              (nz_obsp_idx[0], neighbor_cell_type_index),
+              nz_obsp)
+    cell_type_agg_df = pd.DataFrame(
+        cell_type_agg,
+        columns=sorted_cell_types)
+    
+    # Add cell type labels of observations
+    cell_type_agg_df[cell_type_key] = adata.obs[cell_type_key].values
+
+    # If specified, add group label
+    if group_key is not None:
+        cell_type_agg_df[group_key] = adata.obs[group_key].values
+
+    if agg_rows:
+        # In addition, aggregate values across rows to get a
+        # (n_cell_types x n_cell_types) df
+        if group_key is not None:
+            cell_type_agg_df = cell_type_agg_df.groupby(
+                [group_key, cell_type_key]).sum()
+        else:
+            cell_type_agg_df = cell_type_agg_df.groupby(cell_type_key).sum()
+
+        # Sort index to have same order as columns
+        cell_type_agg_df = cell_type_agg_df.loc[
+            sorted(cell_type_agg_df.index.tolist()), :]
+        
+    return cell_type_agg_df
+
+
+def create_cell_type_chord_plot_from_df(
+        adata: AnnData,
+        df: pd.DataFrame,
+        link_threshold: float=0.01,
+        cell_type_key: str="cell_type",
+        group_key: Optional[str]=None,
+        groups: str="all",
+        plot_label: str="Niche",
+        save_fig: bool=False,
+        file_path: Optional[str]=None):
+    """
+    Create a cell type chord diagram per group based on an input DataFrame.
+
+    Parameters
+    ----------
+    adata:
+        AnnData object which contains outputs of NicheCompass model training.
+    df:
+        A Pandas DataFrame that contains the connection values for the chord
+        plot (dim: (n_groups x n_cell_types) x n_cell_types).
+    link_threshold:
+        Ratio of link strength that a cell type pair needs to exceed compared to
+        the cell type pair with the maximum link strength to be considered a
+        link for the chord plot.
+    cell_type_key:
+        Key in ´adata.obs´ where the cell type labels are stored.
+    group_key:
+        Key in ´adata.obs´ where additional group labels are stored.
+    groups:
+        List of groups that will be plotted. If ´all´, plot all groups.
+    plot_label:
+        Shared label for the plots.
+    save_fig:
+        If ´True´, save the figure.
+    file_path:
+        Path where to save the figure.
+    """
+    hv.extension("bokeh")
+    hv.output(size=200)
+
+    sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist())
+
+    # Get group labels
+    if (group_key is not None) & (groups == "all"):
+        group_labels = df.index.get_level_values(
+            df.index.names.index(group_key)).unique().tolist()
+    elif (group_key is not None) & (groups != "all"):
+        group_labels = groups
+    else:
+        group_labels = [""]
+
+    chord_list = []
+    for group_label in group_labels:
+        if group_label == "":
+            group_df = df
+        else:
+            group_df = df[df.index.get_level_values(
+                df.index.names.index(group_key)) == group_label]
+        
+        # Get max value (over rows and columns) of the group for thresholding
+        group_max = group_df.max().max()
+
+        # Create group chord links
+        links_list = []
+        for i in range(len(sorted_cell_types)):
+            for j in range(len(sorted_cell_types)):
+                if group_df.iloc[i, j] > group_max * link_threshold:
+                    link_dict = {}
+                    link_dict["source"] = j
+                    link_dict["target"] = i
+                    link_dict["value"] = group_df.iloc[i, j]
+                    links_list.append(link_dict)
+        links = pd.DataFrame(links_list)
+
+        # Create group chord nodes (only where links exist)
+        nodes_list = []
+        nodes_idx = []
+        for i, cell_type in enumerate(sorted_cell_types):
+            if i in (links["source"].values) or i in (links["target"].values):
+                nodes_idx.append(i)
+                nodes_dict = {}
+                nodes_dict["name"] = cell_type
+                nodes_dict["group"] = 1
+                nodes_list.append(nodes_dict)
+        nodes = hv.Dataset(pd.DataFrame(nodes_list, index=nodes_idx), "index")
+
+        # Create group chord plot
+        chord = hv.Chord((links, nodes)).select(value=(5, None))
+        chord.opts(hv.opts.Chord(cmap="Category20",
+                                 edge_cmap="Category20",
+                                 edge_color=hv.dim("source").str(),
+                                 labels="name",
+                                 node_color=hv.dim("index").str(),
+                                 title=f"{plot_label} {group_label}"))
+        chord_list.append(chord)
+    
+    # Display chord plots
+    layout = hv.Layout(chord_list).cols(2)
+    hv.output(layout)
+
+    # Save chord plots
+    if save_fig:
+        hv.save(layout,
+                file_path,
+                fmt="png")
+
+        
+def generate_enriched_gp_info_plots(plot_label: str,
+                                    model: NicheCompass,
+                                    sample_key: str,
+                                    differential_gp_test_results_key: str,
+                                    cat_key: str,
+                                    cat_palette: dict,
+                                    n_top_enriched_gp_start_idx: int=0,
+                                    n_top_enriched_gp_end_idx: int=10,
+                                    feature_spaces: list=["latent"],
+                                    n_top_genes_per_gp: int=3,
+                                    n_top_peaks_per_gp: int=0,
+                                    scale_omics_ft: bool=False,
+                                    save_figs: bool=False,
+                                    figure_folder_path: str="",
+                                    file_format: str="png",
+                                    spot_size: float=30.):
+    """
+    Generate info plots of enriched gene programs. These show the enriched
+    category, the gp activities, as well as the counts (or log normalized
+    counts) of the top genes and/or peaks in a specified feature space.
+    
+    Parameters
+    ----------
+    plot_label:
+        Main label of the plots.
+    model:
+        A trained NicheCompass model.
+    sample_key:
+        Key in ´adata.obs´ where the samples are stored.
+    differential_gp_test_results_key:
+        Key in ´adata.uns´ where the results of the differential gene program
+        testing are stored.
+    cat_key:
+        Key in ´adata.obs´ where the categories that are used as colors for the
+        enriched category plot are stored.
+    cat_palette:
+        Dictionary of colors that are used to highlight the categories, where
+        the category is the key of the dictionary and the color is the value.
+    n_top_enriched_gp_start_idx:
+        Number of top enriched gene program from which to start the creation
+        of plots.
+    n_top_enriched_gp_end_idx:
+        Number of top enriched gene program at which to stop the creation
+        of plots.
+    feature_spaces:
+        List of feature spaces used for the info plots. Can be ´latent´ to use
+        the latent embeddings for the plots, or it can be any of the samples
+        stored in ´adata.obs[sample_key]´ to use the respective physical
+        feature space for the plots.
+    n_top_genes_per_gp:
+        Number of top genes per gp to be considered in the info plots.
+    n_top_peaks_per_gp:
+        Number of top peaks per gp to be considered in the info plots. If ´>0´,
+        requires the model to be trained inlcuding ATAC modality.
+    scale_omics_ft:
+        If ´True´, scale genes and peaks before plotting.
+    save_figs:
+        If ´True´, save the figures.
+    figure_folder_path:
+        Folder path where the figures will be saved.
+    file_format:
+        Format with which the figures will be saved.
+    spot_size:
+        Spot size used for the spatial plots.
+    """
+    model._check_if_trained(warn=True)
+
+    adata = model.adata.copy()
+    if n_top_peaks_per_gp > 0:
+        if "atac" not in model.modalities_:
+            raise ValueError("The model needs to be trained with ATAC data if"
+                             "'n_top_peaks_per_gp' > 0.")
+        adata_atac = model.adata_atac.copy()
+    
+    # TODO
+    if scale_omics_ft:
+        sc.pp.scale(adata)
+        if n_top_peaks_per_gp > 0:
+            sc.pp.scale(adata_atac)
+        adata.uns["omics_ft_pos_cmap"] = "RdBu"
+        adata.uns["omics_ft_neg_cmap"] = "RdBu_r"
+    else:
+        if n_top_peaks_per_gp > 0:
+            adata_atac.X = adata_atac.X.toarray()
+        adata.uns["omics_ft_pos_cmap"] = "Blues"
+        adata.uns["omics_ft_neg_cmap"] = "Reds"
+        
+    cats = list(adata.uns[differential_gp_test_results_key]["category"][
+        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
+    gps = list(adata.uns[differential_gp_test_results_key]["gene_program"][
+        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
+    log_bayes_factors = list(adata.uns[differential_gp_test_results_key]["log_bayes_factor"][
+        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
+    
+    for gp in gps:
+        # Get source and target genes, gene importances and gene signs and store
+        # in temporary adata
+        gp_gene_importances_df = model.compute_gp_gene_importances(
+            selected_gp=gp)
+        
+        gp_source_genes_gene_importances_df = gp_gene_importances_df[
+            gp_gene_importances_df["gene_entity"] == "source"]
+        gp_target_genes_gene_importances_df = gp_gene_importances_df[
+            gp_gene_importances_df["gene_entity"] == "target"]
+        adata.uns["n_top_source_genes"] = n_top_genes_per_gp
+        adata.uns[f"{gp}_source_genes_top_genes"] = (
+            gp_source_genes_gene_importances_df["gene"][
+                :n_top_genes_per_gp].values)
+        adata.uns[f"{gp}_source_genes_top_gene_importances"] = (
+            gp_source_genes_gene_importances_df["gene_importance"][
+                :n_top_genes_per_gp].values)
+        adata.uns[f"{gp}_source_genes_top_gene_signs"] = (
+            np.where(gp_source_genes_gene_importances_df[
+                "gene_weight"] > 0, "+", "-"))
+        adata.uns["n_top_target_genes"] = n_top_genes_per_gp
+        adata.uns[f"{gp}_target_genes_top_genes"] = (
+            gp_target_genes_gene_importances_df["gene"][
+                :n_top_genes_per_gp].values)
+        adata.uns[f"{gp}_target_genes_top_gene_importances"] = (
+            gp_target_genes_gene_importances_df["gene_importance"][
+                :n_top_genes_per_gp].values)
+        adata.uns[f"{gp}_target_genes_top_gene_signs"] = (
+            np.where(gp_target_genes_gene_importances_df[
+                "gene_weight"] > 0, "+", "-"))
+
+        if n_top_peaks_per_gp > 0:
+            # Get source and target peaks, peak importances and peak signs and
+            # store in temporary adata
+            gp_peak_importances_df = model.compute_gp_peak_importances(
+                selected_gp=gp)
+            gp_source_peaks_peak_importances_df = gp_peak_importances_df[
+                gp_peak_importances_df["peak_entity"] == "source"]
+            gp_target_peaks_peak_importances_df = gp_peak_importances_df[
+                gp_peak_importances_df["peak_entity"] == "target"]
+            adata.uns["n_top_source_peaks"] = n_top_peaks_per_gp
+            adata.uns[f"{gp}_source_peaks_top_peaks"] = (
+                gp_source_peaks_peak_importances_df["peak"][
+                    :n_top_peaks_per_gp].values)
+            adata.uns[f"{gp}_source_peaks_top_peak_importances"] = (
+                gp_source_peaks_peak_importances_df["peak_importance"][
+                    :n_top_peaks_per_gp].values)
+            adata.uns[f"{gp}_source_peaks_top_peak_signs"] = (
+                np.where(gp_source_peaks_peak_importances_df[
+                    "peak_weight"] > 0, "+", "-"))
+            adata.uns["n_top_target_peaks"] = n_top_peaks_per_gp
+            adata.uns[f"{gp}_target_peaks_top_peaks"] = (
+                gp_target_peaks_peak_importances_df["peak"][
+                    :n_top_peaks_per_gp].values)
+            adata.uns[f"{gp}_target_peaks_top_peak_importances"] = (
+                gp_target_peaks_peak_importances_df["peak_importance"][
+                    :n_top_peaks_per_gp].values)
+            adata.uns[f"{gp}_target_peaks_top_peak_signs"] = (
+                np.where(gp_target_peaks_peak_importances_df[
+                    "peak_weight"] > 0, "+", "-"))
+            
+            # Add peak counts to temporary adata for plotting
+            adata.obs[[peak for peak in 
+                       adata.uns[f"{gp}_target_peaks_top_peaks"]]] = (
+                adata_atac.X[
+                    :, [adata_atac.var_names.tolist().index(peak)
+                        for peak in adata.uns[f"{gp}_target_peaks_top_peaks"]]])
+            adata.obs[[peak for peak in
+                       adata.uns[f"{gp}_source_peaks_top_peaks"]]] = (
+                adata_atac.X[
+                    :, [adata_atac.var_names.tolist().index(peak)
+                        for peak in adata.uns[f"{gp}_source_peaks_top_peaks"]]])
+        else:
+            adata.uns["n_top_source_peaks"] = 0
+            adata.uns["n_top_target_peaks"] = 0
+
+    for feature_space in feature_spaces:
+        plot_enriched_gp_info_plots_(
+            adata=adata,
+            sample_key=sample_key,
+            gps=gps,
+            log_bayes_factors=log_bayes_factors,
+            cat_key=cat_key,
+            cat_palette=cat_palette,
+            cats=cats,
+            feature_space=feature_space,
+            spot_size=spot_size,
+            suptitle=f"{plot_label.replace('_', ' ').title()} "
+                     f"Top {n_top_enriched_gp_start_idx} to "
+                     f"{n_top_enriched_gp_end_idx} Enriched GPs: "
+                     f"GP Scores and Omics Feature Counts in "
+                     f"{feature_space} Feature Space",
+            save_fig=save_figs,
+            figure_folder_path=figure_folder_path,
+            fig_name=f"{plot_label}_top_{n_top_enriched_gp_start_idx}"
+                     f"-{n_top_enriched_gp_end_idx}_enriched_gps_gp_scores_"
+                     f"omics_feature_counts_in_{feature_space}_"
+                     f"feature_space.{file_format}")
+            
+            
+def plot_enriched_gp_info_plots_(adata: AnnData,
+                                 sample_key: str,
+                                 gps: list,
+                                 log_bayes_factors: list,
+                                 cat_key: str,
+                                 cat_palette: dict,
+                                 cats: list,
+                                 feature_space: str,
+                                 spot_size: float,
+                                 suptitle: str,
+                                 save_fig: bool,
+                                 figure_folder_path: str,
+                                 fig_name: str):
+    """
+    This is a helper function to plot gene program info plots in a specified
+    feature space.
+    
+    Parameters
+    ----------
+    adata:
+        An AnnData object with stored information about the gene programs to be
+        plotted.
+    sample_key:
+        Key in ´adata.obs´ where the samples are stored.
+    gps:
+        List of gene programs for which info plots will be created.
+    log_bayes_factors:
+        List of log bayes factors corresponding to gene programs
+    cat_key:
+        Key in ´adata.obs´ where the categories that are used as colors for the
+        enriched category plot are stored.
+    cat_palette:
+        Dictionary of colors that are used to highlight the categories, where
+        the category is the key of the dictionary and the color is the value.
+    cats:
+        List of categories for which the corresponding gene programs in ´gps´
+        are enriched.
+    feature_space:
+        Feature space used for the plots. Can be ´latent´ to use the latent
+        embeddings for the plots, or it can be any of the samples stored in
+        ´adata.obs[sample_key]´ to use the respective physical feature space for
+        the plots.
+    spot_size:
+        Spot size used for the spatial plots.
+    subtitle:
+        Overall figure title.
+    save_fig:
+        If ´True´, save the figure.
+    figure_folder_path:
+        Path of the folder where the figure will be saved.
+    fig_name:
+        Name of the figure under which it will be saved.
+    """
+    # Define figure configurations
+    ncols = (2 +
+             adata.uns["n_top_source_genes"] +
+             adata.uns["n_top_target_genes"] +
+             adata.uns["n_top_source_peaks"] +
+             adata.uns["n_top_target_peaks"])
+    fig_width = (12 + (6 * (
+        adata.uns["n_top_source_genes"] +
+        adata.uns["n_top_target_genes"] +
+        adata.uns["n_top_source_peaks"] +
+        adata.uns["n_top_target_peaks"])))
+    wspace = 0.3
+    fig, axs = plt.subplots(nrows=len(gps),
+                            ncols=ncols,
+                            figsize=(fig_width, 6*len(gps)))
+    if axs.ndim == 1:
+        axs = axs.reshape(1, -1)
+    title = fig.suptitle(t=suptitle,
+                         x=0.55,
+                         y=(1.1 if len(gps) == 1 else 0.97),
+                         fontsize=20)
+    
+    # Plot enriched gp category and gene program latent scores
+    for i, gp in enumerate(gps):
+        if feature_space == "latent":
+            sc.pl.umap(
+                adata,
+                color=cat_key,
+                palette=cat_palette,
+                groups=cats[i],
+                ax=axs[i, 0],
+                title="Enriched GP Category",
+                legend_loc="on data",
+                na_in_legend=False,
+                show=False)
+            sc.pl.umap(
+                adata,
+                color=gps[i],
+                color_map="RdBu",
+                ax=axs[i, 1],
+                title=f"{gp[:gp.index('_')]}\n"
+                      f"{gp[gp.index('_') + 1: gp.rindex('_')].replace('_', ' ')}"
+                      f"\n{gp[gps[i].rindex('_') + 1:]} score (LBF: {round(log_bayes_factors[i])})",
+                colorbar_loc="bottom",
+                show=False)
+        else:
+            sc.pl.spatial(
+                adata=adata[adata.obs[sample_key] == feature_space],
+                color=cat_key,
+                palette=cat_palette,
+                groups=cats[i],
+                ax=axs[i, 0],
+                spot_size=spot_size,
+                title="Enriched GP Category",
+                legend_loc="on data",
+                na_in_legend=False,
+                show=False)
+            sc.pl.spatial(
+                adata=adata[adata.obs[sample_key] == feature_space],
+                color=gps[i],
+                color_map="RdBu",
+                spot_size=spot_size,
+                title=f"{gps[i].split('_', 1)[0]}\n{gps[i].split('_', 1)[1]} "
+                      f"(LBF: {round(log_bayes_factors[i], 2)})",
+                legend_loc=None,
+                ax=axs[i, 1],
+                colorbar_loc="bottom",
+                show=False) 
+        axs[i, 0].xaxis.label.set_visible(False)
+        axs[i, 0].yaxis.label.set_visible(False)
+        axs[i, 1].xaxis.label.set_visible(False)
+        axs[i, 1].yaxis.label.set_visible(False)
+
+        # Plot omics feature counts (or log normalized counts)
+        modality_entities = []
+        if len(adata.uns[f"{gp}_source_genes_top_genes"]) > 0:
+            modality_entities.append("source_genes")
+        if len(adata.uns[f"{gp}_target_genes_top_genes"]) > 0:
+            modality_entities.append("target_genes")
+        if f"{gp}_source_peaks_top_peaks" in adata.uns.keys():
+            gp_n_source_peaks_top_peaks = (
+                len(adata.uns[f"{gp}_source_peaks_top_peaks"]))
+            if len(adata.uns[f"{gp}_source_peaks_top_peaks"]) > 0:
+                modality_entities.append("source_peaks")
+        else:
+            gp_n_source_peaks_top_peaks = 0
+        if f"{gp}_target_peaks_top_peaks" in adata.uns.keys():
+            gp_n_target_peaks_top_peaks = (
+                len(adata.uns[f"{gp}_target_peaks_top_peaks"]))
+            if len(adata.uns[f"{gp}_target_peaks_top_peaks"]) > 0:
+                modality_entities.append("target_peaks")
+        else:
+            gp_n_target_peaks_top_peaks = 0
+        for modality_entity in modality_entities:
+            # Define k for index iteration
+            if modality_entity == "source_genes":
+                k = 0
+            elif modality_entity == "target_genes":
+                k = len(adata.uns[f"{gp}_source_genes_top_genes"])
+            elif modality_entity == "source_peaks":
+                k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) +
+                     len(adata.uns[f"{gp}_target_genes_top_genes"]))
+            elif modality_entity == "target_peaks":
+                k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) +
+                     len(adata.uns[f"{gp}_target_genes_top_genes"]) +
+                     len(adata.uns[f"{gp}_source_peaks_top_peaks"]))
+            for j in range(len(adata.uns[f"{gp}_{modality_entity}_top_"
+                                         f"{modality_entity.split('_')[1]}"])):
+                if feature_space == "latent":
+                    sc.pl.umap(
+                        adata,
+                        color=adata.uns[f"{gp}_{modality_entity}_top_"
+                                        f"{modality_entity.split('_')[1]}"][j],
+                        color_map=(adata.uns["omics_ft_pos_cmap"] if
+                                   adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]),
+                        ax=axs[i, 2+k+j],
+                        legend_loc="on data",
+                        na_in_legend=False,
+                        title=f"""{adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1]}"
+                                             ][j]}: """
+                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_importances"][j]:.2f} """
+                              f"({modality_entity[:-1]}; "
+                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_signs"][j]})""",
+                        colorbar_loc="bottom",
+                        show=False)
+                else:
+                    sc.pl.spatial(
+                        adata=adata[adata.obs[sample_key] == feature_space],
+                        color=adata.uns[f"{gp}_{modality_entity}_top_"
+                                        f"{modality_entity.split('_')[1]}"][j],
+                        color_map=(adata.uns["omics_ft_pos_cmap"] if
+                                   adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]),
+                        legend_loc="on data",
+                        na_in_legend=False,
+                        ax=axs[i, 2+k+j],
+                        spot_size=spot_size,
+                        title=f"""{adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1]}"
+                                             ][j]} \n"""
+                              f"""({adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_importances"][j]:.2f}; """
+                              f"{modality_entity[:-1]}; "
+                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
+                                             f"{modality_entity.split('_')[1][:-1]}"
+                                             "_signs"][j]})""",
+                        colorbar_loc="bottom",
+                        show=False)
+                axs[i, 2+k+j].xaxis.label.set_visible(False)
+                axs[i, 2+k+j].yaxis.label.set_visible(False)
+            # Remove unnecessary axes
+            for l in range(2 +
+                           len(adata.uns[f"{gp}_source_genes_top_genes"]) +
+                           len(adata.uns[f"{gp}_target_genes_top_genes"]) +
+                           gp_n_source_peaks_top_peaks +
+                           gp_n_target_peaks_top_peaks, ncols):
+                axs[i, l].set_visible(False)
+
+    # Save and display plot
+    plt.subplots_adjust(wspace=wspace, hspace=0.275)
+    if save_fig:
+        fig.savefig(f"{figure_folder_path}/{fig_name}",
+                    bbox_extra_artists=(title,),
+                    bbox_inches="tight")
+    plt.show()
+
+default_color_dict = {
+    "0": "#66C5CC",
+    "1": "#F6CF71",
+    "2": "#F89C74",
+    "3": "#DCB0F2",
+    "4": "#87C55F",
+    "5": "#9EB9F3",
+    "6": "#FE88B1",
+    "7": "#C9DB74",
+    "8": "#8BE0A4",
+    "9": "#B497E7",
+    "10": "#D3B484",
+    "11": "#B3B3B3",
+    "12": "#276A8C", # Royal Blue
+    "13": "#DAB6C4", # Pink
+    "14": "#C38D9E", # Mauve-Pink
+    "15": "#9D88A2", # Mauve
+    "16": "#FF4D4D", # Light Red
+    "17": "#9B4DCA", # Lavender-Purple
+    "18": "#FF9CDA", # Bright Pink
+    "19": "#FF69B4", # Hot Pink
+    "20": "#FF00FF", # Magenta
+    "21": "#DA70D6", # Orchid
+    "22": "#BA55D3", # Medium Orchid
+    "23": "#8A2BE2", # Blue Violet
+    "24": "#9370DB", # Medium Purple
+    "25": "#7B68EE", # Medium Slate Blue
+    "26": "#4169E1", # Royal Blue
+    "27": "#FF8C8C", # Salmon Pink
+    "28": "#FFAA80", # Light Coral
+    "29": "#48D1CC", # Medium Turquoise
+    "30": "#40E0D0", # Turquoise
+    "31": "#00FF00", # Lime
+    "32": "#7FFF00", # Chartreuse
+    "33": "#ADFF2F", # Green Yellow
+    "34": "#32CD32", # Lime Green
+    "35": "#228B22", # Forest Green
+    "36": "#FFD8B8", # Peach
+    "37": "#008080", # Teal
+    "38": "#20B2AA", # Light Sea Green
+    "39": "#00FFFF", # Cyan
+    "40": "#00BFFF", # Deep Sky Blue
+    "41": "#4169E1", # Royal Blue
+    "42": "#0000CD", # Medium Blue
+    "43": "#00008B", # Dark Blue
+    "44": "#8B008B", # Dark Magenta
+    "45": "#FF1493", # Deep Pink
+    "46": "#FF4500", # Orange Red
+    "47": "#006400", # Dark Green
+    "48": "#FF6347", # Tomato
+    "49": "#FF7F50", # Coral
+    "50": "#CD5C5C", # Indian Red
+    "51": "#B22222", # Fire Brick
+    "52": "#FFB83F",  # Light Orange
+    "53": "#8B0000", # Dark Red
+    "54": "#D2691E", # Chocolate
+    "55": "#A0522D", # Sienna
+    "56": "#800000", # Maroon
+    "57": "#808080", # Gray
+    "58": "#A9A9A9", # Dark Gray
+    "59": "#C0C0C0", # Silver
+    "60": "#9DD84A",
+    "61": "#F5F5F5", # White Smoke
+    "62": "#F17171", # Light Red
+    "63": "#000000", # Black
+    "64": "#FF8C42", # Tangerine
+    "65": "#F9A11F", # Bright Orange-Yellow
+    "66": "#FACC15", # Golden Yellow
+    "67": "#E2E062", # Pale Lime
+    "68": "#BADE92", # Soft Lime
+    "69": "#70C1B3", # Greenish-Blue
+    "70": "#41B3A3", # Turquoise
+    "71": "#5EAAA8", # Gray-Green
+    "72": "#72B01D", # Chartreuse
+    "73": "#9CD08F", # Light Green
+    "74": "#8EBA43", # Olive Green
+    "75": "#FAC8C3", # Light Pink
+    "76": "#E27D60", # Dark Salmon
+    "77": "#C38D9E", # Mauve-Pink
+    "78": "#937D64", # Light Brown
+    "79": "#B1C1CC", # Light Blue-Gray
+    "80": "#88A0A8", # Gray-Blue-Green
+    "81": "#4E598C", # Dark Blue-Purple
+    "82": "#4B4E6D", # Dark Gray-Blue
+    "83": "#8E9AAF", # Light Blue-Grey
+    "84": "#C0D6DF", # Pale Blue-Grey
+    "85": "#97C1A9", # Blue-Green
+    "86": "#4C6E5D", # Dark Green
+    "87": "#95B9C7", # Pale Blue-Green
+    "88": "#C1D5E0", # Pale Gray-Blue
+    "89": "#ECDB54", # Bright Yellow
+    "90": "#E89B3B", # Bright Orange
+    "91": "#CE5A57", # Deep Red
+    "92": "#C3525A", # Dark Red
+    "93": "#B85D8E", # Berry
+    "94": "#7D5295", # Deep Purple
+    "-1" : "#E1D9D1",
+    "None" : "#E1D9D1"
+}
+
+def create_new_color_dict(
+        adata,
+        cat_key,
+        color_palette="default",
+        overwrite_color_dict={"-1" : "#E1D9D1"},
+        skip_default_colors=0):
+    """
+    Create a dictionary of color hexcodes for a specified category.
+
+    Parameters
+    ----------
+    adata:
+        AnnData object.
+    cat_key:
+        Key in ´adata.obs´ where the categories are stored for which color
+        hexcodes will be created.
+    color_palette:
+        Type of color palette.
+    overwrite_color_dict:
+        Dictionary with overwrite values that will take precedence over the
+        automatically created dictionary.
+    skip_default_colors:
+        Number of colors to skip from the default color dict.
+
+    Returns
+    ----------
+    new_color_dict:
+        The color dictionary with a hexcode for each category.
+    """
+    new_categories = adata.obs[cat_key].unique().tolist()
+    if color_palette == "cell_type_30":
+        # https://github.com/scverse/scanpy/blob/master/scanpy/plotting/palettes.py#L40
+        new_color_dict = {key: value for key, value in zip(
+            new_categories,
+            ["#023fa5",
+             "#7d87b9",
+             "#bec1d4",
+             "#d6bcc0",
+             "#bb7784",
+             "#8e063b",
+             "#4a6fe3",
+             "#8595e1",
+             "#b5bbe3",
+             "#e6afb9",
+             "#e07b91",
+             "#d33f6a",
+             "#11c638",
+             "#8dd593",
+             "#c6dec7",
+             "#ead3c6",
+             "#f0b98d",
+             "#ef9708",
+             "#0fcfc0",
+             "#9cded6",
+             "#d5eae7",
+             "#f3e1eb",
+             "#f6c4e1",
+             "#f79cd4",
+             '#7f7f7f',
+             "#c7c7c7",
+             "#1CE6FF",
+             "#336600"])}
+    elif color_palette == "cell_type_20":
+        # https://github.com/vega/vega/wiki/Scales#scale-range-literals (some adjusted)
+        new_color_dict = {key: value for key, value in zip(
+            new_categories,
+            ['#1f77b4',
+             '#ff7f0e',
+             '#279e68',
+             '#d62728',
+             '#aa40fc',
+             '#8c564b',
+             '#e377c2',
+             '#b5bd61',
+             '#17becf',
+             '#aec7e8',
+             '#ffbb78',
+             '#98df8a',
+             '#ff9896',
+             '#c5b0d5',
+             '#c49c94',
+             '#f7b6d2',
+             '#dbdb8d',
+             '#9edae5',
+             '#ad494a',
+             '#8c6d31'])}
+    elif color_palette == "cell_type_10":
+        # scanpy vega10
+        new_color_dict = {key: value for key, value in zip(
+            new_categories,
+            ['#7f7f7f',
+             '#ff7f0e',
+             '#279e68',
+             '#e377c2',
+             '#17becf',
+             '#8c564b',
+             '#d62728',
+             '#1f77b4',
+             '#b5bd61',
+             '#aa40fc'])}
+    elif color_palette == "batch":
+        # sns.color_palette("colorblind").as_hex()
+        new_color_dict = {key: value for key, value in zip(
+            new_categories,
+            ['#0173b2', '#d55e00', '#ece133', '#ca9161', '#fbafe4',
+             '#949494', '#de8f05', '#029e73', '#cc78bc', '#56b4e9',
+             '#F0F8FF', '#FAEBD7', '#00FFFF', '#7FFFD4', '#F0FFFF',
+             '#F5F5DC', '#FFE4C4', '#000000', '#FFEBCD', '#0000FF',
+             '#8A2BE2', '#A52A2A', '#DEB887', '#5F9EA0', '#7FFF00',
+             '#D2691E', '#FF7F50', '#6495ED', '#FFF8DC', '#DC143C'])}
+    elif color_palette == "default":
+        new_color_dict = {key: value for key, value in zip(new_categories, list(default_color_dict.values())[skip_default_colors:])}
+    for key, val in overwrite_color_dict.items():
+        new_color_dict[key] = val
+    return new_color_dict
+
+
+def plot_non_zero_gene_count_means_dist(
+        adata: AnnData,
+        genes: list,
+        gene_label: str):
+    """
+    Plot distribution of non zero gene count means in the adata over all 
+    specified genes.
+    """
+    gene_counts = adata[
+        :, [gene for gene in adata.var_names if gene in genes]].layers["counts"]
+    nz_gene_means = np.mean(
+        np.ma.masked_equal(gene_counts.toarray(), 0), axis=0).data
+    
+    sns.kdeplot(nz_gene_means)
+    plt.title(f"{gene_label} Genes Average Non-Zero Gene Counts per Gene")
+    plt.xlabel("Average Non-zero Gene Counts")
+    plt.ylabel("Gene Density")
+    plt.show()
+
+
+def compute_communication_gp_network(
+    gp_list: list,
+    model: NicheCompass,
+    group_key: str="niche",
+    filter_key: Optional[str]=None,
+    filter_cat: Optional[str]=None,
+    n_neighbors: int=90):
+    """
+    Compute a network of category aggregated cell-pair communication strengths.
+    
+    First, compute cell-cell communication potential scores for each cell.
+    Then dot product them and take into account neighborhoods to compute
+    cell-pair communication strengths. Then, normalize cell-pair communication
+    strengths.
+    
+    Parameters
+    ----------
+    gp_list:
+        List of GPs for which the cell-pair communication strengths are computed.
+    model:
+        A trained NicheCompass model.
+    group_key:
+        Key in ´adata.obs´ where the groups are stored over which the cell-pair
+        communication strengths will be aggregated.
+    filter_key:
+        Key in ´adata.obs´ that contains the category for which the results are
+        filtered.
+    filter_cat:
+        Category for which the results are filtered.
+    n_neighbors:
+        Number of neighbors for the gp-specific neighborhood graph.
+
+    Returns
+    ----------
+    network_df:
+        A pandas dataframe with aggregated, normalized cell-pair communication strengths.
+    """
+    # Compute neighborhood graph
+    compute_knn = True
+    if 'spatial_cci' in model.adata.uns.keys():
+        if model.adata.uns['spatial_cci']['params']['n_neighbors'] == n_neighbors:
+            compute_knn = False
+    if compute_knn:
+        sc.pp.neighbors(model.adata,
+                        n_neighbors=n_neighbors,
+                        use_rep="spatial",
+                        key_added="spatial_cci")
+    
+    gp_network_dfs = []
+    gp_summary_df = model.get_gp_summary()
+    for gp in gp_list:
+        gp_idx = model.adata.uns[model.gp_names_key_].tolist().index(gp)
+        active_gp_idx = model.adata.uns[model.active_gp_names_key_].tolist().index(gp)
+        gp_scores = model.adata.obsm[model.latent_key_][:, active_gp_idx]
+        gp_targets_cats = model.adata.varm[model.gp_targets_categories_mask_key_][:, gp_idx]
+        gp_sources_cats = model.adata.varm[model.gp_sources_categories_mask_key_][:, gp_idx]
+        targets_cats_label_encoder = model.adata.uns[model.targets_categories_label_encoder_key_]
+        sources_cats_label_encoder = model.adata.uns[model.sources_categories_label_encoder_key_]
+
+        sources_cat_idx_dict = {}
+        for source_cat, source_cat_label in sources_cats_label_encoder.items():
+            sources_cat_idx_dict[source_cat] = np.where(gp_sources_cats == source_cat_label)[0]
+
+        targets_cat_idx_dict = {}
+        for target_cat, target_cat_label in targets_cats_label_encoder.items():
+            targets_cat_idx_dict[target_cat] = np.where(gp_targets_cats == target_cat_label)[0]
+
+        # Get indices of all source and target genes
+        source_genes_idx = np.array([], dtype=np.int64)
+        for key in sources_cat_idx_dict.keys():
+            source_genes_idx = np.append(source_genes_idx,
+                                         sources_cat_idx_dict[key])
+        target_genes_idx = np.array([], dtype=np.int64)
+        for key in targets_cat_idx_dict.keys():
+            target_genes_idx = np.append(target_genes_idx,
+                                         targets_cat_idx_dict[key])
+
+        # Compute cell-cell communication potential scores
+        gp_source_scores = np.zeros((len(model.adata.obs), len(source_genes_idx)))
+        gp_target_scores = np.zeros((len(model.adata.obs), len(target_genes_idx)))
+
+        for i, source_gene_idx in enumerate(source_genes_idx):
+            source_gene = model.adata.var_names[source_gene_idx]
+            gp_source_scores[:, i] = (
+                model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten().max() *
+                gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes"].values[0].index(source_gene)] *
+                gp_scores)
+
+        for j, target_gene_idx in enumerate(target_genes_idx):
+            target_gene = model.adata.var_names[target_gene_idx]
+            gp_target_scores[:, j] = (
+                model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten().max() *
+                gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes"].values[0].index(target_gene)] *
+                gp_scores)
+
+        agg_gp_source_score = gp_source_scores.mean(1).astype("float32")
+        agg_gp_target_score = gp_target_scores.mean(1).astype("float32")
+        agg_gp_source_score[agg_gp_source_score < 0] = 0.
+        agg_gp_target_score[agg_gp_target_score < 0] = 0.
+
+        model.adata.obs[f"{gp}_source_score"] = agg_gp_source_score
+        model.adata.obs[f"{gp}_target_score"] = agg_gp_target_score
+        
+        del(gp_target_scores)
+        del(gp_source_scores)
+
+        agg_gp_source_score = sp.csr_matrix(agg_gp_source_score)
+        agg_gp_target_score = sp.csr_matrix(agg_gp_target_score)
+
+        model.adata.obsp[f"{gp}_connectivities"] = (model.adata.obsp["spatial_cci_connectivities"] > 0).multiply(
+            agg_gp_source_score.T.dot(agg_gp_target_score))
+
+        # Aggregate gp connectivities for each group
+        gp_network_df_pivoted = aggregate_obsp_matrix_per_cell_type(
+            adata=model.adata,
+            obsp_key=f"{gp}_connectivities",
+            cell_type_key=group_key,
+            group_key=filter_key,
+            agg_rows=True)
+
+        if filter_key is not None:
+            gp_network_df_pivoted = gp_network_df_pivoted.loc[filter_cat, :]
+
+        gp_network_df = gp_network_df_pivoted.melt(var_name="source", value_name="gp_score", ignore_index=False).reset_index()
+        gp_network_df.columns = ["source", "target", "strength"]
+
+        gp_network_df = gp_network_df.sort_values("strength", ascending=False)
+
+        # Normalize strength
+        min_value = gp_network_df["strength"].min()
+        max_value = gp_network_df["strength"].max()
+        gp_network_df["strength_unscaled"] = gp_network_df["strength"]
+        gp_network_df["strength"] = (gp_network_df["strength"] - min_value) / (max_value - min_value)
+        gp_network_df["strength"] = np.round(gp_network_df["strength"], 2)
+        gp_network_df = gp_network_df[gp_network_df["strength"] > 0]
+
+        gp_network_df["edge_type"] = gp
+        gp_network_dfs.append(gp_network_df)
+
+    network_df = pd.concat(gp_network_dfs, ignore_index=True)
+    return network_df
+
+
+def visualize_communication_gp_network(
+    adata,
+    network_df,
+    cat_colors,
+    edge_type_colors: Optional[dict]=None,
+    edge_width_scale: int=20.0,
+    node_size: int=500,
+    fontsize: int=14,
+    figsize: Tuple[int, int]=(18, 16),
+    plot_legend: bool=True,
+    save: bool=False,
+    save_path: str="communication_gp_network.svg",
+    show: bool=True,
+    text_space: float=1.3,
+    connection_style="arc3, rad = 0.1",
+    cat_key: str="niche",
+    edge_attr: str="strength"):
+    """
+    Visualize a communication gp network.
+    """
+    # Assuming you have unique edge types in your 'edge_type' column
+    edge_types = np.unique(network_df['edge_type'])
+    
+    if edge_type_colors is None:
+        # Colorblindness adjusted vega_10
+        # See https://github.com/theislab/scanpy/issues/387
+        vega_10 = list(map(colors.to_hex, cm.tab10.colors))
+        vega_10_scanpy = vega_10.copy()
+        vega_10_scanpy[2] = "#279e68"  # green
+        vega_10_scanpy[4] = "#aa40fc"  # purple
+        vega_10_scanpy[8] = "#b5bd61"  # kakhi
+        edge_type_colors = vega_10_scanpy
+
+    # Create a dictionary that maps edge types to colors
+    edge_type_color_dict = {edge_type: color for edge_type, color in zip(edge_types, edge_type_colors)}
+
+    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
+    ax.axis("off")
+    G = nx.from_pandas_edgelist(
+        network_df,
+        source="source",
+        target="target",
+        edge_attr=["edge_type", edge_attr],
+        create_using=nx.DiGraph(),
+    )
+    pos = nx.circular_layout(G)
+
+    nx.set_node_attributes(G, cat_colors, "color")
+    node_color = nx.get_node_attributes(G, "color")
+
+    description = nx.draw_networkx_labels(G, pos, font_size=fontsize)
+    n = adata.obs[cat_key].nunique()
+    node_list = sorted(G.nodes())
+    angle = []
+    angle_dict = {}
+    for i, node in zip(range(n), node_list):
+        theta = 2.0 * np.pi * i / n
+        angle.append((np.cos(theta), np.sin(theta)))
+        angle_dict[node] = theta
+    pos = {}
+    for node_i, node in enumerate(node_list):
+        pos[node] = angle[node_i]
+
+    r = fig.canvas.get_renderer()
+    trans = plt.gca().transData.inverted()
+    for node, t in description.items():
+        bb = t.get_window_extent(renderer=r)
+        bbdata = bb.transformed(trans)
+        radius = text_space + bbdata.width / 2.0
+        position = (radius * np.cos(angle_dict[node]), radius * np.sin(angle_dict[node]))
+        t.set_position(position)
+        t.set_rotation(angle_dict[node] * 360.0 / (2.0 * np.pi))
+        t.set_clip_on(False)
+
+    edgelist = [(u, v) for u, v, e in G.edges(data=True) if u != v]
+    edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u != v]
+    width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u != v]
+
+    h2 = nx.draw_networkx(
+        G,
+        pos,
+        with_labels=False,
+        node_size=node_size,
+        edgelist=edgelist,
+        width=width,
+        edge_vmin=0.0,
+        edge_vmax=1.0,
+        edge_color=edge_colors,  # Use the edge type colors here
+        arrows=True,
+        arrowstyle="-|>",
+        arrowsize=20,
+        vmin=0.0,
+        vmax=1.0,
+        cmap=plt.cm.binary,  # Use a colormap for node colors if needed
+        node_color=list(node_color.values()),
+        ax=ax,
+        connectionstyle=connection_style,
+    )
+
+    #https://stackoverflow.com/questions/19877666/add-legends-to-linecollection-plot - uses plotted data to define the color but here we already have colors defined, so just need a Line2D object.
+    def make_proxy(clr, mappable, **kwargs):
+        return Line2D([0, 1], [0, 1], color=clr, **kwargs)
+
+    # generate proxies with the above function
+    proxies = [make_proxy(clr, h2, lw=5) for clr in set(edge_colors)]
+    labels = [edge.split("_")[0] + " GP" for edge in edge_types[::-1]]
+
+    if plot_legend:
+        lgd = plt.legend(proxies, labels, loc="lower left")
+
+    edgelist = [(u, v) for u, v, e in G.edges(data=True) if ((u == v))] + [(u, v) for u, v, e in G.edges(data=True) if ((u != v))]
+    edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u == v]
+    width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u == v] + [0 for u, v, e in G.edges(data=True) if ((u != v))]
+    nx.draw_networkx_edges(
+        G,
+        pos,
+        node_size=node_size,
+        edgelist=edgelist, 
+        width=width,
+        edge_vmin=0.0,
+        edge_vmax=1.0,
+        edge_color=edge_colors,
+        arrows=False,
+        arrowstyle="-|>",
+        arrowsize=20,
+        ax=ax,
+        connectionstyle=connection_style)
+    plt.tight_layout()
+    if save:
+        plt.savefig(save_path)
+    if show:
+        plt.show()
+    plt.close(fig)
+    plt.ion()