Diff of /utils.py [000000] .. [e72cf6]

Switch to unified view

a b/utils.py
1
import numpy as np
2
3
4
def train_test_split(adata, test_ratio = 0.1,seed=1):
5
    """Splits the adata into a training set and a test set.
6
    Args:
7
        adata: the dataset to be splitted.
8
        test_ratio: ratio of the test data in adata.
9
        seed: random seed.
10
    Returns:
11
        the training set and the test set, both in AnnData format.
12
    """
13
14
    rng = np.random.default_rng(seed=seed)
15
    test_indices = rng.choice(adata.n_obs, size=int(test_ratio * adata.n_obs), replace=False)
16
    train_indices = list(set(range(adata.n_obs)).difference(test_indices))
17
    train_adata = adata[adata.obs_names[train_indices], :]
18
    test_adata = adata[adata.obs_names[test_indices], :]
19
20
    return train_adata, test_adata
21
22
def calc_weight(
23
        epoch: int,
24
        n_epochs: int,
25
        cutoff_ratio: float = 0.,
26
        warmup_ratio: float = 1 / 3,
27
        min_weight: float = 0.,
28
        max_weight: float = 1e-7
29
) -> float:
30
    """Calculates weights.
31
    Args:
32
        epoch: current epoch.
33
        n_epochs: the total number of epochs to train the model.
34
        cutoff_ratio: ratio of cutoff epochs (set weight to zero) and
35
            n_epochs.
36
        warmup_ratio: ratio of warmup epochs and n_epochs.
37
        min_weight: minimum weight.
38
        max_weight: maximum weight.
39
    Returns:
40
        The current weight of the KL term.
41
    """
42
43
    fully_warmup_epoch = n_epochs * warmup_ratio
44
45
    if epoch < n_epochs * cutoff_ratio:
46
        return 0.
47
    if warmup_ratio:
48
        return max(min(1., epoch / fully_warmup_epoch) * max_weight, min_weight)
49
    else:
50
        return max_weight
51
52
53
import math
54
import os
55
import shutil
56
import sys
57
import time
58
59
import torch
60
import torch.distributions as dist
61
import torch.nn.functional as F
62
63
64
# Classes
65
class Constants(object):
66
    eta = 1e-6
67
    eps = 1e-8
68
    log2 = math.log(2)
69
    log2pi = math.log(2 * math.pi)
70
    logceilc = 88  # largest cuda v s.t. exp(v) < inf
71
    logfloorc = -104  # smallest cuda v s.t. exp(v) > 0
72
73
74
# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting
75
class Logger(object):
76
    def __init__(self, filename, mode="a"):
77
        self.terminal = sys.stdout
78
        self.log = open(filename, mode)
79
80
    def write(self, message):
81
        self.terminal.write(message)
82
        self.log.write(message)
83
84
    def flush(self):
85
        # this flush method is needed for python 3 compatibility.
86
        # this handles the flush command by doing nothing.
87
        # you might want to specify some extra behavior here.
88
        pass
89
90
91
class Timer:
92
    def __init__(self, name):
93
        self.name = name
94
95
    def __enter__(self):
96
        self.begin = time.time()
97
        return self
98
99
    def __exit__(self, *args):
100
        self.end = time.time()
101
        self.elapsed = self.end - self.begin
102
        self.elapsedH = time.gmtime(self.elapsed)
103
        print('====> [{}] Time: {:7.3f}s or {}'
104
              .format(self.name,
105
                      self.elapsed,
106
                      time.strftime("%H:%M:%S", self.elapsedH)))
107
108
109
# Functions
110
def save_vars(vs, filepath):
111
    """
112
    Saves variables to the given filepath in a safe manner.
113
    """
114
    if os.path.exists(filepath):
115
        shutil.copyfile(filepath, '{}.old'.format(filepath))
116
    torch.save(vs, filepath)
117
118
119
def save_model(model, filepath):
120
    """
121
    To load a saved model, simply use
122
    `model.load_state_dict(torch.load('path-to-saved-model'))`.
123
    """
124
    save_vars(model.state_dict(), filepath)
125
    # if hasattr(model, 'vaes'):
126
    #    for vae in model.vaes:
127
    #        fdir, fext = os.path.splitext(filepath)
128
    #        save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext)
