--- a
+++ b/utils.py
@@ -0,0 +1,381 @@
+import numpy as np
+
+
+def train_test_split(adata, test_ratio = 0.1,seed=1):
+    """Splits the adata into a training set and a test set.
+    Args:
+        adata: the dataset to be splitted.
+        test_ratio: ratio of the test data in adata.
+        seed: random seed.
+    Returns:
+        the training set and the test set, both in AnnData format.
+    """
+
+    rng = np.random.default_rng(seed=seed)
+    test_indices = rng.choice(adata.n_obs, size=int(test_ratio * adata.n_obs), replace=False)
+    train_indices = list(set(range(adata.n_obs)).difference(test_indices))
+    train_adata = adata[adata.obs_names[train_indices], :]
+    test_adata = adata[adata.obs_names[test_indices], :]
+
+    return train_adata, test_adata
+
+def calc_weight(
+        epoch: int,
+        n_epochs: int,
+        cutoff_ratio: float = 0.,
+        warmup_ratio: float = 1 / 3,
+        min_weight: float = 0.,
+        max_weight: float = 1e-7
+) -> float:
+    """Calculates weights.
+    Args:
+        epoch: current epoch.
+        n_epochs: the total number of epochs to train the model.
+        cutoff_ratio: ratio of cutoff epochs (set weight to zero) and
+            n_epochs.
+        warmup_ratio: ratio of warmup epochs and n_epochs.
+        min_weight: minimum weight.
+        max_weight: maximum weight.
+    Returns:
+        The current weight of the KL term.
+    """
+
+    fully_warmup_epoch = n_epochs * warmup_ratio
+
+    if epoch < n_epochs * cutoff_ratio:
+        return 0.
+    if warmup_ratio:
+        return max(min(1., epoch / fully_warmup_epoch) * max_weight, min_weight)
+    else:
+        return max_weight
+
+
+import math
+import os
+import shutil
+import sys
+import time
+
+import torch
+import torch.distributions as dist
+import torch.nn.functional as F
+
+
+# Classes
+class Constants(object):
+    eta = 1e-6
+    eps = 1e-8
+    log2 = math.log(2)
+    log2pi = math.log(2 * math.pi)
+    logceilc = 88  # largest cuda v s.t. exp(v) < inf
+    logfloorc = -104  # smallest cuda v s.t. exp(v) > 0
+
+
+# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting
+class Logger(object):
+    def __init__(self, filename, mode="a"):
+        self.terminal = sys.stdout
+        self.log = open(filename, mode)
+
+    def write(self, message):
+        self.terminal.write(message)
+        self.log.write(message)
+
+    def flush(self):
+        # this flush method is needed for python 3 compatibility.
+        # this handles the flush command by doing nothing.
+        # you might want to specify some extra behavior here.
+        pass
+
+
+class Timer:
+    def __init__(self, name):
+        self.name = name
+
+    def __enter__(self):
+        self.begin = time.time()
+        return self
+
+    def __exit__(self, *args):
+        self.end = time.time()
+        self.elapsed = self.end - self.begin
+        self.elapsedH = time.gmtime(self.elapsed)
+        print('====> [{}] Time: {:7.3f}s or {}'
+              .format(self.name,
+                      self.elapsed,
+                      time.strftime("%H:%M:%S", self.elapsedH)))
+
+
+# Functions
+def save_vars(vs, filepath):
+    """
+    Saves variables to the given filepath in a safe manner.
+    """
+    if os.path.exists(filepath):
+        shutil.copyfile(filepath, '{}.old'.format(filepath))
+    torch.save(vs, filepath)
+
+
+def save_model(model, filepath):
+    """
+    To load a saved model, simply use
+    `model.load_state_dict(torch.load('path-to-saved-model'))`.
+    """
+    save_vars(model.state_dict(), filepath)
+    # if hasattr(model, 'vaes'):
+    #    for vae in model.vaes:
+    #        fdir, fext = os.path.splitext(filepath)
+    #        save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext)
+
+
+def is_multidata(dataB):
+    return isinstance(dataB, list) or isinstance(dataB, tuple)
+
+
+def unpack_data(dataB, device='cuda'):
+    # dataB :: (Tensor, Idx) | [(Tensor, Idx)]
+    """ Unpacks the data batch object in an appropriate manner to extract data """
+    if is_multidata(dataB):
+        if torch.is_tensor(dataB[0]):
+            if torch.is_tensor(dataB[1]):
+                return dataB[0].to(device)  # mnist, svhn, cubI
+            elif is_multidata(dataB[1]):
+                return dataB[0].to(device), dataB[1][0].to(device)  # cubISft
+            else:
+                raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1])))
+
+        elif is_multidata(dataB[0]):
+            return [d.to(device) for d in list(zip(*dataB))[0]]  # mnist-svhn, cubIS
+        else:
+            raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0])))
+    elif torch.is_tensor(dataB):
+        return dataB.to(device)
+    else:
+        raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB)))
+
+
+def get_mean(d, K=100):
+    """
+    Extract the `mean` parameter for given distribution.
+    If attribute not available, estimate from samples.
+    """
+    try:
+        mean = d.mean
+    except NotImplementedError:
+        samples = d.rsample(torch.Size([K]))
+        mean = samples.mean(0)
+    return mean
+
+
+def log_mean_exp(value, dim=0, keepdim=False):
+    return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim))
+
+
+def kl_divergence(d1, d2, K=100):
+    """Computes closed-form KL if available, else computes a MC estimate."""
+    if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY:
+        return torch.distributions.kl_divergence(d1, d2)
+    else:
+        samples = d1.rsample(torch.Size([K]))
+        return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0)
+
+
+def vade_kld_uni(model, zs):
+    n_centroids = model.params.n_centroids
+    gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs)  # pi, var_cは get_gammaでConstants.eta足してる
+
+    mu, var = model._qz_x_params
+    mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
+    var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids)
+    lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \
+                                               torch.log(var_c) + \
+                                               var_expand / var_c + \
+                                               (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1)  # log p(z|c)
+    lpc = torch.sum(gamma * torch.log(pi), dim=1)  # log p(c) #log(pi)が-inf怪しい
+    lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1)  # see VaDE paper # log q(z|x)
+    lqc_x = torch.sum(gamma * (lgamma), dim=1)  # log q(c|x)
+
+    kld = -lpz_c - lpc + lqz_x + lqc_x
+
+    return kld
+
+
+def vade_kld(model, zs, r):
+    n_centroids = model.params.n_centroids
+    gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs)  # pi, var_cは get_gammaでConstants.eta足してる
+
+    mu, var = model.vaes[r]._qz_x_params
+    mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
+    var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids)
+    lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \
+                                               torch.log(var_c) + \
+                                               var_expand / var_c + \
+                                               (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1)  # log p(z|c)
+    lpc = torch.sum(gamma * torch.log(pi), dim=1)  # log p(c) #log(pi)が-inf怪しい
+    lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1)  # see VaDE paper # log q(z|x)
+    lqc_x = torch.sum(gamma * (lgamma), dim=1)  # log q(c|x)
+
+    kld = -lpz_c - lpc + lqz_x + lqc_x
+
+    return kld
+
+
+def pdist(sample_1, sample_2, eps=1e-5):
+    """Compute the matrix of all squared pairwise distances. Code
+    adapted from the torch-two-sample library (added batching).
+    You can find the original implementation of this function here:
+    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py
+
+    Arguments
+    ---------
+    sample_1 : torch.Tensor or Variable
+        The first sample, should be of shape ``(batch_size, n_1, d)``.
+    sample_2 : torch.Tensor or Variable
+        The second sample, should be of shape ``(batch_size, n_2, d)``.
+    norm : float
+        The l_p norm to be used.
+    batched : bool
+        whether data is batched
+
+    Returns
+    -------
+    torch.Tensor or Variable
+        Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to
+        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
+    if len(sample_1.shape) == 2:
+        sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0)
+    B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1)
+    norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True)
+    norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True)
+    norms = (norms_1.expand(B, n_1, n_2)
+             + norms_2.transpose(1, 2).expand(B, n_1, n_2))
+    distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2))
+    return torch.sqrt(eps + torch.abs(distances_squared)).squeeze()  # batch x K x latent
+
+
+def NN_lookup(emb_h, emb, data):
+    indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0)
+    # indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze()
+    return data[indices]
+
+
+class FakeCategorical(dist.Distribution):
+    support = dist.constraints.real
+    has_rsample = True
+
+    def __init__(self, locs):
+        self.logits = locs
+        self._batch_shape = self.logits.shape
+
+    @property
+    def mean(self):
+        return self.logits
+
+    def sample(self, sample_shape=torch.Size()):
+        with torch.no_grad():
+            return self.rsample(sample_shape)
+
+    def rsample(self, sample_shape=torch.Size()):
+        return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous()
+
+    def log_prob(self, value):
+        # value of shape (K, B, D)
+        lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)),
+                                 target=value.expand(self.logits.size()[:-1]).long().view(-1),
+                                 reduction='none',
+                                 ignore_index=0)
+
+        return lpx_z.view(*self.logits.shape[:-1])
+        # it is inevitable to have the word embedding dimension summed up in
+        # cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the
+        # operationally equivalence here, which is summing up the sentence dimension
+        # in objective.
+
+
+# from github Bjarten/early-stopping-pytorch
+import numpy as np
+import torch
+
+
+class EarlyStopping:
+    """Early stops the training if validation loss doesn't improve after a given patience."""
+
+    def __init__(self, patience=7, verbose=False, delta=0):
+        """
+        Args:
+            patience (int): How long to wait after last time validation loss improved.
+                            Default: 7
+            verbose (bool): If True, prints a message for each validation loss improvement.
+                            Default: False
+            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
+                            Default: 0
+        """
+        self.patience = patience
+        self.verbose = verbose
+        self.counter = 0
+        self.best_score = None
+        self.early_stop = False
+        self.val_loss_min = np.Inf
+        self.delta = delta
+
+    def __call__(self, val_loss, model, runPath):
+
+        score = -val_loss
+
+        if self.best_score is None:
+            self.best_score = score
+            self.save_checkpoint(val_loss, model, runPath)
+        elif score < self.best_score + self.delta:
+            self.counter += 1
+            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.best_score = score
+            self.save_checkpoint(val_loss, model, runPath)  # runPath追加
+            self.counter = 0
+
+    def save_checkpoint(self, val_loss, model, runPath):
+        '''Saves model when validation loss decrease.'''
+        if self.verbose:
+            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
+        # torch.save(model.state_dict(), 'checkpoint.pt')
+        save_model(model, runPath + '/model.rar')  # mmvaeより移植
+        self.val_loss_min = val_loss
+
+
+class EarlyStopping_nosave:
+    """Early stops the training if validation loss doesn't improve after a given patience."""
+
+    def __init__(self, patience=7, verbose=False, delta=0):
+        """
+        Args:
+            patience (int): How long to wait after last time validation loss improved.
+                            Default: 7
+            verbose (bool): If True, prints a message for each validation loss improvement.
+                            Default: False
+            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
+                            Default: 0
+        """
+        self.patience = patience
+        self.verbose = verbose
+        self.counter = 0
+        self.best_score = -1e9
+        self.early_stop = False
+        self.val_loss_min = np.Inf
+        self.delta = delta
+
+    def __call__(self, val_loss, model, runPath):
+
+        score = -val_loss
+
+        if self.best_score is None:
+            self.best_score = score
+        elif score < self.best_score + self.delta:
+            self.counter += 1
+            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.best_score = score
+            self.counter = 0
\ No newline at end of file