Diff of /gpsa/util/util.py [000000] .. [5c09f6]

Switch to unified view

a b/gpsa/util/util.py
1
import numpy as np
2
import pandas as pd
3
import numpy.random as npr
4
import torch
5
from scipy.special import xlogy
6
7
8
def rbf_kernel(
9
    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
10
):
11
12
    lengthscale = torch.exp(lengthscale_unconstrained)
13
    output_variance = torch.exp(output_variance_unconstrained)
14
15
    if diag:
16
        diffs = x1 - x2
17
    else:
18
        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
19
20
    K = output_variance * torch.exp(
21
        -0.5 * torch.sum(torch.square(diffs / lengthscale), dim=-1)
22
    )
23
    return K
24
25
26
def rbf_kernel_numpy(x, xp, kernel_params):
27
    output_scale = np.exp(kernel_params[0])
28
    lengthscales = np.exp(kernel_params[1:])
29
    diffs = np.expand_dims(x / lengthscales, 1) - np.expand_dims(xp / lengthscales, 0)
30
    return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2))
31
32
33
def matern12_kernel(
34
    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
35
):
36
37
    lengthscale = torch.exp(lengthscale_unconstrained)
38
    output_variance = torch.exp(output_variance_unconstrained)
39
40
    if diag:
41
        diffs = x1 - x2
42
    else:
43
        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
44
    eps = 1e-10
45
    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
46
47
    return output_variance * torch.exp(-0.5 * dists / lengthscale)
48
49
50
def matern32_kernel(
51
    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
52
):
53
54
    lengthscale = torch.exp(lengthscale_unconstrained)
55
    output_variance = torch.exp(output_variance_unconstrained)
56
57
    if diag:
58
        diffs = x1 - x2
59
    else:
60
        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
61
    eps = 1e-10
62
    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
63
64
    inner_term = np.sqrt(3.0) * dists / lengthscale
65
    K = output_variance * (1 + inner_term) * torch.exp(-inner_term)
66
    return K
67
68
69
def polar_warp(X, r, theta):
70
    return np.array([X[:, 0] + r * np.cos(theta), X[:, 1] + r * np.sin(theta)]).T
71
72
73
def get_st_coordinates(df):
74
    """
75
    Extracts spatial coordinates from ST data with index in 'AxB' type format.
76
77
    Return: pandas dataframe of coordinates
78
    """
79
    coor = []
80
    for spot in df.index:
81
        coordinates = spot.split("x")
82
        coordinates = [float(i) for i in coordinates]
83
        coor.append(coordinates)
84
    return np.array(coor)
85
86
87
def compute_distance(X1, X2):
88
    return np.mean(np.sqrt(np.sum((X1 - X2) ** 2, axis=1)))
89
90
91
def make_pinwheel(
92
    radial_std, tangential_std, num_classes, num_per_class, rate, rs=npr.RandomState(0)
93
):
94
    """Based on code by Ryan P. Adams."""
95
    rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
96
97
    features = rs.randn(num_classes * num_per_class, 2) * np.array(
98
        [radial_std, tangential_std]
99
    )
100
    features[:, 0] += 1
101
    labels = np.repeat(np.arange(num_classes), num_per_class)
102
103
    angles = rads[labels] + rate * np.exp(features[:, 0])
104
    rotations = np.stack(
105
        [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]
106
    )
107
    rotations = np.reshape(rotations.T, (-1, 2, 2))
108
109
    return np.einsum("ti,tij->tj", features, rotations)
110
111
112
class ConvergenceChecker(object):
113
    def __init__(self, span, dtp="float64"):
114
        self.span = span
115
        x = np.arange(span, dtype=dtp)
116
        x -= x.mean()
117
        X = np.column_stack((np.ones(shape=x.shape), x, x**2, x**3))
118
        self.U = np.linalg.svd(X, full_matrices=False)[0]
119
120
    def smooth(self, y):
121
        return self.U @ (self.U.T @ y)
122
123
    def subset(self, y, idx=-1):
124
        span = self.U.shape[0]
125
        lo = idx - span + 1
126
        if idx == -1:
127
            return y[lo:]
128
        else:
129
            return y[lo : (idx + 1)]
130
131
    def relative_change(self, y, idx=-1, smooth=True):
132
        y = self.subset(y, idx=idx)
133
        if smooth:
134
            y = self.smooth(y)
135
        prev = y[-2]
136
        return (y[-1] - prev) / (0.1 + abs(prev))
137
138
    def converged(self, y, tol=1e-4, **kwargs):
139
        return abs(self.relative_change(y, **kwargs)) < tol
140
141
    def relative_change_all(self, y, smooth=True):
142
        n = len(y)
143
        span = self.U.shape[0]
144
        cc = np.tile([np.nan], n)
145
        for i in range(span, n):
146
            cc[i] = self.relative_change(y, idx=i, smooth=smooth)
147
        return cc
148
149
    def converged_all(self, y, tol=1e-4, smooth=True):
150
        cc = self.relative_change_all(y, smooth=smooth)
