--- a +++ b/bin/plot_rna_scatter.py @@ -0,0 +1,205 @@ +""" +Short script to plot RNA scatterplots +""" + +import os +import sys +import re +import logging +import argparse +from typing import * + +import numpy as np +import scipy +import anndata as ad +import scanpy as sc +import matplotlib.pyplot as plt + +SRC_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" +) +assert os.path.isdir(SRC_DIR) +sys.path.append(SRC_DIR) +import plot_utils +import utils + + +logging.basicConfig(level=logging.INFO) + + +def sanitize_obs_names(names: List[str]) -> List[str]: + """ + Sanitize the obs names + >>> sanitize_obs_names(['a', 'b']) + ['a', 'b'] + >>> sanitize_obs_names(['foo#a', 'bar#b']) + ['a', 'b'] + >>> sanitize_obs_names(['10xPBMC#TAAGTGCAGCGCACAA-1', '10xPBMC#AGCTATGTCTATCTTG-1']) + ['TAAGTGCAGCGCACAA-1', 'AGCTATGTCTATCTTG-1'] + """ + # Strips out the prefix that archr inserts + def relocate_rep_num(s: str) -> str: + """ + Use the replicate as a suffix instead of prefix + """ + if "#" not in s: + return s + prefix, samplename = s.split("#") + rep_matches = re.findall(f"_rep[0-9]+$", prefix) + if rep_matches: + rep_match = rep_matches.pop() + # Reps are 1 indexed, names are 0 indexed + num = int(rep_match.strip("_rep")) - 1 + assert num >= 0, f"Error when processing {s}" + return samplename + f"-{num}" + else: + return samplename + + def drop_extra_dash(s: str) -> str: + """This may cause issues but it seems to be fine for now""" + tokens = s.split("-") + return "-".join(tokens[:2]) + + retval = [relocate_rep_num(n) for n in names] + retval = [drop_extra_dash(n) for n in retval] + if not utils.is_all_unique(retval): + logging.warning("Got duplicated names after sanitization") + return retval + + +def build_parser(): + """Build a simple commandline parser""" + parser = argparse.ArgumentParser( + usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("x_rna", type=str, help="X axis RNA data") + parser.add_argument("y_rna", type=str, help="Y axis RNA data") + parser.add_argument( + "--outfname", type=str, default="", required=False, help="Filename to save plot" + ) + parser.add_argument( + "--subset", "-s", type=int, default=100000, help="Subset amount (0 to disable)" + ) + parser.add_argument( + "-g", "--genelist", type=str, default="", help="File containing list to plot" + ) + parser.add_argument( + "--linear", + action="store_true", + help="Plot in linear space instead of log space", + ) + parser.add_argument( + "--density", + action="store_true", + help="Plot density scatterplot instead of individual points", + ) + parser.add_argument( + "--densitylogstretch", + type=int, + default=1000, + help="Density logstretch for image normalization", + ) + parser.add_argument("--title", "-t", type=str, default="") + parser.add_argument("--xlabel", type=str, default="Original norm counts") + parser.add_argument("--ylabel", type=str, default="Inferred norm counts") + parser.add_argument( + "--figsize", type=float, nargs=2, default=(7, 5), help="Figure size" + ) + return parser + + +def main(): + parser = build_parser() + args = parser.parse_args() + + if args.x_rna.endswith(".h5ad"): + x_rna = ad.read_h5ad(args.x_rna) + elif args.x_rna.endswith(".h5"): + x_rna = sc.read_10x_h5(args.x_rna, gex_only=False) + else: + raise ValueError(f"Unrecognized file extension: {args.x_rna}") + x_rna.X = utils.ensure_arr(x_rna.X) + x_rna.obs_names = sanitize_obs_names(x_rna.obs_names) + x_rna.obs_names_make_unique() + logging.info(f"Read in {args.x_rna} for {x_rna.shape}") + + if args.y_rna.endswith(".h5ad"): + y_rna = ad.read_h5ad(args.y_rna) + elif args.y_rna.endswith(".h5"): + y_rna = sc.read_10x_h5(args.y_rna, gex_only=False) + else: + raise ValueError(f"Unrecognized file extension: {args.y_rna}") + y_rna.X = utils.ensure_arr(y_rna.X) + y_rna.obs_names = sanitize_obs_names(y_rna.obs_names) + y_rna.obs_names_make_unique() + logging.info(f"Read in {args.y_rna} for {y_rna.shape}") + + if not ( + len(x_rna.obs_names) == len(y_rna.obs_names) + and np.all(x_rna.obs_names == y_rna.obs_names) + ): + logging.warning("Rematching obs axis") + shared_obs_names = sorted( + list(set(x_rna.obs_names).intersection(y_rna.obs_names)) + ) + logging.info(f"Found {len(shared_obs_names)} shared obs") + assert shared_obs_names, ( + "Got empty list of shared obs" + + "\n" + + str(x_rna.obs_names) + + "\n" + + str(y_rna.obs_names) + ) + x_rna = x_rna[shared_obs_names] + y_rna = y_rna[shared_obs_names] + assert np.all(x_rna.obs_names == y_rna.obs_names) + if not ( + len(x_rna.var_names) == len(y_rna.var_names) + and np.all(x_rna.var_names == y_rna.var_names) + ): + logging.warning("Rematching variable axis") + shared_var_names = sorted( + list(set(x_rna.var_names).intersection(y_rna.var_names)) + ) + logging.info(f"Found {len(shared_var_names)} shared variables") + assert shared_var_names, ( + "Got empty list of shared vars" + + "\n" + + str(x_rna.var_names) + + "\n" + + str(y_rna.var_names) + ) + x_rna = x_rna[:, shared_var_names] + y_rna = y_rna[:, shared_var_names] + assert np.all(x_rna.var_names == y_rna.var_names) + + # Subset by gene list if given + if args.genelist: + gene_list = utils.read_delimited_file(args.genelist) + logging.info(f"Read {len(gene_list)} genes from {args.genelist}") + x_rna = x_rna[:, gene_list] + y_rna = y_rna[:, gene_list] + + assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}" + + fig = plot_utils.plot_scatter_with_r( + x_rna.X, + y_rna.X, + subset=args.subset, + one_to_one=True, + logscale=not args.linear, + density_heatmap=args.density, + density_logstretch=args.densitylogstretch, + fname=args.outfname, + title=args.title, + xlabel=args.xlabel, + ylabel=args.ylabel, + figsize=args.figsize, + ) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() + main()