--- a +++ b/src/vae.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import pytorch_pretrained_bert as Bert + +# borrowed boilerplate vae code from https://github.com/AntixK/PyTorch-VAE + + +class VAE(Bert.modeling.BertPreTrainedModel): + def __init__(self, config): + super(VAE, self).__init__(config) + + self.unsuplist = config.unsupSize + + + + + vaelatentdim = config.vaelatentdim + vaeinchannels = config.vaeinchannels + + modules = [] + vaehidden = [config.poolingSize] + self.linearFC = nn.Linear(config.hidden_size, config.poolingSize) + self.activ = nn.ReLU() + + # Build Encoder + self.fc_mu = nn.Linear(vaehidden[-1], vaelatentdim) + self.fc_var = nn.Linear(vaehidden[-1], vaelatentdim) + + + # Build Decoder + modules = [] + + self.decoder1 = nn.Linear(vaelatentdim, vaehidden[-1]) + self.decoder2 = nn.Linear(vaehidden[-1],int( vaehidden[-1])) + + self.logSoftmax = nn.LogSoftmax(dim=1) + + self.linearOut = nn.ModuleList([nn.Linear (int( vaehidden[-1]), el[0]) for el in self.unsuplist]) + self.BetaD = config.BetaD + + self.apply(self.init_bert_weights) + + def encode(self, input: torch.Tensor) : + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [N x C x H x W] + :return: (Tensor) List of latent codes + """ + # result = self.activ (self.linearFC(input)) + + mu = self.fc_mu(input) + log_var = self.fc_var(input) + + return [mu, log_var] + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Maps the given latent codes + onto the image space. + :param z: (Tensor) [B x D] + :return: (Tensor) [B x C x H x W] + """ + result = self.activ(self.decoder1(z)) + result = self.activ(self.decoder2(result)) + outs = [] + + + for outputiter , linoutnetwork in enumerate(self.linearOut): + resout = self.logSoftmax(linoutnetwork(result)) + outs.append(resout) + + outs = torch.cat((outs), dim=1) + return outs + + def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Reparameterization trick to sample from N(mu, var) from + N(0,1). + :param mu: (Tensor) Mean of the latent Gaussian [B x D] + :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] + :return: (Tensor) [B x D] + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def forward(self, input: torch.Tensor, label: torch.Tensor): + + if self.BetaD==False: + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + return [self.decode(z), label, mu, log_var] + else: + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + return [self.decode(z), label, mu, log_var] + def loss_function(self,dictout) -> dict: + """ + Computes the VAE loss function. + KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} + :param args: + :param kwargs: + :return: + """ + recons = dictout[0].transpose(1,0) + input = dictout[1].transpose(1,0) + + mu = dictout[2] + log_var = dictout[3] + if self.BetaD==False: + + kld_weight = self.config.klpar # Account for the minibatch samples from the dataset + reconsloss = 0 + startindx = 0 + + outs = [] + labs = [] + for outputiter , output in enumerate(self.unsuplist): + elementssize = output[0] + chunkrecons = recons[startindx:startindx+elementssize].transpose(1,0) + labels= input[outputiter] + lossF = nn.NLLLoss(reduction='none', ignore_index=-1) + temploss = lossF(chunkrecons,labels).sum() + reconsloss =reconsloss+ temploss + + outs.append(chunkrecons) + labs.append(labels) + startindx = startindx+elementssize + + + kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) + + loss = (reconsloss + kld_weight * kld_loss)/len(dictout[0]) + + if self.config.klpar<1: + self.config.klpar = self.config.klpar + 1e-5 + + return {'loss': loss, 'Reconstruction_Loss':reconsloss, 'KLD':-kld_loss, 'outs':outs, 'labs':labs} + else: + + + return 0 + + def sample(self, + num_samples:int, + current_device: int, **kwargs) -> torch.Tensor: + """ + Samples from the latent space and return the corresponding + image space map. + :param num_samples: (Int) Number of samples + :param current_device: (Int) Device to run the model + :return: (Tensor) + """ + z = torch.randn(num_samples, + self.vaelatentdim) + + z = z.to(current_device) + + samples = self.decode(z) + return samples + + def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Given an input image x, returns the reconstructed image + :param x: (Tensor) [B x C x H x W] + :return: (Tensor) [B x C x H x W] + """ + + return self.forward(x)[0] + + +