|
a |
|
b/bin/plot_rna_scatter.py |
|
|
1 |
""" |
|
|
2 |
Short script to plot RNA scatterplots |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import re |
|
|
8 |
import logging |
|
|
9 |
import argparse |
|
|
10 |
from typing import * |
|
|
11 |
|
|
|
12 |
import numpy as np |
|
|
13 |
import scipy |
|
|
14 |
import anndata as ad |
|
|
15 |
import scanpy as sc |
|
|
16 |
import matplotlib.pyplot as plt |
|
|
17 |
|
|
|
18 |
SRC_DIR = os.path.join( |
|
|
19 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" |
|
|
20 |
) |
|
|
21 |
assert os.path.isdir(SRC_DIR) |
|
|
22 |
sys.path.append(SRC_DIR) |
|
|
23 |
import plot_utils |
|
|
24 |
import utils |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
logging.basicConfig(level=logging.INFO) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def sanitize_obs_names(names: List[str]) -> List[str]: |
|
|
31 |
""" |
|
|
32 |
Sanitize the obs names |
|
|
33 |
>>> sanitize_obs_names(['a', 'b']) |
|
|
34 |
['a', 'b'] |
|
|
35 |
>>> sanitize_obs_names(['foo#a', 'bar#b']) |
|
|
36 |
['a', 'b'] |
|
|
37 |
>>> sanitize_obs_names(['10xPBMC#TAAGTGCAGCGCACAA-1', '10xPBMC#AGCTATGTCTATCTTG-1']) |
|
|
38 |
['TAAGTGCAGCGCACAA-1', 'AGCTATGTCTATCTTG-1'] |
|
|
39 |
""" |
|
|
40 |
# Strips out the prefix that archr inserts |
|
|
41 |
def relocate_rep_num(s: str) -> str: |
|
|
42 |
""" |
|
|
43 |
Use the replicate as a suffix instead of prefix |
|
|
44 |
""" |
|
|
45 |
if "#" not in s: |
|
|
46 |
return s |
|
|
47 |
prefix, samplename = s.split("#") |
|
|
48 |
rep_matches = re.findall(f"_rep[0-9]+$", prefix) |
|
|
49 |
if rep_matches: |
|
|
50 |
rep_match = rep_matches.pop() |
|
|
51 |
# Reps are 1 indexed, names are 0 indexed |
|
|
52 |
num = int(rep_match.strip("_rep")) - 1 |
|
|
53 |
assert num >= 0, f"Error when processing {s}" |
|
|
54 |
return samplename + f"-{num}" |
|
|
55 |
else: |
|
|
56 |
return samplename |
|
|
57 |
|
|
|
58 |
def drop_extra_dash(s: str) -> str: |
|
|
59 |
"""This may cause issues but it seems to be fine for now""" |
|
|
60 |
tokens = s.split("-") |
|
|
61 |
return "-".join(tokens[:2]) |
|
|
62 |
|
|
|
63 |
retval = [relocate_rep_num(n) for n in names] |
|
|
64 |
retval = [drop_extra_dash(n) for n in retval] |
|
|
65 |
if not utils.is_all_unique(retval): |
|
|
66 |
logging.warning("Got duplicated names after sanitization") |
|
|
67 |
return retval |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
def build_parser(): |
|
|
71 |
"""Build a simple commandline parser""" |
|
|
72 |
parser = argparse.ArgumentParser( |
|
|
73 |
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
74 |
) |
|
|
75 |
parser.add_argument("x_rna", type=str, help="X axis RNA data") |
|
|
76 |
parser.add_argument("y_rna", type=str, help="Y axis RNA data") |
|
|
77 |
parser.add_argument( |
|
|
78 |
"--outfname", type=str, default="", required=False, help="Filename to save plot" |
|
|
79 |
) |
|
|
80 |
parser.add_argument( |
|
|
81 |
"--subset", "-s", type=int, default=100000, help="Subset amount (0 to disable)" |
|
|
82 |
) |
|
|
83 |
parser.add_argument( |
|
|
84 |
"-g", "--genelist", type=str, default="", help="File containing list to plot" |
|
|
85 |
) |
|
|
86 |
parser.add_argument( |
|
|
87 |
"--linear", |
|
|
88 |
action="store_true", |
|
|
89 |
help="Plot in linear space instead of log space", |
|
|
90 |
) |
|
|
91 |
parser.add_argument( |
|
|
92 |
"--density", |
|
|
93 |
action="store_true", |
|
|
94 |
help="Plot density scatterplot instead of individual points", |
|
|
95 |
) |
|
|
96 |
parser.add_argument( |
|
|
97 |
"--densitylogstretch", |
|
|
98 |
type=int, |
|
|
99 |
default=1000, |
|
|
100 |
help="Density logstretch for image normalization", |
|
|
101 |
) |
|
|
102 |
parser.add_argument("--title", "-t", type=str, default="") |
|
|
103 |
parser.add_argument("--xlabel", type=str, default="Original norm counts") |
|
|
104 |
parser.add_argument("--ylabel", type=str, default="Inferred norm counts") |
|
|
105 |
parser.add_argument( |
|
|
106 |
"--figsize", type=float, nargs=2, default=(7, 5), help="Figure size" |
|
|
107 |
) |
|
|
108 |
return parser |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
def main(): |
|
|
112 |
parser = build_parser() |
|
|
113 |
args = parser.parse_args() |
|
|
114 |
|
|
|
115 |
if args.x_rna.endswith(".h5ad"): |
|
|
116 |
x_rna = ad.read_h5ad(args.x_rna) |
|
|
117 |
elif args.x_rna.endswith(".h5"): |
|
|
118 |
x_rna = sc.read_10x_h5(args.x_rna, gex_only=False) |
|
|
119 |
else: |
|
|
120 |
raise ValueError(f"Unrecognized file extension: {args.x_rna}") |
|
|
121 |
x_rna.X = utils.ensure_arr(x_rna.X) |
|
|
122 |
x_rna.obs_names = sanitize_obs_names(x_rna.obs_names) |
|
|
123 |
x_rna.obs_names_make_unique() |
|
|
124 |
logging.info(f"Read in {args.x_rna} for {x_rna.shape}") |
|
|
125 |
|
|
|
126 |
if args.y_rna.endswith(".h5ad"): |
|
|
127 |
y_rna = ad.read_h5ad(args.y_rna) |
|
|
128 |
elif args.y_rna.endswith(".h5"): |
|
|
129 |
y_rna = sc.read_10x_h5(args.y_rna, gex_only=False) |
|
|
130 |
else: |
|
|
131 |
raise ValueError(f"Unrecognized file extension: {args.y_rna}") |
|
|
132 |
y_rna.X = utils.ensure_arr(y_rna.X) |
|
|
133 |
y_rna.obs_names = sanitize_obs_names(y_rna.obs_names) |
|
|
134 |
y_rna.obs_names_make_unique() |
|
|
135 |
logging.info(f"Read in {args.y_rna} for {y_rna.shape}") |
|
|
136 |
|
|
|
137 |
if not ( |
|
|
138 |
len(x_rna.obs_names) == len(y_rna.obs_names) |
|
|
139 |
and np.all(x_rna.obs_names == y_rna.obs_names) |
|
|
140 |
): |
|
|
141 |
logging.warning("Rematching obs axis") |
|
|
142 |
shared_obs_names = sorted( |
|
|
143 |
list(set(x_rna.obs_names).intersection(y_rna.obs_names)) |
|
|
144 |
) |
|
|
145 |
logging.info(f"Found {len(shared_obs_names)} shared obs") |
|
|
146 |
assert shared_obs_names, ( |
|
|
147 |
"Got empty list of shared obs" |
|
|
148 |
+ "\n" |
|
|
149 |
+ str(x_rna.obs_names) |
|
|
150 |
+ "\n" |
|
|
151 |
+ str(y_rna.obs_names) |
|
|
152 |
) |
|
|
153 |
x_rna = x_rna[shared_obs_names] |
|
|
154 |
y_rna = y_rna[shared_obs_names] |
|
|
155 |
assert np.all(x_rna.obs_names == y_rna.obs_names) |
|
|
156 |
if not ( |
|
|
157 |
len(x_rna.var_names) == len(y_rna.var_names) |
|
|
158 |
and np.all(x_rna.var_names == y_rna.var_names) |
|
|
159 |
): |
|
|
160 |
logging.warning("Rematching variable axis") |
|
|
161 |
shared_var_names = sorted( |
|
|
162 |
list(set(x_rna.var_names).intersection(y_rna.var_names)) |
|
|
163 |
) |
|
|
164 |
logging.info(f"Found {len(shared_var_names)} shared variables") |
|
|
165 |
assert shared_var_names, ( |
|
|
166 |
"Got empty list of shared vars" |
|
|
167 |
+ "\n" |
|
|
168 |
+ str(x_rna.var_names) |
|
|
169 |
+ "\n" |
|
|
170 |
+ str(y_rna.var_names) |
|
|
171 |
) |
|
|
172 |
x_rna = x_rna[:, shared_var_names] |
|
|
173 |
y_rna = y_rna[:, shared_var_names] |
|
|
174 |
assert np.all(x_rna.var_names == y_rna.var_names) |
|
|
175 |
|
|
|
176 |
# Subset by gene list if given |
|
|
177 |
if args.genelist: |
|
|
178 |
gene_list = utils.read_delimited_file(args.genelist) |
|
|
179 |
logging.info(f"Read {len(gene_list)} genes from {args.genelist}") |
|
|
180 |
x_rna = x_rna[:, gene_list] |
|
|
181 |
y_rna = y_rna[:, gene_list] |
|
|
182 |
|
|
|
183 |
assert x_rna.shape == y_rna.shape, f"Mismatched shapes {x_rna.shape} {y_rna.shape}" |
|
|
184 |
|
|
|
185 |
fig = plot_utils.plot_scatter_with_r( |
|
|
186 |
x_rna.X, |
|
|
187 |
y_rna.X, |
|
|
188 |
subset=args.subset, |
|
|
189 |
one_to_one=True, |
|
|
190 |
logscale=not args.linear, |
|
|
191 |
density_heatmap=args.density, |
|
|
192 |
density_logstretch=args.densitylogstretch, |
|
|
193 |
fname=args.outfname, |
|
|
194 |
title=args.title, |
|
|
195 |
xlabel=args.xlabel, |
|
|
196 |
ylabel=args.ylabel, |
|
|
197 |
figsize=args.figsize, |
|
|
198 |
) |
|
|
199 |
|
|
|
200 |
|
|
|
201 |
if __name__ == "__main__": |
|
|
202 |
import doctest |
|
|
203 |
|
|
|
204 |
doctest.testmod() |
|
|
205 |
main() |