151
        return np.abs(cc) < tol
152
153
154
# Function for computing size factors
155
def compute_size_factors(m):
156
    # given matrix m with samples in the columns
157
    # compute size factors
158
159
    sz = np.sum(m.values, axis=0)  # column sums (sum of counts in each cell)
160
    lsz = np.log(sz)
161
162
    # make geometric mean of sz be 1 for poisson
163
    sz_poisson = np.exp(lsz - np.mean(lsz))
164
    return sz_poisson
165
166
167
def poisson_deviance(X, sz):
168
169
    LP = X.values / sz  # recycling
170
    # import ipdb; ipdb.set_trace()
171
    LP[LP > 0] = np.log(LP[LP > 0])  # log transform nonzero elements only
172
173
    # Transpose to make features in cols, observations in rows
174
    X = X.T
175
    ll_sat = np.sum(np.multiply(X, LP.T), axis=0)
176
    feature_sums = np.sum(X, axis=0)
177
    ll_null = feature_sums * np.log(feature_sums / np.sum(sz))
178
    return 2 * (ll_sat - ll_null)
179
180
181
def deviance_feature_selection(X):
182
183
    # Remove cells without any counts
184
    X = X[np.sum(X, axis=1) > 0]
185
186
    # Compute size factors
187
    sz = compute_size_factors(X)
188
189
    # Compute deviances
190
    devs = poisson_deviance(X, sz)
191
192
    # Get associated gene names
193
    gene_names = X.index.values
194
195
    assert gene_names.shape[0] == devs.values.shape[0]
196
197
    return devs.values, gene_names
198
199
200
def deviance_residuals(x, theta, mu=None):
201
    """Computes deviance residuals for NB model with a fixed theta"""
202
203
    if mu is None:
204
        counts_sum0 = np.sum(x, axis=0, keepdims=True)
205
        counts_sum1 = np.sum(x, axis=1, keepdims=True)
206
        counts_sum = np.sum(x)
207
        # get residuals
208
        mu = counts_sum1 @ counts_sum0 / counts_sum
209
210
    def remove_negatives(sqrt_term):
211
        negatives_idx = sqrt_term < 0
212
        if np.any(negatives_idx):
213
            n_negatives = np.sum(negatives_idx)
214
            print(
215
                "Setting %u negative sqrt term values to 0 (%f%%)"
216
                % (n_negatives, n_negatives / np.product(sqrt_term.shape))
217
            )
218
            sqrt_term[negatives_idx] = 0
219
220
    if np.isinf(theta):  ### POISSON
221
        x_minus_mu = x - mu
222
        sqrt_term = 2 * (
223
            xlogy(x, x / mu) - x_minus_mu
224
        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
225
        remove_negatives(sqrt_term)
226
        dev = np.sign(x_minus_mu) * np.sqrt(sqrt_term)
227
    else:  ### NEG BIN
228
        x_plus_theta = x + theta
229
        sqrt_term = 2 * (
230
            xlogy(x, x / mu) - (x_plus_theta) * np.log(x_plus_theta / (mu + theta))
231
        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
232
        remove_negatives(sqrt_term)
233
        dev = np.sign(x - mu) * np.sqrt(sqrt_term)
234
235
    return dev
236
237
238
def pearson_residuals(counts, theta, clipping=True):
239
    """Computes analytical residuals for NB model with a fixed theta, clipping outlier residuals to sqrt(N)"""
240
    counts_sum0 = np.sum(counts, axis=0, keepdims=True)
241
    counts_sum1 = np.sum(counts, axis=1, keepdims=True)
242
    counts_sum = np.sum(counts)
243
244
    # get residuals
245
    mu = counts_sum1 @ counts_sum0 / counts_sum
246
    z = (counts - mu) / np.sqrt(mu + mu**2 / theta)
247
248
    # clip to sqrt(n)
249
    if clipping:
250
        n = counts.shape[0]
251
        z[z > np.sqrt(n)] = np.sqrt(n)
252
        z[z < -np.sqrt(n)] = -np.sqrt(n)
253
254
    return z
255
256
257
class LossNotDecreasingChecker:
258
    def __init__(self, max_epochs, atol=1e-2, window_size=10):
259
        self.max_epochs = max_epochs
260
        self.atol = atol
261
        self.window_size = window_size
262
        self.decrease_in_loss = np.zeros(max_epochs)
263
        self.average_decrease_in_loss = np.zeros(max_epochs)
264
265
    def check_loss(self, iternum, loss_trace):
266
267
        if iternum >= 1:
268
            self.decrease_in_loss[iternum] = (
269
                loss_trace[iternum - 1] - loss_trace[iternum]
270
            )
271
            if iternum >= self.window_size:
272
                self.average_decrease_in_loss[iternum] = np.mean(
273
                    self.decrease_in_loss[iternum - self.window_size + 1 : iternum]
274
                )
275
                has_converged = self.average_decrease_in_loss[iternum] < self.atol
276
                return has_converged
277
278
        return False