Diff of /bin/plot_rna_scatter.py [000000] .. [d01132]

Switch to side-by-side view

--- a
+++ b/bin/plot_rna_scatter.py
@@ -0,0 +1,205 @@
+"""
+Short script to plot RNA scatterplots
+"""
+
+import os
+import sys
+import re
+import logging
+import argparse
+from typing import *
+
+import numpy as np
+import scipy
+import anndata as ad
+import scanpy as sc
+import matplotlib.pyplot as plt
+
+SRC_DIR = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
+)
+assert os.path.isdir(SRC_DIR)
+sys.path.append(SRC_DIR)
+import plot_utils
+import utils
+
+
+logging.basicConfig(level=logging.INFO)
+
+
+def sanitize_obs_names(names: List[str]) -> List[str]:
+    """
+    Sanitize the obs names
+    >>> sanitize_obs_names(['a', 'b'])
+    ['a', 'b']
+    >>> sanitize_obs_names(['foo#a', 'bar#b'])
+    ['a', 'b']
+    >>> sanitize_obs_names(['10xPBMC#TAAGTGCAGCGCACAA-1', '10xPBMC#AGCTATGTCTATCTTG-1'])
+    ['TAAGTGCAGCGCACAA-1', 'AGCTATGTCTATCTTG-1']
+    """
+    # Strips out the prefix that archr inserts
+    def relocate_rep_num(s: str) -> str:
+        """
+        Use the replicate as a suffix instead of prefix
+        """
+        if "#" not in s:
+            return s
+        prefix, samplename = s.split("#")
+        rep_matches = re.findall(f"_rep[0-9]+$", prefix)
+        if rep_matches:
+            rep_match = rep_matches.pop()
+            # Reps are 1 indexed, names are 0 indexed
+            num = int(rep_match.strip("_rep")) - 1
+            assert num >= 0, f"Error when processing {s}"
+            return samplename + f"-{num}"
+        else:
+            return samplename
+
+    def drop_extra_dash(s: str) -> str:
+        """This may cause issues but it seems to be fine for now"""
+        tokens = s.split("-")
+        return "-".join(tokens[:2])
+
+    retval = [relocate_rep_num(n) for n in names]
+    retval = [drop_extra_dash(n) for n in retval]
+    if not utils.is_all_unique(retval):
+        logging.warning("Got duplicated names after sanitization")
+    return retval
+
+
+def build_parser():
+    """Build a simple commandline parser"""
+    parser = argparse.ArgumentParser(
+        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument("x_rna", type=str, help="X axis RNA data")
+    parser.add_argument("y_rna", type=str, help="Y axis RNA data")
+    parser.add_argument(
+        "--outfname", type=str, default="", required=False, help="Filename to save plot"
+    )
+    parser.add_argument(
+        "--subset", "-s", type=int, default=100000, help="Subset amount (0 to disable)"
+    )
+    parser.add_argument(
+        "-g", "--genelist", type=str, default="", help="File containing list to plot"
+    )
+    parser.add_argument(
+        "--linear",
+        action="store_true",
+        help="Plot in linear space instead of log space",
+    )
+    parser.add_argument(
+        "--density",
+        action="store_true",
+        help="Plot density scatterplot instead of individual points",
+    )
+    parser.add_argument(
+        "--densitylogstretch",
+        type=int,
+        default=1000,
+        help="Density logstretch for image normalization",
+    )
+    parser.add_argument("--title", "-t", type=str, default="")
+    parser.add_argument("--xlabel", type=str, default="Original norm counts")
+    parser.add_argument("--ylabel", type=str, default="Inferred norm counts")
+    parser.add_argument(
+        "--figsize", type=float, nargs=2, default=(7, 5), help="Figure size"
+    )
+    return parser
+
+
+def main():
+    parser = build_parser()
+    args = parser.parse_args()
+
+    if args.x_rna.endswith(".h5ad"):
+        x_rna = ad.read_h5ad(args.x_rna)
+    elif args.x_rna.endswith(".h5"):
+        x_rna = sc.read_10x_h5(args.x_rna, gex_only=False)
+    else:
+        raise ValueError(f"Unrecognized file extension: {args.x_rna}")
+    x_rna.X = utils.ensure_arr(x_rna.X)
+    x_rna.obs_names = sanitize_obs_names(x_rna.obs_names)
+    x_rna.obs_names_make_unique()
+    logging.info(f"Read in {args.x_rna} for {x_rna.shape}")
+
+    if args.y_rna.endswith(".h5ad"):
+        y_rna = ad.read_h5ad(args.y_rna)
+    elif args.y_rna.endswith(".h5"):
+        y_rna = sc.read_10x_h5(args.y_rna, gex_only=False)
+    else:
+        raise ValueError(f"Unrecognized file extension: {args.y_rna}")
+    y_rna.X = utils.ensure_arr(y_rna.X)
+    y_rna.obs_names = sanitize_obs_names(y_rna.obs_names)
+    y_rna.obs_names_make_unique()
+    logging.info(f"Read in {args.y_rna} for {y_rna.shape}")
+
+    if not (
+        len(x_rna.obs_names) == len(y_rna.obs_names)
+        and np.all(x_rna.obs_names == y_rna.obs_names)
+    ):
+        logging.warning("Rematching obs axis")
+        shared_obs_names = sorted(
+            list(set(x_rna.obs_names).intersection(y_rna.obs_names))
+        )
+        logging.info(f"Found {len(shared_obs_names)} shared obs")
+        assert shared_obs_names, (
+            "Got empty list of shared obs"
+            + "\n"
+            + str(x_rna.obs_names)
+            + "\n"
+            + str(y_rna.obs_names)
+        )
+        x_rna = x_rna[shared_obs_names]
+        y_rna = y_rna[shared_obs_names]
+    assert np.all(x_rna.obs_names == y_rna.obs_names)
+    if not (
+        len(x_rna.var_names) == len(y_rna.var_names)
+        and np.all(x_rna.var_names == y_rna.var_names)
+    ):
+        logging.warning("Rematching variable axis")
+        shared_var_names = sorted(
+            list(set(x_rna.var_names).intersection(y_rna.var_names))
+        )
+        logging.info(f"Found {len(shared_var_names)} shared variables")
+        assert shared_var_names, (
+            "Got empty list of shared vars"
+            + "\n"
+            + str(x_rna.var_names)
+            + "\n"
+            + str(y_rna.var_names)
+        )
+        x_rna = x_rna[:, shared_var_names]
+        y_rna = y_rna[:, shared_var_names]
+    assert np.all(x_rna.var_names == y_rna.var_names)
+
+    # Subset by gene list if given
+    if args.genelist:
+        gene_list = utils.read_delimited_file(args.genelist)
+        logging.info(f"Read {len(gene_list)} genes from {args.genelist}")
+        x_rna = x_rna[:, gene_list]
+        y_rna = y_rna[:, gene_list]
+
+    assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}"
+
+    fig = plot_utils.plot_scatter_with_r(
+        x_rna.X,
+        y_rna.X,
+        subset=args.subset,
+        one_to_one=True,
+        logscale=not args.linear,
+        density_heatmap=args.density,
+        density_logstretch=args.densitylogstretch,
+        fname=args.outfname,
+        title=args.title,
+        xlabel=args.xlabel,
+        ylabel=args.ylabel,
+        figsize=args.figsize,
+    )
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod()
+    main()