--- a
+++ b/bin/plot_rna_clustering.py
@@ -0,0 +1,113 @@
+"""
+Script for plotting clustering, while also evaluating distance between clusters
+"""
+
+import os
+import sys
+import argparse
+import logging
+import itertools
+
+SRC_DIR = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
+)
+assert os.path.isdir(SRC_DIR)
+sys.path.append(SRC_DIR)
+import interpretation
+import adata_utils
+import plot_utils
+import utils
+
+from evaluate_bulk_rna_concordance import load_file_flex_format
+
+REF_MARKER_GENES = {
+    "PBMC": interpretation.PBMC_MARKER_GENES,
+    "PBMC_Seurat": interpretation.SEURAT_PBMC_MARKER_GENES,
+}
+
+
+def build_parser():
+    """Build commandline argument parser"""
+    parser = argparse.ArgumentParser(
+        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument("fname", type=str, help="File to plot clustering for")
+    parser.add_argument("plotprefix", type=str, help="File prefix to save plots to")
+    parser.add_argument(
+        "--resolution", "-r", type=float, default=1.0, help="Clustering resolution"
+    )
+    parser.add_argument(
+        "--markers",
+        "-m",
+        type=str,
+        choices=REF_MARKER_GENES.keys(),
+        default="PBMC",
+        help="Marker genes to evaluate with",
+    )
+    parser.add_argument(
+        "--fast", action="store_true", help="Skip pairwise comparisons for speed"
+    )
+    return parser
+
+
+def main():
+    parser = build_parser()
+    args = parser.parse_args()
+
+    adata = load_file_flex_format(args.fname)
+    logging.info(f"Read in {os.path.abspath(args.fname)} for a matrix of {adata.shape}")
+
+    # Do clustering
+    logging.info(f"Clustering with resolution {args.resolution}")
+    plot_utils.preprocess_anndata(
+        adata,
+        louvain_resolution=args.resolution,
+        leiden_resolution=args.resolution,
+        seed=1234,
+    )
+
+    plot_utils.plot_clustering_anndata(
+        adata, label_counter=True, fname=args.plotprefix + "_leiden_clustering.pdf"
+    )
+    if not args.fast:
+        logging.info(f"Computing pairwise distances between clusters")
+        (
+            clustering_distance_means,
+            clustering_distance_sds,
+        ) = adata_utils.evaluate_pairwise_cluster_distance(adata, stratify="leiden")
+        clustering_distance_means.to_csv(args.plotprefix + "_cluster_dist_means.csv")
+        clustering_distance_sds.to_csv(args.plotprefix + "_cluster_dist_sds.csv")
+
+    # Find marker genes and label clusters
+    logging.info(f"Finding marker genes")
+    adata_utils.find_marker_genes(adata, n_genes=25)
+    logging.info(
+        f"Labelling clusters using {args.markers} marker genes (n={len(set(itertools.chain.from_iterable(REF_MARKER_GENES[args.markers])))})"
+    )
+    marker_matches = interpretation.annotate_clusters_to_celltypes(
+        adata, REF_MARKER_GENES[args.markers],
+    )
+    plot_utils.plot_clustering_anndata(
+        adata,
+        color="leiden_celltypes",
+        label_counter=True,
+        fname=args.plotprefix + "_celltype_clustering.pdf",
+    )
+    if not args.fast:
+        logging.info("Computing pairwise distances between labelled clusters")
+        (
+            labelled_cluster_distance_means,
+            labelled_cluster_distance_sds,
+        ) = adata_utils.evaluate_pairwise_cluster_distance(
+            adata, stratify="leiden_celltypes"
+        )
+        labelled_cluster_distance_means.to_csv(
+            args.plotprefix + "_labelled_cluster_dist_means.csv"
+        )
+        labelled_cluster_distance_sds.to_csv(
+            args.plotprefix + "_labelled_cluster_dist_sds.csv"
+        )
+
+
+if __name__ == "__main__":
+    main()