Switch to unified view

a b/src/nichecompass/modules/vgaemodulemixin.py
1
"""
2
This module contains generic VGAE functionalities, added as a Mixin to the
3
Variational Gene Program Graph Autoencoder neural network module.
4
"""
5
6
import torch
7
8
9
class VGAEModuleMixin:
10
    """
11
    VGAE module mix in class containing universal VGAE module 
12
    functionalities.
13
    """
14
    def reparameterize(self,
15
                       mu: torch.Tensor,
16
                       logstd: torch.Tensor) -> torch.Tensor:
17
        """
18
        Use reparameterization trick for latent space normal distribution.
19
        
20
        Parameters
21
        ----------
22
        mu:
23
            Expected values of the latent space distribution (dim: n_obs, 
24
            n_gps).
25
        logstd:
26
            Log standard deviations of the latent space distribution (dim: n_obs,
27
            n_gps).
28
29
        Returns
30
        ----------
31
        rep:
32
            Reparameterized latent features (dim: n_obs, n_gps).
33
        """
34
        if self.training:
35
            std = torch.exp(logstd)
36
            eps = torch.randn_like(mu)
37
            rep = eps.mul(std).add(mu)
38
            return rep
39
        else:
40
            rep = mu
41
            return rep
42