129
130
131
def is_multidata(dataB):
132
    return isinstance(dataB, list) or isinstance(dataB, tuple)
133
134
135
def unpack_data(dataB, device='cuda'):
136
    # dataB :: (Tensor, Idx) | [(Tensor, Idx)]
137
    """ Unpacks the data batch object in an appropriate manner to extract data """
138
    if is_multidata(dataB):
139
        if torch.is_tensor(dataB[0]):
140
            if torch.is_tensor(dataB[1]):
141
                return dataB[0].to(device)  # mnist, svhn, cubI
142
            elif is_multidata(dataB[1]):
143
                return dataB[0].to(device), dataB[1][0].to(device)  # cubISft
144
            else:
145
                raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1])))
146
147
        elif is_multidata(dataB[0]):
148
            return [d.to(device) for d in list(zip(*dataB))[0]]  # mnist-svhn, cubIS
149
        else:
150
            raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0])))
151
    elif torch.is_tensor(dataB):
152
        return dataB.to(device)
153
    else:
154
        raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB)))
155
156
157
def get_mean(d, K=100):
158
    """
159
    Extract the `mean` parameter for given distribution.
160
    If attribute not available, estimate from samples.
161
    """
162
    try:
163
        mean = d.mean
164
    except NotImplementedError:
165
        samples = d.rsample(torch.Size([K]))
166
        mean = samples.mean(0)
167
    return mean
168
169
170
def log_mean_exp(value, dim=0, keepdim=False):
171
    return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim))
172
173
174
def kl_divergence(d1, d2, K=100):
175
    """Computes closed-form KL if available, else computes a MC estimate."""
176
    if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY:
177
        return torch.distributions.kl_divergence(d1, d2)
178
    else:
179
        samples = d1.rsample(torch.Size([K]))
180
        return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0)
181
182
183
def vade_kld_uni(model, zs):
184
    n_centroids = model.params.n_centroids
185
    gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs)  # pi, var_cは get_gammaでConstants.eta足してる
186
187
    mu, var = model._qz_x_params
188
    mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
189
    var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids)
190
    lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \
191
                                               torch.log(var_c) + \
192
                                               var_expand / var_c + \
193
                                               (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1)  # log p(z|c)
194
    lpc = torch.sum(gamma * torch.log(pi), dim=1)  # log p(c) #log(pi)が-inf怪しい
195
    lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1)  # see VaDE paper # log q(z|x)
196
    lqc_x = torch.sum(gamma * (lgamma), dim=1)  # log q(c|x)
197
198
    kld = -lpz_c - lpc + lqz_x + lqc_x
199
200
    return kld
201
202
203
def vade_kld(model, zs, r):
204
    n_centroids = model.params.n_centroids
205
    gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs)  # pi, var_cは get_gammaでConstants.eta足してる
206
207
    mu, var = model.vaes[r]._qz_x_params
208
    mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
209
    var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids)
210
    lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \
211
                                               torch.log(var_c) + \
212
                                               var_expand / var_c + \
213
                                               (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1)  # log p(z|c)
214
    lpc = torch.sum(gamma * torch.log(pi), dim=1)  # log p(c) #log(pi)が-inf怪しい
215
    lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1)  # see VaDE paper # log q(z|x)
216
    lqc_x = torch.sum(gamma * (lgamma), dim=1)  # log q(c|x)
217
218
    kld = -lpz_c - lpc + lqz_x + lqc_x
219
220
    return kld
