a b/src/nichecompass/utils/analysis.py
1
"""
2
This module contains utilities to analyze niches inferred by the NicheCompass
3
model.
4
"""
5
6
from typing import Optional, Tuple
7
8
#import holoviews as hv
9
import matplotlib.pyplot as plt
10
import numpy as np
11
import pandas as pd
12
import scanpy as sc
13
import scipy.sparse as sp
14
import seaborn as sns
15
from anndata import AnnData
16
from matplotlib import cm, colors
17
from matplotlib.lines import Line2D
18
import networkx as nx
19
20
from ..models import NicheCompass
21
22
23
def aggregate_obsp_matrix_per_cell_type(
24
        adata: AnnData,
25
        obsp_key: str,
26
        cell_type_key: str="cell_type",
27
        group_key: Optional[str]=None,
28
        agg_rows: bool=False):
29
    """
30
    Generic function to aggregate adjacency matrices stored in
31
    ´adata.obsp[obsp_key]´ on cell type level. It can be used to aggregate the
32
    node label aggregator aggregation weights alpha or the reconstructed adjacency
33
    matrix of a trained NicheCompass model by neighbor cell type for downstream
34
    analysis.
35
36
    Parameters
37
    ----------
38
    adata:
39
        AnnData object which contains outputs of NicheCompass model training.
40
    obsp_key:
41
        Key in ´adata.obsp´ where the matrix to be aggregated is stored.
42
    cell_type_key:
43
        Key in ´adata.obs´ where the cell type labels are stored.
44
    group_key:
45
        Key in ´adata.obs´ where additional grouping labels are stored.    
46
    agg_rows:
47
        If ´True´, also aggregate over the observations on cell type level.
48
49
    Returns
50
    ----------
51
    cell_type_agg_df:
52
        Pandas DataFrame with the aggregated obsp values (dim: n_obs x
53
        n_cell_types if ´agg_rows == False´, else n_cell_types x n_cell_types).
54
    """
55
    n_obs = len(adata)
56
    n_cell_types = adata.obs[cell_type_key].nunique()
57
    sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist())
58
59
    cell_type_label_encoder = {k: v for k, v in zip(
60
        sorted_cell_types,
61
        range(n_cell_types))}
62
63
    # Retrieve non zero indices and non zero values, and create row-wise
64
    # observation cell type index
65
    nz_obsp_idx = adata.obsp[obsp_key].nonzero()
66
    neighbor_cell_type_index = adata.obs[cell_type_key][nz_obsp_idx[1]].map(
67
        cell_type_label_encoder).values
68
    adata.obsp[obsp_key].eliminate_zeros() # In some sparse reps 0s can appear
69
    nz_obsp = adata.obsp[obsp_key].data
70
71
    # Use non zero indices, non zero values and row-wise observation cell type
72
    # index to construct new df with cell types as columns and row-wise
73
    # aggregated values per cell type index as values
74
    cell_type_agg = np.zeros((n_obs, n_cell_types))
75
    np.add.at(cell_type_agg,
76
              (nz_obsp_idx[0], neighbor_cell_type_index),
77
              nz_obsp)
78
    cell_type_agg_df = pd.DataFrame(
79
        cell_type_agg,
80
        columns=sorted_cell_types)
81
    
82
    # Add cell type labels of observations
83
    cell_type_agg_df[cell_type_key] = adata.obs[cell_type_key].values
84
85
    # If specified, add group label
86
    if group_key is not None:
87
        cell_type_agg_df[group_key] = adata.obs[group_key].values
88
89
    if agg_rows:
90
        # In addition, aggregate values across rows to get a
91
        # (n_cell_types x n_cell_types) df
92
        if group_key is not None:
93
            cell_type_agg_df = cell_type_agg_df.groupby(
94
                [group_key, cell_type_key]).sum()
95
        else:
96
            cell_type_agg_df = cell_type_agg_df.groupby(cell_type_key).sum()
97
98
        # Sort index to have same order as columns
99
        cell_type_agg_df = cell_type_agg_df.loc[
100
            sorted(cell_type_agg_df.index.tolist()), :]
101
        
102
    return cell_type_agg_df
103
104
105
def create_cell_type_chord_plot_from_df(
106
        adata: AnnData,
107
        df: pd.DataFrame,
108
        link_threshold: float=0.01,
109
        cell_type_key: str="cell_type",
110
        group_key: Optional[str]=None,
111
        groups: str="all",
112
        plot_label: str="Niche",
113
        save_fig: bool=False,
114
        file_path: Optional[str]=None):
115
    """
116
    Create a cell type chord diagram per group based on an input DataFrame.
117
118
    Parameters
119
    ----------
120
    adata:
121
        AnnData object which contains outputs of NicheCompass model training.
122
    df:
123
        A Pandas DataFrame that contains the connection values for the chord
124
        plot (dim: (n_groups x n_cell_types) x n_cell_types).
125
    link_threshold:
126
        Ratio of link strength that a cell type pair needs to exceed compared to
127
        the cell type pair with the maximum link strength to be considered a
128
        link for the chord plot.
129
    cell_type_key:
130
        Key in ´adata.obs´ where the cell type labels are stored.
131
    group_key:
132
        Key in ´adata.obs´ where additional group labels are stored.
133
    groups:
134
        List of groups that will be plotted. If ´all´, plot all groups.
135
    plot_label:
136
        Shared label for the plots.
137
    save_fig:
138
        If ´True´, save the figure.
139
    file_path:
140
        Path where to save the figure.
141
    """
142
    hv.extension("bokeh")
143
    hv.output(size=200)
144
145
    sorted_cell_types = sorted(adata.obs[cell_type_key].unique().tolist())
146
147
    # Get group labels
148
    if (group_key is not None) & (groups == "all"):
149
        group_labels = df.index.get_level_values(
150
            df.index.names.index(group_key)).unique().tolist()
151
    elif (group_key is not None) & (groups != "all"):
152
        group_labels = groups
153
    else:
154
        group_labels = [""]
155
156
    chord_list = []
157
    for group_label in group_labels:
158
        if group_label == "":
159
            group_df = df
160
        else:
161
            group_df = df[df.index.get_level_values(
162
                df.index.names.index(group_key)) == group_label]
163
        
164
        # Get max value (over rows and columns) of the group for thresholding
165
        group_max = group_df.max().max()
166
167
        # Create group chord links
168
        links_list = []
169
        for i in range(len(sorted_cell_types)):
170
            for j in range(len(sorted_cell_types)):
