Diff of /bin/match_closest_cell.py [000000] .. [d01132]

Switch to unified view

a b/bin/match_closest_cell.py
1
"""
2
Script for calculating pairwise distances between 2 adatas
3
"""
4
import os
5
import sys
6
import logging
7
import argparse
8
import json
9
10
import numpy as np
11
12
from evaluate_bulk_rna_concordance import load_file_flex_format
13
14
SRC_DIR = os.path.join(
15
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
16
)
17
assert os.path.isdir(SRC_DIR)
18
sys.path.append(SRC_DIR)
19
import adata_utils
20
import metrics
21
22
logging.basicConfig(level=logging.INFO)
23
24
25
def build_parser():
26
    """Build CLI parser"""
27
    parser = argparse.ArgumentParser(
28
        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
29
    )
30
    parser.add_argument("adata1", type=str, help="First adata object to compare")
31
    parser.add_argument("adata2", type=str, help="Second adata object to compare")
32
    parser.add_argument(
33
        "--output", "-o", type=str, default="", help="Json file to write matches to"
34
    )
35
    parser.add_argument(
36
        "--method", "-m", type=str, choices=["euclidean", "cosine"], default="euclidean"
37
    )
38
    parser.add_argument(
39
        "--log",
40
        "-l",
41
        action="store_true",
42
        help="Log transform before computing distance",
43
    )
44
    parser.add_argument(
45
        "--numtop",
46
        "-n",
47
        type=int,
48
        default=0,
49
        help="Number of top matches to report. 0 indicates reporting all in descending order",
50
    )
51
    return parser
52
53
54
def main():
55
    """Run script"""
56
    parser = build_parser()
57
    args = parser.parse_args()
58
59
    # Compute distances
60
    x = load_file_flex_format(args.adata1)
61
    logging.info(f"Loaded in {args.adata1} for {x.shape}")
62
    y = load_file_flex_format(args.adata2)
63
    logging.info(f"Loaded in {args.adata2} for {y.shape}")
64
65
    # Log, because often times in this project the output is actually linear space
66
    # and comparing expression is typically done in log space
67
    if args.log:
68
        logging.info("Log transforming inputs")
69
        x.X = np.log1p(x.X)
70
        y.X = np.log1p(y.X)
71
72
    pairwise_dist = adata_utils.evaluate_pairwise_cell_distance(
73
        x, y, method=args.method
74
    )
75
    if args.numtop == 0:
76
        args.numtop = y.n_obs
77
78
    # Figure out the top few matches per cell
79
    matches = {}
80
    for i, row in pairwise_dist.iterrows():
81
        sorted_idx = np.argsort(row.values)
82
        cell_matches = pairwise_dist.columns[sorted_idx[: args.numtop]]
83
        matches[i] = list(cell_matches)
84
    if args.output:
85
        assert args.output.endswith(".json")
86
        with open(args.output, "w") as sink:
87
            logging.info(f"Writing matches to {args.output}")
88
            json.dump(matches, sink, indent=4)
89
90
    # Report Top N accuracy if relevant
91
    if x.n_obs == y.n_obs and np.all(x.obs_names == y.obs_names):
92
        n = args.numtop if args.numtop < y.n_obs else 10
93
        acc = metrics.top_n_accuracy(matches.values(), matches.keys(), n=n)
94
        logging.info(f"Top {n} accuracy: {acc:.4f}")
95
96
97
if __name__ == "__main__":
98
    main()