a b/bin/plot_rna_scatter.py
1
"""
2
Short script to plot RNA scatterplots
3
"""
4
5
import os
6
import sys
7
import re
8
import logging
9
import argparse
10
from typing import *
11
12
import numpy as np
13
import scipy
14
import anndata as ad
15
import scanpy as sc
16
import matplotlib.pyplot as plt
17
18
SRC_DIR = os.path.join(
19
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
20
)
21
assert os.path.isdir(SRC_DIR)
22
sys.path.append(SRC_DIR)
23
import plot_utils
24
import utils
25
26
27
logging.basicConfig(level=logging.INFO)
28
29
30
def sanitize_obs_names(names: List[str]) -> List[str]:
31
    """
32
    Sanitize the obs names
33
    >>> sanitize_obs_names(['a', 'b'])
34
    ['a', 'b']
35
    >>> sanitize_obs_names(['foo#a', 'bar#b'])
36
    ['a', 'b']
37
    >>> sanitize_obs_names(['10xPBMC#TAAGTGCAGCGCACAA-1', '10xPBMC#AGCTATGTCTATCTTG-1'])
38
    ['TAAGTGCAGCGCACAA-1', 'AGCTATGTCTATCTTG-1']
39
    """
40
    # Strips out the prefix that archr inserts
41
    def relocate_rep_num(s: str) -> str:
42
        """
43
        Use the replicate as a suffix instead of prefix
44
        """
45
        if "#" not in s:
46
            return s
47
        prefix, samplename = s.split("#")
48
        rep_matches = re.findall(f"_rep[0-9]+$", prefix)
49
        if rep_matches:
50
            rep_match = rep_matches.pop()
51
            # Reps are 1 indexed, names are 0 indexed
52
            num = int(rep_match.strip("_rep")) - 1
53
            assert num >= 0, f"Error when processing {s}"
54
            return samplename + f"-{num}"
55
        else:
56
            return samplename
57
58
    def drop_extra_dash(s: str) -> str:
59
        """This may cause issues but it seems to be fine for now"""
60
        tokens = s.split("-")
61
        return "-".join(tokens[:2])
62
63
    retval = [relocate_rep_num(n) for n in names]
64
    retval = [drop_extra_dash(n) for n in retval]
65
    if not utils.is_all_unique(retval):
66
        logging.warning("Got duplicated names after sanitization")
67
    return retval
68
69
70
def build_parser():
71
    """Build a simple commandline parser"""
72
    parser = argparse.ArgumentParser(
73
        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
74
    )
75
    parser.add_argument("x_rna", type=str, help="X axis RNA data")
76
    parser.add_argument("y_rna", type=str, help="Y axis RNA data")
77
    parser.add_argument(
78
        "--outfname", type=str, default="", required=False, help="Filename to save plot"
79
    )
80
    parser.add_argument(
81
        "--subset", "-s", type=int, default=100000, help="Subset amount (0 to disable)"
82
    )
83
    parser.add_argument(
84
        "-g", "--genelist", type=str, default="", help="File containing list to plot"
85
    )
86
    parser.add_argument(
87
        "--linear",
88
        action="store_true",
89
        help="Plot in linear space instead of log space",
90
    )
91
    parser.add_argument(
92
        "--density",
93
        action="store_true",
94
        help="Plot density scatterplot instead of individual points",
95
    )
96
    parser.add_argument(
97
        "--densitylogstretch",
98
        type=int,
99
        default=1000,
100
        help="Density logstretch for image normalization",
101
    )
102
    parser.add_argument("--title", "-t", type=str, default="")
103
    parser.add_argument("--xlabel", type=str, default="Original norm counts")
104
    parser.add_argument("--ylabel", type=str, default="Inferred norm counts")
105
    parser.add_argument(
106
        "--figsize", type=float, nargs=2, default=(7, 5), help="Figure size"
107
    )
108
    return parser
109
110
111
def main():
112
    parser = build_parser()
