Diff of /bin/entropy.py [000000] .. [d01132]

Switch to unified view

a b/bin/entropy.py
1
"""
2
Calculate the entropy of the given h5ad
3
"""
4
5
import os
6
import sys
7
import argparse
8
import logging
9
import numpy as np
10
from entropy_estimators import continuous as ce
11
from pyitlib import discrete_random_variable as drv
12
import scipy.stats
13
14
import anndata as ad
15
16
SRC_DIR = os.path.join(
17
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel",
18
)
19
assert os.path.isdir(SRC_DIR)
20
sys.path.append(SRC_DIR)
21
import utils
22
23
logging.basicConfig(level=logging.INFO)
24
25
"""
26
NOTES
27
Estimator for continuous variables
28
https://github.com/paulbrodersen/entropy_estimators
29
- This estimator is NOT symmetric
30
>>> x = np.random.randn(3000, 30000)
31
>>> x
32
array([[ 1.01757666,  0.14706194,  0.17207894, ..., -0.5776106 ,
33
         1.27110965, -0.80688082],
34
       [-0.46566731, -1.65503883,  0.34362236, ..., -0.56790773,
35
         1.58161324,  0.6875425 ],
36
       [ 0.21598618,  0.15462247, -0.66670242, ..., -1.28547741,
37
        -0.1731192 ,  0.19815154],
38
       ...,
39
       [ 0.30699781,  0.24104934,  0.30279376, ...,  1.95658979,
40
         0.78125961,  0.26259683],
41
       [-1.94023222, -0.79838041, -0.10267371, ..., -0.67825156,
42
         0.75047044,  0.773398  ],
43
       [ 0.73951081,  0.3485434 , -0.17277407, ..., -0.32622845,
44
        -0.59264903,  1.27659335]])
45
>>> x.shape
46
(3000, 30000)
47
>>> h = continuous.get_h(x)
48
>>> h
49
69901.37779787864
50
>>> h = continuous.get_h(x.T)
51
>>> h
52
6346.646780095286
53
54
(Simple) estimator for discrete variables
55
https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
56
57
For a binary variable, we can calculate (base e) entropy in several ways
58
We can specify a torch distribution, and get entory per dimension
59
We can ask scipy to calculate this from an input of unnormalized probs
60
Both give us the same results
61
>>> b = torch.distributions.bernoulli.Bernoulli(torch.tensor([0.1, 0.9, 0.00001, 0.5]))
62
>>> b.entropy()
63
tensor([3.2508e-01, 3.2508e-01, 1.2541e-04, 6.9315e-01])
64
>>> scipy.stats.entropy([0.1, 0.9])
65
0.3250829733914482
66
>>> scipy.stats.entropy([1, 1])  # scipy normalizes the input probs
67
0.6931471805599453
68
69
Another estimator for discrete variables
70
https://github.com/pafoster/pyitlib
71
- This supports calculation of joint entropy
72
>>> x
73
array([[1, 1, 1, 0],
74
       [0, 0, 0, 1]])
75
>>> drv.entropy_joint(x)
76
0.8112781244591328
77
>>> drv.entropy_joint(x.T)
78
1.0
79
"""
80
81
82
def build_parser():
83
    """Build CLI parser"""
84
    parser = argparse.ArgumentParser(
85
        description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
86
    )
87
    parser.add_argument("h5ad", type=str, help=".h5ad file to evaluate entropy for")
88
    parser.add_argument(
89
        "--discrete", action="store_true", help="Use discrete calculation for entropy"
90
    )
91
    return parser
92
93
94
def main():
95
    """Run script"""
96
    parser = build_parser()
97
    args = parser.parse_args()
98
99
    adata = ad.read_h5ad(args.h5ad)
100
    logging.info(f"Read {args.h5ad} for adata of {adata.shape}")
101
102
    if args.discrete:
103
        # Use the discrete algorithm from pyitlib
104
        # https://pafoster.github.io/pyitlib/#discrete_random_variable.entropy_joint
105
        # https://github.com/pafoster/pyitlib/blob/master/pyitlib/discrete_random_variable.py#L3535
106
        # Successive realisations of a random variable are indexed by the last axis in the array; multiple random variables may be specified using preceding axes.
107
        # In other words, different variables are axis 0, samples are axis 1
108
        # This is contrary to the default ML format which is samples axis 0, variables axes 1
109
        # Therefore we must transpose
110
        input_arr = utils.ensure_arr(adata.X).T
111
        h = drv.entropy_joint(input_arr, base=np.e)
112
        logging.info(f"Found discrete joint entropy of {h:.6f}")
113
    else:
114
        raise NotImplementedError
115
116
117
if __name__ == "__main__":
118
    main()