|
a |
|
b/bin/match_closest_cell.py |
|
|
1 |
""" |
|
|
2 |
Script for calculating pairwise distances between 2 adatas |
|
|
3 |
""" |
|
|
4 |
import os |
|
|
5 |
import sys |
|
|
6 |
import logging |
|
|
7 |
import argparse |
|
|
8 |
import json |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
|
|
|
12 |
from evaluate_bulk_rna_concordance import load_file_flex_format |
|
|
13 |
|
|
|
14 |
SRC_DIR = os.path.join( |
|
|
15 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel", |
|
|
16 |
) |
|
|
17 |
assert os.path.isdir(SRC_DIR) |
|
|
18 |
sys.path.append(SRC_DIR) |
|
|
19 |
import adata_utils |
|
|
20 |
import metrics |
|
|
21 |
|
|
|
22 |
logging.basicConfig(level=logging.INFO) |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def build_parser(): |
|
|
26 |
"""Build CLI parser""" |
|
|
27 |
parser = argparse.ArgumentParser( |
|
|
28 |
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
29 |
) |
|
|
30 |
parser.add_argument("adata1", type=str, help="First adata object to compare") |
|
|
31 |
parser.add_argument("adata2", type=str, help="Second adata object to compare") |
|
|
32 |
parser.add_argument( |
|
|
33 |
"--output", "-o", type=str, default="", help="Json file to write matches to" |
|
|
34 |
) |
|
|
35 |
parser.add_argument( |
|
|
36 |
"--method", "-m", type=str, choices=["euclidean", "cosine"], default="euclidean" |
|
|
37 |
) |
|
|
38 |
parser.add_argument( |
|
|
39 |
"--log", |
|
|
40 |
"-l", |
|
|
41 |
action="store_true", |
|
|
42 |
help="Log transform before computing distance", |
|
|
43 |
) |
|
|
44 |
parser.add_argument( |
|
|
45 |
"--numtop", |
|
|
46 |
"-n", |
|
|
47 |
type=int, |
|
|
48 |
default=0, |
|
|
49 |
help="Number of top matches to report. 0 indicates reporting all in descending order", |
|
|
50 |
) |
|
|
51 |
return parser |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def main(): |
|
|
55 |
"""Run script""" |
|
|
56 |
parser = build_parser() |
|
|
57 |
args = parser.parse_args() |
|
|
58 |
|
|
|
59 |
# Compute distances |
|
|
60 |
x = load_file_flex_format(args.adata1) |
|
|
61 |
logging.info(f"Loaded in {args.adata1} for {x.shape}") |
|
|
62 |
y = load_file_flex_format(args.adata2) |
|
|
63 |
logging.info(f"Loaded in {args.adata2} for {y.shape}") |
|
|
64 |
|
|
|
65 |
# Log, because often times in this project the output is actually linear space |
|
|
66 |
# and comparing expression is typically done in log space |
|
|
67 |
if args.log: |
|
|
68 |
logging.info("Log transforming inputs") |
|
|
69 |
x.X = np.log1p(x.X) |
|
|
70 |
y.X = np.log1p(y.X) |
|
|
71 |
|
|
|
72 |
pairwise_dist = adata_utils.evaluate_pairwise_cell_distance( |
|
|
73 |
x, y, method=args.method |
|
|
74 |
) |
|
|
75 |
if args.numtop == 0: |
|
|
76 |
args.numtop = y.n_obs |
|
|
77 |
|
|
|
78 |
# Figure out the top few matches per cell |
|
|
79 |
matches = {} |
|
|
80 |
for i, row in pairwise_dist.iterrows(): |
|
|
81 |
sorted_idx = np.argsort(row.values) |
|
|
82 |
cell_matches = pairwise_dist.columns[sorted_idx[: args.numtop]] |
|
|
83 |
matches[i] = list(cell_matches) |
|
|
84 |
if args.output: |
|
|
85 |
assert args.output.endswith(".json") |
|
|
86 |
with open(args.output, "w") as sink: |
|
|
87 |
logging.info(f"Writing matches to {args.output}") |
|
|
88 |
json.dump(matches, sink, indent=4) |
|
|
89 |
|
|
|
90 |
# Report Top N accuracy if relevant |
|
|
91 |
if x.n_obs == y.n_obs and np.all(x.obs_names == y.obs_names): |
|
|
92 |
n = args.numtop if args.numtop < y.n_obs else 10 |
|
|
93 |
acc = metrics.top_n_accuracy(matches.values(), matches.keys(), n=n) |
|
|
94 |
logging.info(f"Top {n} accuracy: {acc:.4f}") |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
if __name__ == "__main__": |
|
|
98 |
main() |