171
                if group_df.iloc[i, j] > group_max * link_threshold:
172
                    link_dict = {}
173
                    link_dict["source"] = j
174
                    link_dict["target"] = i
175
                    link_dict["value"] = group_df.iloc[i, j]
176
                    links_list.append(link_dict)
177
        links = pd.DataFrame(links_list)
178
179
        # Create group chord nodes (only where links exist)
180
        nodes_list = []
181
        nodes_idx = []
182
        for i, cell_type in enumerate(sorted_cell_types):
183
            if i in (links["source"].values) or i in (links["target"].values):
184
                nodes_idx.append(i)
185
                nodes_dict = {}
186
                nodes_dict["name"] = cell_type
187
                nodes_dict["group"] = 1
188
                nodes_list.append(nodes_dict)
189
        nodes = hv.Dataset(pd.DataFrame(nodes_list, index=nodes_idx), "index")
190
191
        # Create group chord plot
192
        chord = hv.Chord((links, nodes)).select(value=(5, None))
193
        chord.opts(hv.opts.Chord(cmap="Category20",
194
                                 edge_cmap="Category20",
195
                                 edge_color=hv.dim("source").str(),
196
                                 labels="name",
197
                                 node_color=hv.dim("index").str(),
198
                                 title=f"{plot_label} {group_label}"))
199
        chord_list.append(chord)
200
    
201
    # Display chord plots
202
    layout = hv.Layout(chord_list).cols(2)
203
    hv.output(layout)
204
205
    # Save chord plots
206
    if save_fig:
207
        hv.save(layout,
208
                file_path,
209
                fmt="png")
210
211
        
212
def generate_enriched_gp_info_plots(plot_label: str,
213
                                    model: NicheCompass,
214
                                    sample_key: str,
215
                                    differential_gp_test_results_key: str,
216
                                    cat_key: str,
217
                                    cat_palette: dict,
218
                                    n_top_enriched_gp_start_idx: int=0,
219
                                    n_top_enriched_gp_end_idx: int=10,
220
                                    feature_spaces: list=["latent"],
221
                                    n_top_genes_per_gp: int=3,
222
                                    n_top_peaks_per_gp: int=0,
223
                                    scale_omics_ft: bool=False,
224
                                    save_figs: bool=False,
225
                                    figure_folder_path: str="",
226
                                    file_format: str="png",
227
                                    spot_size: float=30.):
228
    """
229
    Generate info plots of enriched gene programs. These show the enriched
230
    category, the gp activities, as well as the counts (or log normalized
231
    counts) of the top genes and/or peaks in a specified feature space.
232
    
233
    Parameters
234
    ----------
235
    plot_label:
236
        Main label of the plots.
237
    model:
238
        A trained NicheCompass model.
239
    sample_key:
240
        Key in ´adata.obs´ where the samples are stored.
241
    differential_gp_test_results_key:
242
        Key in ´adata.uns´ where the results of the differential gene program
243
        testing are stored.
244
    cat_key:
245
        Key in ´adata.obs´ where the categories that are used as colors for the
246
        enriched category plot are stored.
247
    cat_palette:
248
        Dictionary of colors that are used to highlight the categories, where
249
        the category is the key of the dictionary and the color is the value.
250
    n_top_enriched_gp_start_idx:
251
        Number of top enriched gene program from which to start the creation
252
        of plots.
253
    n_top_enriched_gp_end_idx:
254
        Number of top enriched gene program at which to stop the creation
255
        of plots.
256
    feature_spaces:
257
        List of feature spaces used for the info plots. Can be ´latent´ to use
258
        the latent embeddings for the plots, or it can be any of the samples
259
        stored in ´adata.obs[sample_key]´ to use the respective physical
260
        feature space for the plots.
261
    n_top_genes_per_gp:
262
        Number of top genes per gp to be considered in the info plots.
263
    n_top_peaks_per_gp:
264
        Number of top peaks per gp to be considered in the info plots. If ´>0´,
265
        requires the model to be trained inlcuding ATAC modality.
266
    scale_omics_ft:
267
        If ´True´, scale genes and peaks before plotting.
268
    save_figs:
269
        If ´True´, save the figures.
270
    figure_folder_path:
271
        Folder path where the figures will be saved.
272
    file_format:
273
        Format with which the figures will be saved.
274
    spot_size:
275
        Spot size used for the spatial plots.
276
    """
277
    model._check_if_trained(warn=True)
278
279
    adata = model.adata.copy()
280
    if n_top_peaks_per_gp > 0:
281
        if "atac" not in model.modalities_:
282
            raise ValueError("The model needs to be trained with ATAC data if"
283
                             "'n_top_peaks_per_gp' > 0.")
284
        adata_atac = model.adata_atac.copy()
285
    
286
    # TODO
287
    if scale_omics_ft:
288
        sc.pp.scale(adata)
289
        if n_top_peaks_per_gp > 0:
290
            sc.pp.scale(adata_atac)
291
        adata.uns["omics_ft_pos_cmap"] = "RdBu"
292
        adata.uns["omics_ft_neg_cmap"] = "RdBu_r"
293
    else:
294
        if n_top_peaks_per_gp > 0:
295
            adata_atac.X = adata_atac.X.toarray()
296
        adata.uns["omics_ft_pos_cmap"] = "Blues"
297
        adata.uns["omics_ft_neg_cmap"] = "Reds"
298
        
