[d01132]: / bin / plot_rna_clustering.py

Download this file

114 lines (99 with data), 3.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Script for plotting clustering, while also evaluating distance between clusters
"""
import os
import sys
import argparse
import logging
import itertools
SRC_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
)
assert os.path.isdir(SRC_DIR)
sys.path.append(SRC_DIR)
import interpretation
import adata_utils
import plot_utils
import utils
from evaluate_bulk_rna_concordance import load_file_flex_format
REF_MARKER_GENES = {
"PBMC": interpretation.PBMC_MARKER_GENES,
"PBMC_Seurat": interpretation.SEURAT_PBMC_MARKER_GENES,
}
def build_parser():
"""Build commandline argument parser"""
parser = argparse.ArgumentParser(
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("fname", type=str, help="File to plot clustering for")
parser.add_argument("plotprefix", type=str, help="File prefix to save plots to")
parser.add_argument(
"--resolution", "-r", type=float, default=1.0, help="Clustering resolution"
)
parser.add_argument(
"--markers",
"-m",
type=str,
choices=REF_MARKER_GENES.keys(),
default="PBMC",
help="Marker genes to evaluate with",
)
parser.add_argument(
"--fast", action="store_true", help="Skip pairwise comparisons for speed"
)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
adata = load_file_flex_format(args.fname)
logging.info(f"Read in {os.path.abspath(args.fname)} for a matrix of {adata.shape}")
# Do clustering
logging.info(f"Clustering with resolution {args.resolution}")
plot_utils.preprocess_anndata(
adata,
louvain_resolution=args.resolution,
leiden_resolution=args.resolution,
seed=1234,
)
plot_utils.plot_clustering_anndata(
adata, label_counter=True, fname=args.plotprefix + "_leiden_clustering.pdf"
)
if not args.fast:
logging.info(f"Computing pairwise distances between clusters")
(
clustering_distance_means,
clustering_distance_sds,
) = adata_utils.evaluate_pairwise_cluster_distance(adata, stratify="leiden")
clustering_distance_means.to_csv(args.plotprefix + "_cluster_dist_means.csv")
clustering_distance_sds.to_csv(args.plotprefix + "_cluster_dist_sds.csv")
# Find marker genes and label clusters
logging.info(f"Finding marker genes")
adata_utils.find_marker_genes(adata, n_genes=25)
logging.info(
f"Labelling clusters using {args.markers} marker genes (n={len(set(itertools.chain.from_iterable(REF_MARKER_GENES[args.markers])))})"
)
marker_matches = interpretation.annotate_clusters_to_celltypes(
adata, REF_MARKER_GENES[args.markers],
)
plot_utils.plot_clustering_anndata(
adata,
color="leiden_celltypes",
label_counter=True,
fname=args.plotprefix + "_celltype_clustering.pdf",
)
if not args.fast:
logging.info("Computing pairwise distances between labelled clusters")
(
labelled_cluster_distance_means,
labelled_cluster_distance_sds,
) = adata_utils.evaluate_pairwise_cluster_distance(
adata, stratify="leiden_celltypes"
)
labelled_cluster_distance_means.to_csv(
args.plotprefix + "_labelled_cluster_dist_means.csv"
)
labelled_cluster_distance_sds.to_csv(
args.plotprefix + "_labelled_cluster_dist_sds.csv"
)
if __name__ == "__main__":
main()