Diff of /src/vae.py [000000] .. [9e1f38]

Switch to unified view

a b/src/vae.py
1
import torch
2
import torch.nn as nn
3
import pytorch_pretrained_bert as Bert
4
5
# borrowed boilerplate vae code from https://github.com/AntixK/PyTorch-VAE
6
7
8
class VAE(Bert.modeling.BertPreTrainedModel):
9
    def __init__(self, config):
10
        super(VAE, self).__init__(config)
11
12
        self.unsuplist = config.unsupSize
13
14
15
16
17
        vaelatentdim = config.vaelatentdim
18
        vaeinchannels = config.vaeinchannels
19
20
        modules = []
21
        vaehidden = [config.poolingSize]
22
        self.linearFC = nn.Linear(config.hidden_size, config.poolingSize)
23
        self.activ =  nn.ReLU()
24
25
        # Build Encoder
26
        self.fc_mu = nn.Linear(vaehidden[-1], vaelatentdim)
27
        self.fc_var = nn.Linear(vaehidden[-1], vaelatentdim)
28
29
30
        # Build Decoder
31
        modules = []
32
33
        self.decoder1 = nn.Linear(vaelatentdim, vaehidden[-1])
34
        self.decoder2 = nn.Linear(vaehidden[-1],int( vaehidden[-1]))
35
36
        self.logSoftmax = nn.LogSoftmax(dim=1)
37
38
        self.linearOut =  nn.ModuleList([nn.Linear   (int( vaehidden[-1]), el[0]) for el in self.unsuplist])
39
        self.BetaD = config.BetaD
40
41
        self.apply(self.init_bert_weights)
42
43
    def encode(self, input: torch.Tensor) :
44
        """
45
        Encodes the input by passing through the encoder network
46
        and returns the latent codes.
47
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
48
        :return: (Tensor) List of latent codes
49
        """
50
        # result = self.activ (self.linearFC(input))
51
52
        mu = self.fc_mu(input)
53
        log_var = self.fc_var(input)
54
55
        return [mu, log_var]
56
57
    def decode(self, z: torch.Tensor) -> torch.Tensor:
58
        """
59
        Maps the given latent codes
60
        onto the image space.
61
        :param z: (Tensor) [B x D]
62
        :return: (Tensor) [B x C x H x W]
63
        """
64
        result = self.activ(self.decoder1(z))
65
        result = self.activ(self.decoder2(result))
66
        outs = []
67
68
69
        for outputiter , linoutnetwork in enumerate(self.linearOut):
70
            resout = self.logSoftmax(linoutnetwork(result))
71
            outs.append(resout)
72
73
        outs = torch.cat((outs), dim=1)
74
        return outs
75
76
    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
77
        """
78
        Reparameterization trick to sample from N(mu, var) from
79
        N(0,1).
80
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
81
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
82
        :return: (Tensor) [B x D]
83
        """
84
        std = torch.exp(0.5 * logvar)
85
        eps = torch.randn_like(std)
86
        return eps * std + mu
87
88
    def forward(self, input: torch.Tensor, label: torch.Tensor):
89
90
        if self.BetaD==False:
91
            mu, log_var = self.encode(input)
92
            z = self.reparameterize(mu, log_var)
93
            return  [self.decode(z), label, mu, log_var]
94
        else:
95
            mu, log_var = self.encode(input)
96
            z = self.reparameterize(mu, log_var)
97
            return  [self.decode(z), label, mu, log_var]
98
    def loss_function(self,dictout) -> dict:
99
        """
100
        Computes the VAE loss function.
101
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
102
        :param args:
103
        :param kwargs:
104
        :return:
105
        """
106
        recons = dictout[0].transpose(1,0)
107
        input = dictout[1].transpose(1,0)
108
109
        mu = dictout[2]
110
        log_var = dictout[3]
111
        if self.BetaD==False:
112
113
            kld_weight = self.config.klpar # Account for the minibatch samples from the dataset
114
            reconsloss = 0
115
            startindx = 0
116
117
            outs = []
118
            labs = []
119
            for outputiter , output in enumerate(self.unsuplist):
120
                elementssize = output[0]
121
                chunkrecons = recons[startindx:startindx+elementssize].transpose(1,0)
122
                labels= input[outputiter]
123
                lossF  = nn.NLLLoss(reduction='none', ignore_index=-1)
124
                temploss = lossF(chunkrecons,labels).sum()
125
                reconsloss =reconsloss+ temploss
126
127
                outs.append(chunkrecons)
128
                labs.append(labels)
129
                startindx = startindx+elementssize
130
131
132
            kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
133
134
            loss = (reconsloss + kld_weight * kld_loss)/len(dictout[0])
135
136
            if self.config.klpar<1:
137
                self.config.klpar = self.config.klpar + 1e-5
138
139
            return {'loss': loss, 'Reconstruction_Loss':reconsloss, 'KLD':-kld_loss, 'outs':outs, 'labs':labs}
140
        else:
141
142
143
            return 0
144
145
    def sample(self,
146
               num_samples:int,
147
               current_device: int, **kwargs) -> torch.Tensor:
148
        """
149
        Samples from the latent space and return the corresponding
150
        image space map.
151
        :param num_samples: (Int) Number of samples
152
        :param current_device: (Int) Device to run the model
153
        :return: (Tensor)
154
        """
155
        z = torch.randn(num_samples,
156
                        self.vaelatentdim)
157
158
        z = z.to(current_device)
159
160
        samples = self.decode(z)
161
        return samples
162
163
    def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
164
        """
165
        Given an input image x, returns the reconstructed image
166
        :param x: (Tensor) [B x C x H x W]
167
        :return: (Tensor) [B x C x H x W]
168
        """
169
170
        return self.forward(x)[0]
171
172
173