299
    cats = list(adata.uns[differential_gp_test_results_key]["category"][
300
        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
301
    gps = list(adata.uns[differential_gp_test_results_key]["gene_program"][
302
        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
303
    log_bayes_factors = list(adata.uns[differential_gp_test_results_key]["log_bayes_factor"][
304
        n_top_enriched_gp_start_idx:n_top_enriched_gp_end_idx])
305
    
306
    for gp in gps:
307
        # Get source and target genes, gene importances and gene signs and store
308
        # in temporary adata
309
        gp_gene_importances_df = model.compute_gp_gene_importances(
310
            selected_gp=gp)
311
        
312
        gp_source_genes_gene_importances_df = gp_gene_importances_df[
313
            gp_gene_importances_df["gene_entity"] == "source"]
314
        gp_target_genes_gene_importances_df = gp_gene_importances_df[
315
            gp_gene_importances_df["gene_entity"] == "target"]
316
        adata.uns["n_top_source_genes"] = n_top_genes_per_gp
317
        adata.uns[f"{gp}_source_genes_top_genes"] = (
318
            gp_source_genes_gene_importances_df["gene"][
319
                :n_top_genes_per_gp].values)
320
        adata.uns[f"{gp}_source_genes_top_gene_importances"] = (
321
            gp_source_genes_gene_importances_df["gene_importance"][
322
                :n_top_genes_per_gp].values)
323
        adata.uns[f"{gp}_source_genes_top_gene_signs"] = (
324
            np.where(gp_source_genes_gene_importances_df[
325
                "gene_weight"] > 0, "+", "-"))
326
        adata.uns["n_top_target_genes"] = n_top_genes_per_gp
327
        adata.uns[f"{gp}_target_genes_top_genes"] = (
328
            gp_target_genes_gene_importances_df["gene"][
329
                :n_top_genes_per_gp].values)
330
        adata.uns[f"{gp}_target_genes_top_gene_importances"] = (
331
            gp_target_genes_gene_importances_df["gene_importance"][
332
                :n_top_genes_per_gp].values)
333
        adata.uns[f"{gp}_target_genes_top_gene_signs"] = (
334
            np.where(gp_target_genes_gene_importances_df[
335
                "gene_weight"] > 0, "+", "-"))
336
337
        if n_top_peaks_per_gp > 0:
338
            # Get source and target peaks, peak importances and peak signs and
339
            # store in temporary adata
340
            gp_peak_importances_df = model.compute_gp_peak_importances(
341
                selected_gp=gp)
342
            gp_source_peaks_peak_importances_df = gp_peak_importances_df[
343
                gp_peak_importances_df["peak_entity"] == "source"]
344
            gp_target_peaks_peak_importances_df = gp_peak_importances_df[
345
                gp_peak_importances_df["peak_entity"] == "target"]
346
            adata.uns["n_top_source_peaks"] = n_top_peaks_per_gp
347
            adata.uns[f"{gp}_source_peaks_top_peaks"] = (
348
                gp_source_peaks_peak_importances_df["peak"][
349
                    :n_top_peaks_per_gp].values)
350
            adata.uns[f"{gp}_source_peaks_top_peak_importances"] = (
351
                gp_source_peaks_peak_importances_df["peak_importance"][
352
                    :n_top_peaks_per_gp].values)
353
            adata.uns[f"{gp}_source_peaks_top_peak_signs"] = (
354
                np.where(gp_source_peaks_peak_importances_df[
355
                    "peak_weight"] > 0, "+", "-"))
356
            adata.uns["n_top_target_peaks"] = n_top_peaks_per_gp
357
            adata.uns[f"{gp}_target_peaks_top_peaks"] = (
358
                gp_target_peaks_peak_importances_df["peak"][
359
                    :n_top_peaks_per_gp].values)
360
            adata.uns[f"{gp}_target_peaks_top_peak_importances"] = (
361
                gp_target_peaks_peak_importances_df["peak_importance"][
362
                    :n_top_peaks_per_gp].values)
363
            adata.uns[f"{gp}_target_peaks_top_peak_signs"] = (
364
                np.where(gp_target_peaks_peak_importances_df[
365
                    "peak_weight"] > 0, "+", "-"))
366
            
367
            # Add peak counts to temporary adata for plotting
368
            adata.obs[[peak for peak in 
369
                       adata.uns[f"{gp}_target_peaks_top_peaks"]]] = (
370
                adata_atac.X[
371
                    :, [adata_atac.var_names.tolist().index(peak)
372
                        for peak in adata.uns[f"{gp}_target_peaks_top_peaks"]]])
373
            adata.obs[[peak for peak in
374
                       adata.uns[f"{gp}_source_peaks_top_peaks"]]] = (
375
                adata_atac.X[
376
                    :, [adata_atac.var_names.tolist().index(peak)
377
                        for peak in adata.uns[f"{gp}_source_peaks_top_peaks"]]])
378
        else:
379
            adata.uns["n_top_source_peaks"] = 0
380
            adata.uns["n_top_target_peaks"] = 0
381
382
    for feature_space in feature_spaces:
383
        plot_enriched_gp_info_plots_(
384
            adata=adata,
385
            sample_key=sample_key,
386
            gps=gps,
387
            log_bayes_factors=log_bayes_factors,
388
            cat_key=cat_key,
389
            cat_palette=cat_palette,
390
            cats=cats,
391
            feature_space=feature_space,
392
            spot_size=spot_size,
393
            suptitle=f"{plot_label.replace('_', ' ').title()} "
394
                     f"Top {n_top_enriched_gp_start_idx} to "
395
                     f"{n_top_enriched_gp_end_idx} Enriched GPs: "
396
                     f"GP Scores and Omics Feature Counts in "
397
                     f"{feature_space} Feature Space",
398
            save_fig=save_figs,
399
            figure_folder_path=figure_folder_path,
400
            fig_name=f"{plot_label}_top_{n_top_enriched_gp_start_idx}"
401
                     f"-{n_top_enriched_gp_end_idx}_enriched_gps_gp_scores_"
402
                     f"omics_feature_counts_in_{feature_space}_"
403
                     f"feature_space.{file_format}")
404
            
405
            
406
def plot_enriched_gp_info_plots_(adata: AnnData,
407
                                 sample_key: str,
408
                                 gps: list,
409
                                 log_bayes_factors: list,
410
                                 cat_key: str,
411
                                 cat_palette: dict,
412
                                 cats: list,
413
                                 feature_space: str,
414
                                 spot_size: float,
415
                                 suptitle: str,
416
                                 save_fig: bool,
417
                                 figure_folder_path: str,
418
                                 fig_name: str):
419
    """
420
    This is a helper function to plot gene program info plots in a specified
421
    feature space.
422
    
423
    Parameters
424
    ----------
425
    adata:
426
        An AnnData object with stored information about the gene programs to be
427
        plotted.
428
    sample_key:
429
        Key in ´adata.obs´ where the samples are stored.
430
    gps:
431
        List of gene programs for which info plots will be created.
432
    log_bayes_factors:
433
        List of log bayes factors corresponding to gene programs
434
    cat_key:
435
        Key in ´adata.obs´ where the categories that are used as colors for the
436
        enriched category plot are stored.
437
    cat_palette:
438
        Dictionary of colors that are used to highlight the categories, where
439
        the category is the key of the dictionary and the color is the value.
440
    cats:
441
        List of categories for which the corresponding gene programs in ´gps´
442
        are enriched.
443
    feature_space:
444
        Feature space used for the plots. Can be ´latent´ to use the latent
445
        embeddings for the plots, or it can be any of the samples stored in
446
        ´adata.obs[sample_key]´ to use the respective physical feature space for
447
        the plots.
448
    spot_size:
449
        Spot size used for the spatial plots.
450
    subtitle:
451
        Overall figure title.
452
    save_fig:
453
        If ´True´, save the figure.
454
    figure_folder_path:
455
        Path of the folder where the figure will be saved.
456
    fig_name:
457
        Name of the figure under which it will be saved.
458
    """
459
    # Define figure configurations
460
    ncols = (2 +
461
             adata.uns["n_top_source_genes"] +
462
             adata.uns["n_top_target_genes"] +
463
             adata.uns["n_top_source_peaks"] +
464
             adata.uns["n_top_target_peaks"])
465
    fig_width = (12 + (6 * (
466
        adata.uns["n_top_source_genes"] +
467
        adata.uns["n_top_target_genes"] +
468
        adata.uns["n_top_source_peaks"] +
469
        adata.uns["n_top_target_peaks"])))
470
    wspace = 0.3
471
    fig, axs = plt.subplots(nrows=len(gps),
472
                            ncols=ncols,
473
                            figsize=(fig_width, 6*len(gps)))
474
    if axs.ndim == 1:
475
        axs = axs.reshape(1, -1)
476
    title = fig.suptitle(t=suptitle,
477
                         x=0.55,
478
                         y=(1.1 if len(gps) == 1 else 0.97),
479
                         fontsize=20)
480
    
481
    # Plot enriched gp category and gene program latent scores
482
    for i, gp in enumerate(gps):
483
        if feature_space == "latent":
484
            sc.pl.umap(
485
                adata,
486
                color=cat_key,
487
                palette=cat_palette,
488
                groups=cats[i],
489
                ax=axs[i, 0],
490
                title="Enriched GP Category",
491
                legend_loc="on data",
492
                na_in_legend=False,
493
                show=False)
494
            sc.pl.umap(
495
                adata,
496
                color=gps[i],
497
                color_map="RdBu",
498
                ax=axs[i, 1],
499
                title=f"{gp[:gp.index('_')]}\n"
500
                      f"{gp[gp.index('_') + 1: gp.rindex('_')].replace('_', ' ')}"
501
                      f"\n{gp[gps[i].rindex('_') + 1:]} score (LBF: {round(log_bayes_factors[i])})",
502
                colorbar_loc="bottom",
503
                show=False)
504
        else:
505
            sc.pl.spatial(
506
                adata=adata[adata.obs[sample_key] == feature_space],
507
                color=cat_key,
508
                palette=cat_palette,
509
                groups=cats[i],
510
                ax=axs[i, 0],
511
                spot_size=spot_size,
512
                title="Enriched GP Category",
513
                legend_loc="on data",
514
                na_in_legend=False,
515
                show=False)
516
            sc.pl.spatial(
517
                adata=adata[adata.obs[sample_key] == feature_space],
518
                color=gps[i],
519
                color_map="RdBu",
520
                spot_size=spot_size,
521
                title=f"{gps[i].split('_', 1)[0]}\n{gps[i].split('_', 1)[1]} "
522
                      f"(LBF: {round(log_bayes_factors[i], 2)})",
523
                legend_loc=None,
524
                ax=axs[i, 1],
525
                colorbar_loc="bottom",
526
                show=False) 
527
        axs[i, 0].xaxis.label.set_visible(False)
528
        axs[i, 0].yaxis.label.set_visible(False)
529
        axs[i, 1].xaxis.label.set_visible(False)
530
        axs[i, 1].yaxis.label.set_visible(False)
531
532
        # Plot omics feature counts (or log normalized counts)
533
        modality_entities = []
534
        if len(adata.uns[f"{gp}_source_genes_top_genes"]) > 0:
535
            modality_entities.append("source_genes")
536
        if len(adata.uns[f"{gp}_target_genes_top_genes"]) > 0:
537
            modality_entities.append("target_genes")
538
        if f"{gp}_source_peaks_top_peaks" in adata.uns.keys():
539
            gp_n_source_peaks_top_peaks = (
540
                len(adata.uns[f"{gp}_source_peaks_top_peaks"]))
541
            if len(adata.uns[f"{gp}_source_peaks_top_peaks"]) > 0:
542
                modality_entities.append("source_peaks")
543
        else:
544
            gp_n_source_peaks_top_peaks = 0
545
        if f"{gp}_target_peaks_top_peaks" in adata.uns.keys():
546
            gp_n_target_peaks_top_peaks = (
547
                len(adata.uns[f"{gp}_target_peaks_top_peaks"]))
548
            if len(adata.uns[f"{gp}_target_peaks_top_peaks"]) > 0:
549
                modality_entities.append("target_peaks")
550
        else:
551
            gp_n_target_peaks_top_peaks = 0
552
        for modality_entity in modality_entities:
553
            # Define k for index iteration
554
            if modality_entity == "source_genes":
555
                k = 0
556
            elif modality_entity == "target_genes":
557
                k = len(adata.uns[f"{gp}_source_genes_top_genes"])
558
            elif modality_entity == "source_peaks":
559
                k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) +
560
                     len(adata.uns[f"{gp}_target_genes_top_genes"]))
561
            elif modality_entity == "target_peaks":
562
                k = (len(adata.uns[f"{gp}_source_genes_top_genes"]) +
563
                     len(adata.uns[f"{gp}_target_genes_top_genes"]) +
564
                     len(adata.uns[f"{gp}_source_peaks_top_peaks"]))
565
            for j in range(len(adata.uns[f"{gp}_{modality_entity}_top_"
566
                                         f"{modality_entity.split('_')[1]}"])):
567
                if feature_space == "latent":
568
                    sc.pl.umap(
569
                        adata,
570
                        color=adata.uns[f"{gp}_{modality_entity}_top_"
571
                                        f"{modality_entity.split('_')[1]}"][j],
572
                        color_map=(adata.uns["omics_ft_pos_cmap"] if
573
                                   adata.uns[f"{gp}_{modality_entity}_top_"
574
                                             f"{modality_entity.split('_')[1][:-1]}"
575
                                             "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]),
576
                        ax=axs[i, 2+k+j],
577
                        legend_loc="on data",
578
                        na_in_legend=False,
579
                        title=f"""{adata.uns[f"{gp}_{modality_entity}_top_"
580
                                             f"{modality_entity.split('_')[1]}"
581
                                             ][j]}: """
582
                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
583
                                             f"{modality_entity.split('_')[1][:-1]}"
584
                                             "_importances"][j]:.2f} """
585
                              f"({modality_entity[:-1]}; "
586
                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
587
                                             f"{modality_entity.split('_')[1][:-1]}"
588
                                             "_signs"][j]})""",
589
                        colorbar_loc="bottom",
590
                        show=False)
591
                else:
592
                    sc.pl.spatial(
593
                        adata=adata[adata.obs[sample_key] == feature_space],
594
                        color=adata.uns[f"{gp}_{modality_entity}_top_"
595
                                        f"{modality_entity.split('_')[1]}"][j],
596
                        color_map=(adata.uns["omics_ft_pos_cmap"] if
597
                                   adata.uns[f"{gp}_{modality_entity}_top_"
598
                                             f"{modality_entity.split('_')[1][:-1]}"
599
                                             "_signs"][j] == "+" else adata.uns["omics_ft_neg_cmap"]),
600
                        legend_loc="on data",
601
                        na_in_legend=False,
602
                        ax=axs[i, 2+k+j],
603
                        spot_size=spot_size,
604
                        title=f"""{adata.uns[f"{gp}_{modality_entity}_top_"
605
                                             f"{modality_entity.split('_')[1]}"
606
                                             ][j]} \n"""
607
                              f"""({adata.uns[f"{gp}_{modality_entity}_top_"
608
                                             f"{modality_entity.split('_')[1][:-1]}"
609
                                             "_importances"][j]:.2f}; """
610
                              f"{modality_entity[:-1]}; "
611
                              f"""{adata.uns[f"{gp}_{modality_entity}_top_"
612
                                             f"{modality_entity.split('_')[1][:-1]}"
613
                                             "_signs"][j]})""",
614
                        colorbar_loc="bottom",
615
                        show=False)
616
                axs[i, 2+k+j].xaxis.label.set_visible(False)
617
                axs[i, 2+k+j].yaxis.label.set_visible(False)
618
            # Remove unnecessary axes
619
            for l in range(2 +
620
                           len(adata.uns[f"{gp}_source_genes_top_genes"]) +
621
                           len(adata.uns[f"{gp}_target_genes_top_genes"]) +
622
                           gp_n_source_peaks_top_peaks +
623
                           gp_n_target_peaks_top_peaks, ncols):
624
                axs[i, l].set_visible(False)
625
626
    # Save and display plot
627
    plt.subplots_adjust(wspace=wspace, hspace=0.275)
628
    if save_fig:
629
        fig.savefig(f"{figure_folder_path}/{fig_name}",
630
                    bbox_extra_artists=(title,),
631
                    bbox_inches="tight")
632
    plt.show()
633
634
default_color_dict = {
635
    "0": "#66C5CC",
636
    "1": "#F6CF71",
637
    "2": "#F89C74",
638
    "3": "#DCB0F2",
639
    "4": "#87C55F",
640
    "5": "#9EB9F3",
641
    "6": "#FE88B1",
642
    "7": "#C9DB74",
643
    "8": "#8BE0A4",
644
    "9": "#B497E7",
645
    "10": "#D3B484",
646
    "11": "#B3B3B3",
647
    "12": "#276A8C", # Royal Blue
648
    "13": "#DAB6C4", # Pink
649
    "14": "#C38D9E", # Mauve-Pink
650
    "15": "#9D88A2", # Mauve
651
    "16": "#FF4D4D", # Light Red
652
    "17": "#9B4DCA", # Lavender-Purple
653
    "18": "#FF9CDA", # Bright Pink
654
    "19": "#FF69B4", # Hot Pink
655
    "20": "#FF00FF", # Magenta
656
    "21": "#DA70D6", # Orchid
657
    "22": "#BA55D3", # Medium Orchid
658
    "23": "#8A2BE2", # Blue Violet
659
    "24": "#9370DB", # Medium Purple
660
    "25": "#7B68EE", # Medium Slate Blue
661
    "26": "#4169E1", # Royal Blue
662
    "27": "#FF8C8C", # Salmon Pink
663
    "28": "#FFAA80", # Light Coral
664
    "29": "#48D1CC", # Medium Turquoise
665
    "30": "#40E0D0", # Turquoise
666
    "31": "#00FF00", # Lime
667
    "32": "#7FFF00", # Chartreuse
668
    "33": "#ADFF2F", # Green Yellow
669
    "34": "#32CD32", # Lime Green
670
    "35": "#228B22", # Forest Green
671
    "36": "#FFD8B8", # Peach
672
    "37": "#008080", # Teal
673
    "38": "#20B2AA", # Light Sea Green
674
    "39": "#00FFFF", # Cyan
675
    "40": "#00BFFF", # Deep Sky Blue
676
    "41": "#4169E1", # Royal Blue
677
    "42": "#0000CD", # Medium Blue
678
    "43": "#00008B", # Dark Blue
679
    "44": "#8B008B", # Dark Magenta
680
    "45": "#FF1493", # Deep Pink
681
    "46": "#FF4500", # Orange Red
682
    "47": "#006400", # Dark Green
683
    "48": "#FF6347", # Tomato
684
    "49": "#FF7F50", # Coral
685
    "50": "#CD5C5C", # Indian Red
686
    "51": "#B22222", # Fire Brick
687
    "52": "#FFB83F",  # Light Orange
688
    "53": "#8B0000", # Dark Red
689
    "54": "#D2691E", # Chocolate
690
    "55": "#A0522D", # Sienna
691
    "56": "#800000", # Maroon
692
    "57": "#808080", # Gray
693
    "58": "#A9A9A9", # Dark Gray
694
    "59": "#C0C0C0", # Silver
695
    "60": "#9DD84A",
696
    "61": "#F5F5F5", # White Smoke
697
    "62": "#F17171", # Light Red
698
    "63": "#000000", # Black
699
    "64": "#FF8C42", # Tangerine
700
    "65": "#F9A11F", # Bright Orange-Yellow
701
    "66": "#FACC15", # Golden Yellow
702
    "67": "#E2E062", # Pale Lime
703
    "68": "#BADE92", # Soft Lime
704
    "69": "#70C1B3", # Greenish-Blue
705
    "70": "#41B3A3", # Turquoise
706
    "71": "#5EAAA8", # Gray-Green
707
    "72": "#72B01D", # Chartreuse
708
    "73": "#9CD08F", # Light Green
709
    "74": "#8EBA43", # Olive Green
710
    "75": "#FAC8C3", # Light Pink
711
    "76": "#E27D60", # Dark Salmon
712
    "77": "#C38D9E", # Mauve-Pink
713
    "78": "#937D64", # Light Brown
714
    "79": "#B1C1CC", # Light Blue-Gray
715
    "80": "#88A0A8", # Gray-Blue-Green
716
    "81": "#4E598C", # Dark Blue-Purple
717
    "82": "#4B4E6D", # Dark Gray-Blue
718
    "83": "#8E9AAF", # Light Blue-Grey
719
    "84": "#C0D6DF", # Pale Blue-Grey
720
    "85": "#97C1A9", # Blue-Green
721
    "86": "#4C6E5D", # Dark Green
722
    "87": "#95B9C7", # Pale Blue-Green
723
    "88": "#C1D5E0", # Pale Gray-Blue
724
    "89": "#ECDB54", # Bright Yellow
725
    "90": "#E89B3B", # Bright Orange
726
    "91": "#CE5A57", # Deep Red
727
    "92": "#C3525A", # Dark Red
728
    "93": "#B85D8E", # Berry
729
    "94": "#7D5295", # Deep Purple
730
    "-1" : "#E1D9D1",
731
    "None" : "#E1D9D1"
732
}
733
734
def create_new_color_dict(
735
        adata,
736
        cat_key,
737
        color_palette="default",
738
        overwrite_color_dict={"-1" : "#E1D9D1"},
739
        skip_default_colors=0):
740
    """
741
    Create a dictionary of color hexcodes for a specified category.
742
743
    Parameters
744
    ----------
745
    adata:
746
        AnnData object.
747
    cat_key:
748
        Key in ´adata.obs´ where the categories are stored for which color
749
        hexcodes will be created.
750
    color_palette:
751
        Type of color palette.
752
    overwrite_color_dict:
753
        Dictionary with overwrite values that will take precedence over the
754
        automatically created dictionary.
755
    skip_default_colors:
756
        Number of colors to skip from the default color dict.
757
758
    Returns
759
    ----------
760
    new_color_dict:
761
        The color dictionary with a hexcode for each category.
762
    """
763
    new_categories = adata.obs[cat_key].unique().tolist()
764
    if color_palette == "cell_type_30":
765
        # https://github.com/scverse/scanpy/blob/master/scanpy/plotting/palettes.py#L40
766
        new_color_dict = {key: value for key, value in zip(
767
            new_categories,
768
            ["#023fa5",
769
             "#7d87b9",
770
             "#bec1d4",
771
             "#d6bcc0",
772
             "#bb7784",
773
             "#8e063b",
774
             "#4a6fe3",
775
             "#8595e1",
776
             "#b5bbe3",
777
             "#e6afb9",
778
             "#e07b91",
779
             "#d33f6a",
780
             "#11c638",
781
             "#8dd593",
782
             "#c6dec7",
783
             "#ead3c6",
784
             "#f0b98d",
785
             "#ef9708",
786
             "#0fcfc0",
787
             "#9cded6",
788
             "#d5eae7",
789
             "#f3e1eb",
790
             "#f6c4e1",
791
             "#f79cd4",
792
             '#7f7f7f',
793
             "#c7c7c7",
794
             "#1CE6FF",
795
             "#336600"])}
796
    elif color_palette == "cell_type_20":
797
        # https://github.com/vega/vega/wiki/Scales#scale-range-literals (some adjusted)
798
        new_color_dict = {key: value for key, value in zip(
799
            new_categories,
800
            ['#1f77b4',
801
             '#ff7f0e',
802
             '#279e68',
803
             '#d62728',
804
             '#aa40fc',
805
             '#8c564b',
806
             '#e377c2',
807
             '#b5bd61',
808
             '#17becf',
809
             '#aec7e8',
810
             '#ffbb78',
811
             '#98df8a',
812
             '#ff9896',
813
             '#c5b0d5',
814
             '#c49c94',
815
             '#f7b6d2',
816
             '#dbdb8d',
817
             '#9edae5',
818
             '#ad494a',
819
             '#8c6d31'])}
820
    elif color_palette == "cell_type_10":
821
        # scanpy vega10
822
        new_color_dict = {key: value for key, value in zip(
823
            new_categories,
824
            ['#7f7f7f',
825
             '#ff7f0e',
826
             '#279e68',
827
             '#e377c2',
828
             '#17becf',
829
             '#8c564b',
830
             '#d62728',
831
             '#1f77b4',
832
             '#b5bd61',
833
             '#aa40fc'])}
834
    elif color_palette == "batch":
835
        # sns.color_palette("colorblind").as_hex()
836
        new_color_dict = {key: value for key, value in zip(
837
            new_categories,
838
            ['#0173b2', '#d55e00', '#ece133', '#ca9161', '#fbafe4',
839
             '#949494', '#de8f05', '#029e73', '#cc78bc', '#56b4e9',
840
             '#F0F8FF', '#FAEBD7', '#00FFFF', '#7FFFD4', '#F0FFFF',
841
             '#F5F5DC', '#FFE4C4', '#000000', '#FFEBCD', '#0000FF',
842
             '#8A2BE2', '#A52A2A', '#DEB887', '#5F9EA0', '#7FFF00',
843
             '#D2691E', '#FF7F50', '#6495ED', '#FFF8DC', '#DC143C'])}
844
    elif color_palette == "default":
845
        new_color_dict = {key: value for key, value in zip(new_categories, list(default_color_dict.values())[skip_default_colors:])}
846
    for key, val in overwrite_color_dict.items():
847
        new_color_dict[key] = val
848
    return new_color_dict
849
850
851
def plot_non_zero_gene_count_means_dist(
852
        adata: AnnData,
853
        genes: list,
854
        gene_label: str):
855
    """
856
    Plot distribution of non zero gene count means in the adata over all 
857
    specified genes.
858
    """
859
    gene_counts = adata[
860
        :, [gene for gene in adata.var_names if gene in genes]].layers["counts"]
861
    nz_gene_means = np.mean(
862
        np.ma.masked_equal(gene_counts.toarray(), 0), axis=0).data
863
    
864
    sns.kdeplot(nz_gene_means)
865
    plt.title(f"{gene_label} Genes Average Non-Zero Gene Counts per Gene")
866
    plt.xlabel("Average Non-zero Gene Counts")
867
    plt.ylabel("Gene Density")
868
    plt.show()
869
870
871
def compute_communication_gp_network(
872
    gp_list: list,
873
    model: NicheCompass,
874
    group_key: str="niche",
875
    filter_key: Optional[str]=None,
876
    filter_cat: Optional[str]=None,
877
    n_neighbors: int=90):
878
    """
879
    Compute a network of category aggregated cell-pair communication strengths.
880
    
881
    First, compute cell-cell communication potential scores for each cell.
882
    Then dot product them and take into account neighborhoods to compute
883
    cell-pair communication strengths. Then, normalize cell-pair communication
884
    strengths.
885
    
886
    Parameters
887
    ----------
888
    gp_list:
889
        List of GPs for which the cell-pair communication strengths are computed.
890
    model:
891
        A trained NicheCompass model.
892
    group_key:
893
        Key in ´adata.obs´ where the groups are stored over which the cell-pair
894
        communication strengths will be aggregated.
895
    filter_key:
896
        Key in ´adata.obs´ that contains the category for which the results are
897
        filtered.
898
    filter_cat:
899
        Category for which the results are filtered.
900
    n_neighbors:
901
        Number of neighbors for the gp-specific neighborhood graph.
902
903
    Returns
904
    ----------
905
    network_df:
906
        A pandas dataframe with aggregated, normalized cell-pair communication strengths.
907
    """
908
    # Compute neighborhood graph
909
    compute_knn = True
910
    if 'spatial_cci' in model.adata.uns.keys():
911
        if model.adata.uns['spatial_cci']['params']['n_neighbors'] == n_neighbors:
912
            compute_knn = False
913
    if compute_knn:
914
        sc.pp.neighbors(model.adata,
915
                        n_neighbors=n_neighbors,
916
                        use_rep="spatial",
917
                        key_added="spatial_cci")
918
    
919
    gp_network_dfs = []
920
    gp_summary_df = model.get_gp_summary()
921
    for gp in gp_list:
922
        gp_idx = model.adata.uns[model.gp_names_key_].tolist().index(gp)
923
        active_gp_idx = model.adata.uns[model.active_gp_names_key_].tolist().index(gp)
924
        gp_scores = model.adata.obsm[model.latent_key_][:, active_gp_idx]
925
        gp_targets_cats = model.adata.varm[model.gp_targets_categories_mask_key_][:, gp_idx]
926
        gp_sources_cats = model.adata.varm[model.gp_sources_categories_mask_key_][:, gp_idx]
927
        targets_cats_label_encoder = model.adata.uns[model.targets_categories_label_encoder_key_]
928
        sources_cats_label_encoder = model.adata.uns[model.sources_categories_label_encoder_key_]
929
930
        sources_cat_idx_dict = {}
931
        for source_cat, source_cat_label in sources_cats_label_encoder.items():
932
            sources_cat_idx_dict[source_cat] = np.where(gp_sources_cats == source_cat_label)[0]
933
934
        targets_cat_idx_dict = {}
935
        for target_cat, target_cat_label in targets_cats_label_encoder.items():
936
            targets_cat_idx_dict[target_cat] = np.where(gp_targets_cats == target_cat_label)[0]
937
938
        # Get indices of all source and target genes
939
        source_genes_idx = np.array([], dtype=np.int64)
940
        for key in sources_cat_idx_dict.keys():
941
            source_genes_idx = np.append(source_genes_idx,
942
                                         sources_cat_idx_dict[key])
943
        target_genes_idx = np.array([], dtype=np.int64)
944
        for key in targets_cat_idx_dict.keys():
945
            target_genes_idx = np.append(target_genes_idx,
946
                                         targets_cat_idx_dict[key])
947
948
        # Compute cell-cell communication potential scores
949
        gp_source_scores = np.zeros((len(model.adata.obs), len(source_genes_idx)))
950
        gp_target_scores = np.zeros((len(model.adata.obs), len(target_genes_idx)))
951
952
        for i, source_gene_idx in enumerate(source_genes_idx):
953
            source_gene = model.adata.var_names[source_gene_idx]
954
            gp_source_scores[:, i] = (
955
                model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(source_gene)].X.toarray().flatten().max() *
956
                gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_source_genes"].values[0].index(source_gene)] *
957
                gp_scores)
958
959
        for j, target_gene_idx in enumerate(target_genes_idx):
960
            target_gene = model.adata.var_names[target_gene_idx]
961
            gp_target_scores[:, j] = (
962
                model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten() / model.adata[:, model.adata.var_names.tolist().index(target_gene)].X.toarray().flatten().max() *
963
                gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes_weights"].values[0][gp_summary_df[gp_summary_df["gp_name"] == gp]["gp_target_genes"].values[0].index(target_gene)] *
964
                gp_scores)
965
966
        agg_gp_source_score = gp_source_scores.mean(1).astype("float32")
967
        agg_gp_target_score = gp_target_scores.mean(1).astype("float32")
968
        agg_gp_source_score[agg_gp_source_score < 0] = 0.
969
        agg_gp_target_score[agg_gp_target_score < 0] = 0.
970
971
        model.adata.obs[f"{gp}_source_score"] = agg_gp_source_score
972
        model.adata.obs[f"{gp}_target_score"] = agg_gp_target_score
973
        
974
        del(gp_target_scores)
975
        del(gp_source_scores)
976
977
        agg_gp_source_score = sp.csr_matrix(agg_gp_source_score)
978
        agg_gp_target_score = sp.csr_matrix(agg_gp_target_score)
979
980
        model.adata.obsp[f"{gp}_connectivities"] = (model.adata.obsp["spatial_cci_connectivities"] > 0).multiply(
981
            agg_gp_source_score.T.dot(agg_gp_target_score))
982
983
        # Aggregate gp connectivities for each group
984
        gp_network_df_pivoted = aggregate_obsp_matrix_per_cell_type(
985
            adata=model.adata,
986
            obsp_key=f"{gp}_connectivities",
987
            cell_type_key=group_key,
988
            group_key=filter_key,
989
            agg_rows=True)
990
991
        if filter_key is not None:
992
            gp_network_df_pivoted = gp_network_df_pivoted.loc[filter_cat, :]
993
994
        gp_network_df = gp_network_df_pivoted.melt(var_name="source", value_name="gp_score", ignore_index=False).reset_index()
995
        gp_network_df.columns = ["source", "target", "strength"]
996
997
        gp_network_df = gp_network_df.sort_values("strength", ascending=False)
998
999
        # Normalize strength
1000
        min_value = gp_network_df["strength"].min()
1001
        max_value = gp_network_df["strength"].max()
1002
        gp_network_df["strength_unscaled"] = gp_network_df["strength"]
1003
        gp_network_df["strength"] = (gp_network_df["strength"] - min_value) / (max_value - min_value)
1004
        gp_network_df["strength"] = np.round(gp_network_df["strength"], 2)
1005
        gp_network_df = gp_network_df[gp_network_df["strength"] > 0]
1006
1007
        gp_network_df["edge_type"] = gp
1008
        gp_network_dfs.append(gp_network_df)
1009
1010
    network_df = pd.concat(gp_network_dfs, ignore_index=True)
1011
    return network_df
1012
1013
1014
def visualize_communication_gp_network(
1015
    adata,
1016
    network_df,
1017
    cat_colors,
1018
    edge_type_colors: Optional[dict]=None,
1019
    edge_width_scale: int=20.0,
1020
    node_size: int=500,
1021
    fontsize: int=14,
1022
    figsize: Tuple[int, int]=(18, 16),
1023
    plot_legend: bool=True,
1024
    save: bool=False,
1025
    save_path: str="communication_gp_network.svg",
1026
    show: bool=True,
1027
    text_space: float=1.3,
1028
    connection_style="arc3, rad = 0.1",
1029
    cat_key: str="niche",
1030
    edge_attr: str="strength"):
1031
    """
1032
    Visualize a communication gp network.
1033
    """
1034
    # Assuming you have unique edge types in your 'edge_type' column
1035
    edge_types = np.unique(network_df['edge_type'])
1036
    
1037
    if edge_type_colors is None:
1038
        # Colorblindness adjusted vega_10
1039
        # See https://github.com/theislab/scanpy/issues/387
1040
        vega_10 = list(map(colors.to_hex, cm.tab10.colors))
1041
        vega_10_scanpy = vega_10.copy()
1042
        vega_10_scanpy[2] = "#279e68"  # green
1043
        vega_10_scanpy[4] = "#aa40fc"  # purple
1044
        vega_10_scanpy[8] = "#b5bd61"  # kakhi
1045
        edge_type_colors = vega_10_scanpy
1046
1047
    # Create a dictionary that maps edge types to colors
1048
    edge_type_color_dict = {edge_type: color for edge_type, color in zip(edge_types, edge_type_colors)}
1049
1050
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
1051
    ax.axis("off")
1052
    G = nx.from_pandas_edgelist(
1053
        network_df,
1054
        source="source",
1055
        target="target",
1056
        edge_attr=["edge_type", edge_attr],
1057
        create_using=nx.DiGraph(),
1058
    )
1059
    pos = nx.circular_layout(G)
1060
1061
    nx.set_node_attributes(G, cat_colors, "color")
1062
    node_color = nx.get_node_attributes(G, "color")
1063
1064
    description = nx.draw_networkx_labels(G, pos, font_size=fontsize)
1065
    n = adata.obs[cat_key].nunique()
1066
    node_list = sorted(G.nodes())
1067
    angle = []
1068
    angle_dict = {}
1069
    for i, node in zip(range(n), node_list):
1070
        theta = 2.0 * np.pi * i / n
1071
        angle.append((np.cos(theta), np.sin(theta)))
1072
        angle_dict[node] = theta
1073
    pos = {}
1074
    for node_i, node in enumerate(node_list):
1075
        pos[node] = angle[node_i]
1076
1077
    r = fig.canvas.get_renderer()
1078
    trans = plt.gca().transData.inverted()
1079
    for node, t in description.items():
1080
        bb = t.get_window_extent(renderer=r)
1081
        bbdata = bb.transformed(trans)
1082
        radius = text_space + bbdata.width / 2.0
1083
        position = (radius * np.cos(angle_dict[node]), radius * np.sin(angle_dict[node]))
1084
        t.set_position(position)
1085
        t.set_rotation(angle_dict[node] * 360.0 / (2.0 * np.pi))
1086
        t.set_clip_on(False)
1087
1088
    edgelist = [(u, v) for u, v, e in G.edges(data=True) if u != v]
1089
    edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u != v]