221
222
223
def pdist(sample_1, sample_2, eps=1e-5):
224
    """Compute the matrix of all squared pairwise distances. Code
225
    adapted from the torch-two-sample library (added batching).
226
    You can find the original implementation of this function here:
227
    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py
228
229
    Arguments
230
    ---------
231
    sample_1 : torch.Tensor or Variable
232
        The first sample, should be of shape ``(batch_size, n_1, d)``.
233
    sample_2 : torch.Tensor or Variable
234
        The second sample, should be of shape ``(batch_size, n_2, d)``.
235
    norm : float
236
        The l_p norm to be used.
237
    batched : bool
238
        whether data is batched
239
240
    Returns
241
    -------
242
    torch.Tensor or Variable
243
        Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to
244
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
245
    if len(sample_1.shape) == 2:
246
        sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0)
247
    B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1)
248
    norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True)
249
    norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True)
250
    norms = (norms_1.expand(B, n_1, n_2)
251
             + norms_2.transpose(1, 2).expand(B, n_1, n_2))
252
    distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2))
253
    return torch.sqrt(eps + torch.abs(distances_squared)).squeeze()  # batch x K x latent
254
255
256
def NN_lookup(emb_h, emb, data):
257
    indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0)
258
    # indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze()
259
    return data[indices]
260
261
262
class FakeCategorical(dist.Distribution):
263
    support = dist.constraints.real
264
    has_rsample = True
265
266
    def __init__(self, locs):
267
        self.logits = locs
268
        self._batch_shape = self.logits.shape
269
270
    @property
271
    def mean(self):
272
        return self.logits
273
274
    def sample(self, sample_shape=torch.Size()):
275
        with torch.no_grad():
276
            return self.rsample(sample_shape)
277
278
    def rsample(self, sample_shape=torch.Size()):
279
        return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous()
280
281
    def log_prob(self, value):
282
        # value of shape (K, B, D)
283
        lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)),
284
                                 target=value.expand(self.logits.size()[:-1]).long().view(-1),
285
                                 reduction='none',
286
                                 ignore_index=0)
287
288
        return lpx_z.view(*self.logits.shape[:-1])
289
        # it is inevitable to have the word embedding dimension summed up in
290
        # cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the
291
        # operationally equivalence here, which is summing up the sentence dimension
292
        # in objective.
293
294
295
# from github Bjarten/early-stopping-pytorch
296
import numpy as np
297
import torch
298
299
300
class EarlyStopping:
301
    """Early stops the training if validation loss doesn't improve after a given patience."""
302
303
    def __init__(self, patience=7, verbose=False, delta=0):
304
        """
305
        Args:
306
            patience (int): How long to wait after last time validation loss improved.
307
                            Default: 7
308
            verbose (bool): If True, prints a message for each validation loss improvement.
309
                            Default: False
310
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
311
                            Default: 0
312
        """
313
        self.patience = patience
314
        self.verbose = verbose
315
        self.counter = 0
316
        self.best_score = None
317
        self.early_stop = False
318
        self.val_loss_min = np.Inf
319
        self.delta = delta
320
321
    def __call__(self, val_loss, model, runPath):
322
323
        score = -val_loss
324
325
        if self.best_score is None:
326
            self.best_score = score
327
            self.save_checkpoint(val_loss, model, runPath)
328
        elif score < self.best_score + self.delta:
329
            self.counter += 1
330
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
331
            if self.counter >= self.patience:
332
                self.early_stop = True
333
        else:
334
            self.best_score = score
335
            self.save_checkpoint(val_loss, model, runPath)  # runPath追加
336
            self.counter = 0
337
338
    def save_checkpoint(self, val_loss, model, runPath):
339
        '''Saves model when validation loss decrease.'''
340
        if self.verbose:
341
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
342
        # torch.save(model.state_dict(), 'checkpoint.pt')
343
        save_model(model, runPath + '/model.rar')  # mmvaeより移植
344
        self.val_loss_min = val_loss
345
346
347
class EarlyStopping_nosave:
348
    """Early stops the training if validation loss doesn't improve after a given patience."""
349
350
    def __init__(self, patience=7, verbose=False, delta=0):
351
        """
352
        Args:
353
            patience (int): How long to wait after last time validation loss improved.
354
                            Default: 7
355
            verbose (bool): If True, prints a message for each validation loss improvement.
356
                            Default: False
357
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
358
                            Default: 0
359
        """
360
        self.patience = patience
361
        self.verbose = verbose
362
        self.counter = 0
363
        self.best_score = -1e9
364
        self.early_stop = False
365
        self.val_loss_min = np.Inf
366
        self.delta = delta
367
368
    def __call__(self, val_loss, model, runPath):
369
370
        score = -val_loss
371
372
        if self.best_score is None:
373
            self.best_score = score
374
        elif score < self.best_score + self.delta:
375
            self.counter += 1
376
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
377
            if self.counter >= self.patience:
378
                self.early_stop = True
379
        else:
380
            self.best_score = score
381
            self.counter = 0