[121e55]: / src / nichecompass / modules / vgaemodulemixin.py

Download this file

42 lines (36 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
This module contains generic VGAE functionalities, added as a Mixin to the
Variational Gene Program Graph Autoencoder neural network module.
"""
import torch
class VGAEModuleMixin:
"""
VGAE module mix in class containing universal VGAE module
functionalities.
"""
def reparameterize(self,
mu: torch.Tensor,
logstd: torch.Tensor) -> torch.Tensor:
"""
Use reparameterization trick for latent space normal distribution.
Parameters
----------
mu:
Expected values of the latent space distribution (dim: n_obs,
n_gps).
logstd:
Log standard deviations of the latent space distribution (dim: n_obs,
n_gps).
Returns
----------
rep:
Reparameterized latent features (dim: n_obs, n_gps).
"""
if self.training:
std = torch.exp(logstd)
eps = torch.randn_like(mu)
rep = eps.mul(std).add(mu)
return rep
else:
rep = mu
return rep