113
    args = parser.parse_args()
114
115
    if args.x_rna.endswith(".h5ad"):
116
        x_rna = ad.read_h5ad(args.x_rna)
117
    elif args.x_rna.endswith(".h5"):
118
        x_rna = sc.read_10x_h5(args.x_rna, gex_only=False)
119
    else:
120
        raise ValueError(f"Unrecognized file extension: {args.x_rna}")
121
    x_rna.X = utils.ensure_arr(x_rna.X)
122
    x_rna.obs_names = sanitize_obs_names(x_rna.obs_names)
123
    x_rna.obs_names_make_unique()
124
    logging.info(f"Read in {args.x_rna} for {x_rna.shape}")
125
126
    if args.y_rna.endswith(".h5ad"):
127
        y_rna = ad.read_h5ad(args.y_rna)
128
    elif args.y_rna.endswith(".h5"):
129
        y_rna = sc.read_10x_h5(args.y_rna, gex_only=False)
130
    else:
131
        raise ValueError(f"Unrecognized file extension: {args.y_rna}")
132
    y_rna.X = utils.ensure_arr(y_rna.X)
133
    y_rna.obs_names = sanitize_obs_names(y_rna.obs_names)
134
    y_rna.obs_names_make_unique()
135
    logging.info(f"Read in {args.y_rna} for {y_rna.shape}")
136
137
    if not (
138
        len(x_rna.obs_names) == len(y_rna.obs_names)
139
        and np.all(x_rna.obs_names == y_rna.obs_names)
140
    ):
141
        logging.warning("Rematching obs axis")
142
        shared_obs_names = sorted(
143
            list(set(x_rna.obs_names).intersection(y_rna.obs_names))
144
        )
145
        logging.info(f"Found {len(shared_obs_names)} shared obs")
146
        assert shared_obs_names, (
147
            "Got empty list of shared obs"
148
            + "\n"
149
            + str(x_rna.obs_names)
150
            + "\n"
151
            + str(y_rna.obs_names)
152
        )
153
        x_rna = x_rna[shared_obs_names]
154
        y_rna = y_rna[shared_obs_names]
155
    assert np.all(x_rna.obs_names == y_rna.obs_names)
156
    if not (
157
        len(x_rna.var_names) == len(y_rna.var_names)
158
        and np.all(x_rna.var_names == y_rna.var_names)
159
    ):
160
        logging.warning("Rematching variable axis")
161
        shared_var_names = sorted(
162
            list(set(x_rna.var_names).intersection(y_rna.var_names))
163
        )
164
        logging.info(f"Found {len(shared_var_names)} shared variables")
165
        assert shared_var_names, (
166
            "Got empty list of shared vars"
167
            + "\n"
168
            + str(x_rna.var_names)
169
            + "\n"
170
            + str(y_rna.var_names)
171
        )
172
        x_rna = x_rna[:, shared_var_names]
173
        y_rna = y_rna[:, shared_var_names]
174
    assert np.all(x_rna.var_names == y_rna.var_names)
175
176
    # Subset by gene list if given
177
    if args.genelist:
178
        gene_list = utils.read_delimited_file(args.genelist)
179
        logging.info(f"Read {len(gene_list)} genes from {args.genelist}")
180
        x_rna = x_rna[:, gene_list]
181
        y_rna = y_rna[:, gene_list]
182
183
    assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}"
184
185
    fig = plot_utils.plot_scatter_with_r(
186
        x_rna.X,
187
        y_rna.X,
188
        subset=args.subset,
189
        one_to_one=True,
190
        logscale=not args.linear,
191
        density_heatmap=args.density,
192
        density_logstretch=args.densitylogstretch,
193
        fname=args.outfname,
194
        title=args.title,
195
        xlabel=args.xlabel,
196
        ylabel=args.ylabel,
197
        figsize=args.figsize,
198
    )
199
200
201
if __name__ == "__main__":
202
    import doctest
203
204
    doctest.testmod()
205
    main()