Diff of /torchdrug/layers/flow.py [000000] .. [36b44b]

Switch to unified view

a b/torchdrug/layers/flow.py
1
import torch
2
from torch import nn
3
from torch.nn import functional as F
4
5
from torchdrug import layers
6
7
8
class ConditionalFlow(nn.Module):
9
    """
10
    Conditional flow transformation from `Masked Autoregressive Flow for Density Estimation`_.
11
12
    .. _Masked Autoregressive Flow for Density Estimation:
13
        https://arxiv.org/pdf/1705.07057.pdf
14
15
    Parameters:
16
        input_dim (int): input & output dimension
17
        condition_dim (int): condition dimension
18
        hidden_dims (list of int, optional): hidden dimensions
19
        activation (str or function, optional): activation function
20
    """
21
22
    def __init__(self, input_dim, condition_dim, hidden_dims=None, activation="relu"):
23
        super(ConditionalFlow, self).__init__()
24
        self.input_dim = input_dim
25
        self.output_dim = input_dim
26
27
        if hidden_dims is None:
28
            hidden_dims = []
29
        self.mlp = layers.MLP(condition_dim, list(hidden_dims) + [input_dim * 2], activation)
30
        self.rescale = nn.Parameter(torch.zeros(1))
31
32
    def forward(self, input, condition):
33
        """
34
        Transform data into latent representations.
35
36
        Parameters:
37
            input (Tensor): input representations
38
            condition (Tensor): conditional representations
39
40
        Returns:
41
            (Tensor, Tensor): latent representations, log-likelihood of the transformation
42
        """
43
        scale, bias = self.mlp(condition).chunk(2, dim=-1)
44
        scale = (F.tanh(scale) * self.rescale)
45
        output = (input + bias) * scale.exp()
46
        log_det = scale
47
        return output, log_det
48
49
    def reverse(self, latent, condition):
50
        """
51
        Transform latent representations into data.
52
53
        Parameters:
54
            latent (Tensor): latent representations
55
            condition (Tensor): conditional representations
56
57
        Returns:
58
            (Tensor, Tensor): input representations, log-likelihood of the transformation
59
        """
60
        scale, bias = self.mlp(condition).chunk(2, dim=-1)
61
        scale = (F.tanh(scale) * self.rescale)
62
        output = latent / scale.exp() - bias
63
        log_det = scale
64
        return output, log_det