1090
    width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u != v]
1091
1092
    h2 = nx.draw_networkx(
1093
        G,
1094
        pos,
1095
        with_labels=False,
1096
        node_size=node_size,
1097
        edgelist=edgelist,
1098
        width=width,
1099
        edge_vmin=0.0,
1100
        edge_vmax=1.0,
1101
        edge_color=edge_colors,  # Use the edge type colors here
1102
        arrows=True,
1103
        arrowstyle="-|>",
1104
        arrowsize=20,
1105
        vmin=0.0,
1106
        vmax=1.0,
1107
        cmap=plt.cm.binary,  # Use a colormap for node colors if needed
1108
        node_color=list(node_color.values()),
1109
        ax=ax,
1110
        connectionstyle=connection_style,
1111
    )
1112
1113
    #https://stackoverflow.com/questions/19877666/add-legends-to-linecollection-plot - uses plotted data to define the color but here we already have colors defined, so just need a Line2D object.
1114
    def make_proxy(clr, mappable, **kwargs):
1115
        return Line2D([0, 1], [0, 1], color=clr, **kwargs)
1116
1117
    # generate proxies with the above function
1118
    proxies = [make_proxy(clr, h2, lw=5) for clr in set(edge_colors)]
1119
    labels = [edge.split("_")[0] + " GP" for edge in edge_types[::-1]]
