Diff of /bin/plot_var_vs_acc.py [000000] .. [d01132]

Switch to unified view

a b/bin/plot_var_vs_acc.py
1
"""
2
Plot the variance of a gene versus how accurately we can predict it
3
"""
4
5
import os
6
import sys
7
import argparse
8
import logging
9
import itertools
10
11
SRC_DIR = os.path.join(
12
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
13
)
14
assert os.path.isdir(SRC_DIR)
15
sys.path.append(SRC_DIR)
16
import interpretation
17
import adata_utils
18
import plot_utils
19
import utils
20
21
from evaluate_bulk_rna_concordance import load_file_flex_format
22
23
REF_MARKER_GENES = {
24
    "PBMC": set(
25
        itertools.chain.from_iterable(interpretation.PBMC_MARKER_GENES.values())
26
    ),
27
    "PBMC_Seurat": set(
28
        itertools.chain.from_iterable(interpretation.SEURAT_PBMC_MARKER_GENES.values())
29
    ),
30
    "Housekeeper": utils.read_delimited_file(
31
        os.path.join(os.path.dirname(SRC_DIR), "data", "housekeeper_genes.txt")
32
    ),
33
}
34
35
36
def build_parser():
37
    """Build basic CLI parser"""
38
    parser = argparse.ArgumentParser(
39
        description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
40
    )
41
    parser.add_argument("preds", type=str, help="File with predicted expression")
42
    parser.add_argument("truth", type=str, help="File with ground truth expression")
43
    parser.add_argument("plotname", type=str, help="File to write plot to")
44
    parser.add_argument(
45
        "--genelist",
46
        "-g",
47
        type=str,
48
        default="",
49
        help="File to write outliers in expained variance",
50
    )
51
    parser.add_argument(
52
        "--highlight",
53
        choices=REF_MARKER_GENES.keys(),
54
        nargs="*",
55
        default=["Housekeeper"],
56
        help="HGenes to highlight",
57
    )
58
    parser.add_argument(
59
        "--linear", action="store_true", help="Plot in linear space instead of log"
60
    )
61
    parser.add_argument(
62
        "--unconstriained", action="store_true", help="Do not constrain axes"
63
    )
64
    parser.add_argument("--outliers", action="store_true", help="Label outliers")
65
    return parser
66
67
68
def main():
69
    """Run script"""
70
    parser = build_parser()
71
    args = parser.parse_args()
72
73
    truth = load_file_flex_format(args.truth)
74
    truth.X = utils.ensure_arr(truth.X)
75
    logging.info(f"Loaded truth {args.truth}: {truth.shape}")
76
    preds = load_file_flex_format(args.preds)
77
    preds.X = utils.ensure_arr(preds.X)
78
    logging.info(f"Loaded preds {args.preds}: {preds.shape}")
79
80
    common_genes = sorted(list(set(truth.var_names).intersection(preds.var_names)))
81
    logging.info(f"Shared genes: {len(common_genes)}")
82
83
    common_obs = sorted(list(set(truth.obs_names).intersection(preds.obs_names)))
84
    # All obs naames should intersect between preds and truth
85
    assert len(common_obs) == len(truth.obs_names) == len(preds.obs_names)
86
87
    plot_utils.plot_var_vs_explained_var(
88
        truth,
89
        preds,
90
        highlight_genes={k: REF_MARKER_GENES[k] for k in args.highlight},
91
        logscale=not args.linear,
92
        constrain_y_axis=not args.unconstriained,
93
        label_outliers=args.outliers,
94
        fname=args.plotname,
95
        fname_gene_list=args.genelist,
96
    )
97
98
99
if __name__ == "__main__":
100
    main()