|
a |
|
b/src/vae.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import pytorch_pretrained_bert as Bert |
|
|
4 |
|
|
|
5 |
# borrowed boilerplate vae code from https://github.com/AntixK/PyTorch-VAE |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class VAE(Bert.modeling.BertPreTrainedModel): |
|
|
9 |
def __init__(self, config): |
|
|
10 |
super(VAE, self).__init__(config) |
|
|
11 |
|
|
|
12 |
self.unsuplist = config.unsupSize |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
|
|
|
16 |
|
|
|
17 |
vaelatentdim = config.vaelatentdim |
|
|
18 |
vaeinchannels = config.vaeinchannels |
|
|
19 |
|
|
|
20 |
modules = [] |
|
|
21 |
vaehidden = [config.poolingSize] |
|
|
22 |
self.linearFC = nn.Linear(config.hidden_size, config.poolingSize) |
|
|
23 |
self.activ = nn.ReLU() |
|
|
24 |
|
|
|
25 |
# Build Encoder |
|
|
26 |
self.fc_mu = nn.Linear(vaehidden[-1], vaelatentdim) |
|
|
27 |
self.fc_var = nn.Linear(vaehidden[-1], vaelatentdim) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
# Build Decoder |
|
|
31 |
modules = [] |
|
|
32 |
|
|
|
33 |
self.decoder1 = nn.Linear(vaelatentdim, vaehidden[-1]) |
|
|
34 |
self.decoder2 = nn.Linear(vaehidden[-1],int( vaehidden[-1])) |
|
|
35 |
|
|
|
36 |
self.logSoftmax = nn.LogSoftmax(dim=1) |
|
|
37 |
|
|
|
38 |
self.linearOut = nn.ModuleList([nn.Linear (int( vaehidden[-1]), el[0]) for el in self.unsuplist]) |
|
|
39 |
self.BetaD = config.BetaD |
|
|
40 |
|
|
|
41 |
self.apply(self.init_bert_weights) |
|
|
42 |
|
|
|
43 |
def encode(self, input: torch.Tensor) : |
|
|
44 |
""" |
|
|
45 |
Encodes the input by passing through the encoder network |
|
|
46 |
and returns the latent codes. |
|
|
47 |
:param input: (Tensor) Input tensor to encoder [N x C x H x W] |
|
|
48 |
:return: (Tensor) List of latent codes |
|
|
49 |
""" |
|
|
50 |
# result = self.activ (self.linearFC(input)) |
|
|
51 |
|
|
|
52 |
mu = self.fc_mu(input) |
|
|
53 |
log_var = self.fc_var(input) |
|
|
54 |
|
|
|
55 |
return [mu, log_var] |
|
|
56 |
|
|
|
57 |
def decode(self, z: torch.Tensor) -> torch.Tensor: |
|
|
58 |
""" |
|
|
59 |
Maps the given latent codes |
|
|
60 |
onto the image space. |
|
|
61 |
:param z: (Tensor) [B x D] |
|
|
62 |
:return: (Tensor) [B x C x H x W] |
|
|
63 |
""" |
|
|
64 |
result = self.activ(self.decoder1(z)) |
|
|
65 |
result = self.activ(self.decoder2(result)) |
|
|
66 |
outs = [] |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
for outputiter , linoutnetwork in enumerate(self.linearOut): |
|
|
70 |
resout = self.logSoftmax(linoutnetwork(result)) |
|
|
71 |
outs.append(resout) |
|
|
72 |
|
|
|
73 |
outs = torch.cat((outs), dim=1) |
|
|
74 |
return outs |
|
|
75 |
|
|
|
76 |
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: |
|
|
77 |
""" |
|
|
78 |
Reparameterization trick to sample from N(mu, var) from |
|
|
79 |
N(0,1). |
|
|
80 |
:param mu: (Tensor) Mean of the latent Gaussian [B x D] |
|
|
81 |
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] |
|
|
82 |
:return: (Tensor) [B x D] |
|
|
83 |
""" |
|
|
84 |
std = torch.exp(0.5 * logvar) |
|
|
85 |
eps = torch.randn_like(std) |
|
|
86 |
return eps * std + mu |
|
|
87 |
|
|
|
88 |
def forward(self, input: torch.Tensor, label: torch.Tensor): |
|
|
89 |
|
|
|
90 |
if self.BetaD==False: |
|
|
91 |
mu, log_var = self.encode(input) |
|
|
92 |
z = self.reparameterize(mu, log_var) |
|
|
93 |
return [self.decode(z), label, mu, log_var] |
|
|
94 |
else: |
|
|
95 |
mu, log_var = self.encode(input) |
|
|
96 |
z = self.reparameterize(mu, log_var) |
|
|
97 |
return [self.decode(z), label, mu, log_var] |
|
|
98 |
def loss_function(self,dictout) -> dict: |
|
|
99 |
""" |
|
|
100 |
Computes the VAE loss function. |
|
|
101 |
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} |
|
|
102 |
:param args: |
|
|
103 |
:param kwargs: |
|
|
104 |
:return: |
|
|
105 |
""" |
|
|
106 |
recons = dictout[0].transpose(1,0) |
|
|
107 |
input = dictout[1].transpose(1,0) |
|
|
108 |
|
|
|
109 |
mu = dictout[2] |
|
|
110 |
log_var = dictout[3] |
|
|
111 |
if self.BetaD==False: |
|
|
112 |
|
|
|
113 |
kld_weight = self.config.klpar # Account for the minibatch samples from the dataset |
|
|
114 |
reconsloss = 0 |
|
|
115 |
startindx = 0 |
|
|
116 |
|
|
|
117 |
outs = [] |
|
|
118 |
labs = [] |
|
|
119 |
for outputiter , output in enumerate(self.unsuplist): |
|
|
120 |
elementssize = output[0] |
|
|
121 |
chunkrecons = recons[startindx:startindx+elementssize].transpose(1,0) |
|
|
122 |
labels= input[outputiter] |
|
|
123 |
lossF = nn.NLLLoss(reduction='none', ignore_index=-1) |
|
|
124 |
temploss = lossF(chunkrecons,labels).sum() |
|
|
125 |
reconsloss =reconsloss+ temploss |
|
|
126 |
|
|
|
127 |
outs.append(chunkrecons) |
|
|
128 |
labs.append(labels) |
|
|
129 |
startindx = startindx+elementssize |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) |
|
|
133 |
|
|
|
134 |
loss = (reconsloss + kld_weight * kld_loss)/len(dictout[0]) |
|
|
135 |
|
|
|
136 |
if self.config.klpar<1: |
|
|
137 |
self.config.klpar = self.config.klpar + 1e-5 |
|
|
138 |
|
|
|
139 |
return {'loss': loss, 'Reconstruction_Loss':reconsloss, 'KLD':-kld_loss, 'outs':outs, 'labs':labs} |
|
|
140 |
else: |
|
|
141 |
|
|
|
142 |
|
|
|
143 |
return 0 |
|
|
144 |
|
|
|
145 |
def sample(self, |
|
|
146 |
num_samples:int, |
|
|
147 |
current_device: int, **kwargs) -> torch.Tensor: |
|
|
148 |
""" |
|
|
149 |
Samples from the latent space and return the corresponding |
|
|
150 |
image space map. |
|
|
151 |
:param num_samples: (Int) Number of samples |
|
|
152 |
:param current_device: (Int) Device to run the model |
|
|
153 |
:return: (Tensor) |
|
|
154 |
""" |
|
|
155 |
z = torch.randn(num_samples, |
|
|
156 |
self.vaelatentdim) |
|
|
157 |
|
|
|
158 |
z = z.to(current_device) |
|
|
159 |
|
|
|
160 |
samples = self.decode(z) |
|
|
161 |
return samples |
|
|
162 |
|
|
|
163 |
def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
164 |
""" |
|
|
165 |
Given an input image x, returns the reconstructed image |
|
|
166 |
:param x: (Tensor) [B x C x H x W] |
|
|
167 |
:return: (Tensor) [B x C x H x W] |
|
|
168 |
""" |
|
|
169 |
|
|
|
170 |
return self.forward(x)[0] |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
|