--- a +++ b/utils.py @@ -0,0 +1,381 @@ +import numpy as np + + +def train_test_split(adata, test_ratio = 0.1,seed=1): + """Splits the adata into a training set and a test set. + Args: + adata: the dataset to be splitted. + test_ratio: ratio of the test data in adata. + seed: random seed. + Returns: + the training set and the test set, both in AnnData format. + """ + + rng = np.random.default_rng(seed=seed) + test_indices = rng.choice(adata.n_obs, size=int(test_ratio * adata.n_obs), replace=False) + train_indices = list(set(range(adata.n_obs)).difference(test_indices)) + train_adata = adata[adata.obs_names[train_indices], :] + test_adata = adata[adata.obs_names[test_indices], :] + + return train_adata, test_adata + +def calc_weight( + epoch: int, + n_epochs: int, + cutoff_ratio: float = 0., + warmup_ratio: float = 1 / 3, + min_weight: float = 0., + max_weight: float = 1e-7 +) -> float: + """Calculates weights. + Args: + epoch: current epoch. + n_epochs: the total number of epochs to train the model. + cutoff_ratio: ratio of cutoff epochs (set weight to zero) and + n_epochs. + warmup_ratio: ratio of warmup epochs and n_epochs. + min_weight: minimum weight. + max_weight: maximum weight. + Returns: + The current weight of the KL term. + """ + + fully_warmup_epoch = n_epochs * warmup_ratio + + if epoch < n_epochs * cutoff_ratio: + return 0. + if warmup_ratio: + return max(min(1., epoch / fully_warmup_epoch) * max_weight, min_weight) + else: + return max_weight + + +import math +import os +import shutil +import sys +import time + +import torch +import torch.distributions as dist +import torch.nn.functional as F + + +# Classes +class Constants(object): + eta = 1e-6 + eps = 1e-8 + log2 = math.log(2) + log2pi = math.log(2 * math.pi) + logceilc = 88 # largest cuda v s.t. exp(v) < inf + logfloorc = -104 # smallest cuda v s.t. exp(v) > 0 + + +# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting +class Logger(object): + def __init__(self, filename, mode="a"): + self.terminal = sys.stdout + self.log = open(filename, mode) + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + +class Timer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.begin = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.elapsed = self.end - self.begin + self.elapsedH = time.gmtime(self.elapsed) + print('====> [{}] Time: {:7.3f}s or {}' + .format(self.name, + self.elapsed, + time.strftime("%H:%M:%S", self.elapsedH))) + + +# Functions +def save_vars(vs, filepath): + """ + Saves variables to the given filepath in a safe manner. + """ + if os.path.exists(filepath): + shutil.copyfile(filepath, '{}.old'.format(filepath)) + torch.save(vs, filepath) + + +def save_model(model, filepath): + """ + To load a saved model, simply use + `model.load_state_dict(torch.load('path-to-saved-model'))`. + """ + save_vars(model.state_dict(), filepath) + # if hasattr(model, 'vaes'): + # for vae in model.vaes: + # fdir, fext = os.path.splitext(filepath) + # save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext) + + +def is_multidata(dataB): + return isinstance(dataB, list) or isinstance(dataB, tuple) + + +def unpack_data(dataB, device='cuda'): + # dataB :: (Tensor, Idx) | [(Tensor, Idx)] + """ Unpacks the data batch object in an appropriate manner to extract data """ + if is_multidata(dataB): + if torch.is_tensor(dataB[0]): + if torch.is_tensor(dataB[1]): + return dataB[0].to(device) # mnist, svhn, cubI + elif is_multidata(dataB[1]): + return dataB[0].to(device), dataB[1][0].to(device) # cubISft + else: + raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1]))) + + elif is_multidata(dataB[0]): + return [d.to(device) for d in list(zip(*dataB))[0]] # mnist-svhn, cubIS + else: + raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0]))) + elif torch.is_tensor(dataB): + return dataB.to(device) + else: + raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB))) + + +def get_mean(d, K=100): + """ + Extract the `mean` parameter for given distribution. + If attribute not available, estimate from samples. + """ + try: + mean = d.mean + except NotImplementedError: + samples = d.rsample(torch.Size([K])) + mean = samples.mean(0) + return mean + + +def log_mean_exp(value, dim=0, keepdim=False): + return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim)) + + +def kl_divergence(d1, d2, K=100): + """Computes closed-form KL if available, else computes a MC estimate.""" + if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY: + return torch.distributions.kl_divergence(d1, d2) + else: + samples = d1.rsample(torch.Size([K])) + return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0) + + +def vade_kld_uni(model, zs): + n_centroids = model.params.n_centroids + gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs) # pi, var_cは get_gammaでConstants.eta足してる + + mu, var = model._qz_x_params + mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids) + var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids) + lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \ + torch.log(var_c) + \ + var_expand / var_c + \ + (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1) # log p(z|c) + lpc = torch.sum(gamma * torch.log(pi), dim=1) # log p(c) #log(pi)が-inf怪しい + lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1) # see VaDE paper # log q(z|x) + lqc_x = torch.sum(gamma * (lgamma), dim=1) # log q(c|x) + + kld = -lpz_c - lpc + lqz_x + lqc_x + + return kld + + +def vade_kld(model, zs, r): + n_centroids = model.params.n_centroids + gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs) # pi, var_cは get_gammaでConstants.eta足してる + + mu, var = model.vaes[r]._qz_x_params + mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids) + var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids) + lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \ + torch.log(var_c) + \ + var_expand / var_c + \ + (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1) # log p(z|c) + lpc = torch.sum(gamma * torch.log(pi), dim=1) # log p(c) #log(pi)が-inf怪しい + lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1) # see VaDE paper # log q(z|x) + lqc_x = torch.sum(gamma * (lgamma), dim=1) # log q(c|x) + + kld = -lpz_c - lpc + lqz_x + lqc_x + + return kld + + +def pdist(sample_1, sample_2, eps=1e-5): + """Compute the matrix of all squared pairwise distances. Code + adapted from the torch-two-sample library (added batching). + You can find the original implementation of this function here: + https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py + + Arguments + --------- + sample_1 : torch.Tensor or Variable + The first sample, should be of shape ``(batch_size, n_1, d)``. + sample_2 : torch.Tensor or Variable + The second sample, should be of shape ``(batch_size, n_2, d)``. + norm : float + The l_p norm to be used. + batched : bool + whether data is batched + + Returns + ------- + torch.Tensor or Variable + Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to + ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" + if len(sample_1.shape) == 2: + sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0) + B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1) + norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True) + norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True) + norms = (norms_1.expand(B, n_1, n_2) + + norms_2.transpose(1, 2).expand(B, n_1, n_2)) + distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2)) + return torch.sqrt(eps + torch.abs(distances_squared)).squeeze() # batch x K x latent + + +def NN_lookup(emb_h, emb, data): + indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0) + # indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze() + return data[indices] + + +class FakeCategorical(dist.Distribution): + support = dist.constraints.real + has_rsample = True + + def __init__(self, locs): + self.logits = locs + self._batch_shape = self.logits.shape + + @property + def mean(self): + return self.logits + + def sample(self, sample_shape=torch.Size()): + with torch.no_grad(): + return self.rsample(sample_shape) + + def rsample(self, sample_shape=torch.Size()): + return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous() + + def log_prob(self, value): + # value of shape (K, B, D) + lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)), + target=value.expand(self.logits.size()[:-1]).long().view(-1), + reduction='none', + ignore_index=0) + + return lpx_z.view(*self.logits.shape[:-1]) + # it is inevitable to have the word embedding dimension summed up in + # cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the + # operationally equivalence here, which is summing up the sentence dimension + # in objective. + + +# from github Bjarten/early-stopping-pytorch +import numpy as np +import torch + + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + + def __init__(self, patience=7, verbose=False, delta=0): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + + def __call__(self, val_loss, model, runPath): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, runPath) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, runPath) # runPath追加 + self.counter = 0 + + def save_checkpoint(self, val_loss, model, runPath): + '''Saves model when validation loss decrease.''' + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + # torch.save(model.state_dict(), 'checkpoint.pt') + save_model(model, runPath + '/model.rar') # mmvaeより移植 + self.val_loss_min = val_loss + + +class EarlyStopping_nosave: + """Early stops the training if validation loss doesn't improve after a given patience.""" + + def __init__(self, patience=7, verbose=False, delta=0): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = -1e9 + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + + def __call__(self, val_loss, model, runPath): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.counter = 0 \ No newline at end of file