Switch to unified view

a b/src/ml_models/bilstm/encoders.py
1
# coding: utf-8
2
"""
3
RNN encoders
4
5
Source: https://github.com/joeynmt/joeynmt/blob/main/joeynmt/encoders.py
6
"""
7
8
# Base Dependencies
9
# -----------------
10
from typing import Tuple
11
12
# Local Dependencies
13
# -------------------
14
from utils import freeze_params
15
16
17
# PyTorch Dependencies
18
# --------------------
19
import torch
20
from torch import Tensor, nn
21
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
22
23
24
class Encoder(nn.Module):
25
    """
26
    Base encoder class
27
    """
28
29
    # pylint: disable=abstract-method
30
    @property
31
    def output_size(self):
32
        """
33
        Returns the output size
34
        """
35
        return self._output_size
36
37
38
class RecurrentEncoder(Encoder):
39
    """Encodes a sequence of word embeddings"""
40
41
    # pylint: disable=unused-argument
42
    def __init__(
43
        self,
44
        rnn_type: str = "gru",
45
        hidden_size: int = 1,
46
        emb_size: int = 1,
47
        num_layers: int = 1,
48
        dropout: float = 0.0,
49
        emb_dropout: float = 0.0,
50
        bidirectional: bool = True,
51
        freeze: bool = False,
52
        **kwargs,
53
    ) -> None:
54
        """Create a new recurrent encoder.
55
56
        Args:
57
            rnn_type (str): RNN type: `gru` or `lstm`.
58
            hidden_size (int): Size of each RNN.
59
            emb_size (int): Size of the word embeddings.
60
            num_layers (int): Number of encoder RNN layers.
61
            dropout (float):  Is applied between RNN layers.
62
            emb_dropout (float): Is applied to the RNN input (word embeddings).
63
            bidirectional (bool): Use a bi-directional RNN.
64
            freeze (bool): freeze the parameters of the encoder during training
65
            kwargs:
66
        """
67
        super().__init__()
68
69
        self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False)
70
        self.type = rnn_type
71
        self.emb_size = emb_size
72
73
        rnn = nn.GRU if rnn_type == "gru" else nn.LSTM
74
75
        self.rnn = rnn(
76
            emb_size,
77
            hidden_size,
78
            num_layers,
79
            batch_first=True,
80
            bidirectional=bidirectional,
81
            dropout=dropout if num_layers > 1 else 0.0,
82
        )
83
84
        self._output_size = 2 * hidden_size if bidirectional else hidden_size
85
86
        if freeze:
87
            freeze_params(self)
88
89
    def _check_shapes_input_forward(
90
        self, embed_src: Tensor, src_length: Tensor
91
    ) -> None:
92
        """
93
        Make sure the shape of the inputs to `self.forward` are correct.
94
        Same input semantics as `self.forward`.
95
96
        Args:
97
            embed_src (Tensor): embedded source tokens
98
            src_length (Tensor): source length
99
        """
100
        # pylint: disable=unused-argument
101
        assert embed_src.shape[0] == src_length.shape[0]
102
        assert embed_src.shape[2] == self.emb_size
103
        assert len(src_length.shape) == 1
104
105
    def forward(
106
        self, embed_src: Tensor, src_length: Tensor, **kwargs
107
    ) -> Tuple[Tensor, Tensor]:
108
        """
109
        Applies a bidirectional RNN to sequence of embeddings x.
110
        The input mini-batch x needs to be sorted by src length.
111
112
        Args:
113
            embed_src: embedded src inputs, shape (batch_size, src_len, embed_size)
114
            src_length: length of src inputs
115
            (counting tokens before padding), shape (batch_size)
116
            kwargs:
117
118
        Returns:
119
            output: hidden states with shape (batch_size, max_length, directions*hidden),
120
            hidden_concat: last hidden state with shape (batch_size, directions*hidden)
121
        """
122
        self._check_shapes_input_forward(embed_src=embed_src, src_length=src_length)
123
        total_length = embed_src.size(1)
124
125
        # apply dropout to the rnn input
126
        embed_src = self.emb_dropout(embed_src)
127
128
        packed = pack_padded_sequence(
129
            embed_src, src_length.cpu(), batch_first=True, enforce_sorted=True
130
        )
131
        output, hidden = self.rnn(packed)
132
133
        if isinstance(hidden, tuple):
134
            hidden, memory_cell = hidden  # pylint: disable=unused-variable
135
136
        output, _ = pad_packed_sequence(
137
            output, batch_first=True, total_length=total_length
138
        )
139
        # hidden: dir*layers x batch x hidden
140
        # output: batch x max_length x directions*hidden
141
        batch_size = hidden.size()[1]
142
        # separate final hidden states by layer and direction
143
        hidden_layerwise = hidden.view(
144
            self.rnn.num_layers,
145
            2 if self.rnn.bidirectional else 1,
146
            batch_size,
147
            self.rnn.hidden_size,
148
        )
149
        # final_layers: layers x directions x batch x hidden
150
151
        # concatenate the final states of the last layer for each directions
152
        # thanks to pack_padded_sequence final states don't include padding
153
        fwd_hidden_last = hidden_layerwise[-1:, 0]
154
        bwd_hidden_last = hidden_layerwise[-1:, 1]
155
156
        # only feed the final state of the top-most layer to the decoder
157
        # pylint: disable=no-member
158
        hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0)
159
        # final: batch x directions*hidden
160
161
        assert hidden_concat.size(0) == output.size(0), (
162
            hidden_concat.size(),
163
            output.size(),
164
        )
165
        return output, hidden_concat
166
167
    def __repr__(self):
168
        return f"{self.__class__.__name__}(rnn={self.rnn})"