a b/bin/plot_pca.py
1
"""
2
Basic script for pltoting pca of a dataset(s)
3
"""
4
5
import os
6
import sys
7
import argparse
8
import logging
9
import itertools
10
11
import numpy as np
12
import matplotlib.pyplot as plt
13
14
SRC_DIR = os.path.join(
15
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
16
)
17
assert os.path.isdir(SRC_DIR)
18
sys.path.append(SRC_DIR)
19
import interpretation
20
import adata_utils
21
import plot_utils
22
import utils
23
24
from evaluate_bulk_rna_concordance import load_file_flex_format
25
26
27
def var_barplot(var: np.ndarray, fname: str = ""):
28
    """Basic barplot of explained variance"""
29
    fig, ax = plt.subplots(dpi=300)
30
    ax.bar(np.arange(len(var)), var)
31
    ax.set(
32
        xlabel=f"Principal component",
33
        ylabel="Explained variance",
34
        title=f"Top {len(var)} PCs ({np.sum(var):.4f} explained variance)",
35
    )
36
    if fname:
37
        fig.savefig(fname, bbox_inches="tight")
38
    return fig
39
40
41
def build_parser():
42
    """Build a basic CLI parser"""
43
    parser = argparse.ArgumentParser()
44
    parser.add_argument("adata_fname", type=str, help="Adata object to plot PCA for")
45
    parser.add_argument("plot_prefix", type=str, help="Prefix to save plots to")
46
    parser.add_argument(
47
        "--numdims", "-n", type=int, default=16, help="Number of top PCs to consider"
48
    )
49
    return parser
50
51
52
def main():
53
    """Run script"""
54
    parser = build_parser()
55
    args = parser.parse_args()
56
57
    adata = load_file_flex_format(args.adata_fname)
58
59
    var = adata.uns["pca"]["variance_ratio"][: args.numdims]
60
    var_barplot(var, fname=args.plot_prefix + "_explained_var.pdf")
61
62
63
if __name__ == "__main__":
64
    main()