1120
1121
    if plot_legend:
1122
        lgd = plt.legend(proxies, labels, loc="lower left")
1123
1124
    edgelist = [(u, v) for u, v, e in G.edges(data=True) if ((u == v))] + [(u, v) for u, v, e in G.edges(data=True) if ((u != v))]
1125
    edge_colors = [edge_type_color_dict[edge_data['edge_type']] for u, v, edge_data in G.edges(data=True) if u == v]
1126
    width = [e[edge_attr] * edge_width_scale for u, v, e in G.edges(data=True) if u == v] + [0 for u, v, e in G.edges(data=True) if ((u != v))]
1127
    nx.draw_networkx_edges(
1128
        G,
1129
        pos,
1130
        node_size=node_size,
1131
        edgelist=edgelist, 
1132
        width=width,
1133
        edge_vmin=0.0,
1134
        edge_vmax=1.0,
1135
        edge_color=edge_colors,
1136
        arrows=False,
1137
        arrowstyle="-|>",
1138
        arrowsize=20,
1139
        ax=ax,
1140
        connectionstyle=connection_style)
1141
    plt.tight_layout()
1142
    if save:
1143
        plt.savefig(save_path)
1144
    if show:
1145
        plt.show()
1146
    plt.close(fig)
1147
    plt.ion()