Switch to unified view

a b/code/dnc_code/DNC/controller.py
1
# Read and Write Head controller based on LSTM.
2
# Note : Derived from GitHub user loudinthecloud's NTM implementation
3
4
import torch
5
from torch import nn
6
from torch.nn import Parameter
7
import numpy as np
8
9
class controller(nn.Module):    # LSTM Controller
10
    def __init__(self, num_inputs, num_outputs, num_layers):
11
        super(controller, self).__init__()
12
13
        self.num_inputs = num_inputs
14
        self.num_outputs = num_outputs
15
        self.num_layers = num_layers
16
17
        self.lstm_network = nn.LSTM(input_size = self.num_inputs, hidden_size = self.num_outputs, num_layers = self.num_layers)
18
19
        # Parameters of the LSTM. Hidden state serves as the output of our network
20
        self.h_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # Hidden state initialization
21
        self.c_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # C variable initialization
22
23
        # Initialization of the LSTM parameters.
24
        for p in self.lstm_network.parameters():
25
            if p.dim() == 1:
26
                nn.init.constant_(p, 0)
27
            else:
28
                stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))  # I don't know why we multiplied 5
29
                nn.init.uniform_(p, -stdev, stdev)
30
31
    def create_hidden_state(self, batch_size):  # Output : (num_layers x batch_size x num_outputs)
32
        h = self.h_init.clone().repeat(1, batch_size, 1)
33
        c = self.c_init.clone().repeat(1, batch_size, 1)
34
        return h, c
35
36
    def network_size(self):
37
        return self.num_inputs, self.num_outputs
38
39
    def forward(self, inp, prev_state):
40
        inp = inp.unsqueeze(0)                              # inp dimension after unsqueeze : (1 x inp.shape)
41
        output, state = self.lstm_network(inp, prev_state)  # Input to LSTM must be of shape (seq_len x batch_size x input_size) in Pytorch. Here, seq_len = 1
42
        return output.squeeze(0), state
43
44
class backward_controller(nn.Module):   # Backward LSTM to make DNC Bi-Directional
45
    def __init__(self, num_inputs, num_outputs, num_layers):
46
        super(backward_controller, self).__init__()
47
48
        self.num_inputs = num_inputs
49
        self.num_outputs = num_outputs
50
        self.num_layers = num_layers
51
52
        self.lstm_network = nn.LSTM(input_size = self.num_inputs, hidden_size = self.num_outputs, num_layers = self.num_layers)
53
54
        # Parameters of the LSTM. Hidden state serves as the output of our network
55
        self.h_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # Hidden state initialization
56
        self.c_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # C variable initialization
57
58
        # Initialization of the LSTM parameters.
59
        for p in self.lstm_network.parameters():
60
            if p.dim() == 1:
61
                nn.init.constant_(p, 0)
62
            else:
63
                stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))  # I don't know why we multiplied 5
64
                nn.init.uniform_(p, -stdev, stdev)
65
66
    def create_hidden_state(self, batch_size):  # Output : (num_layers x batch_size x num_outputs)
67
        h = self.h_init.clone().repeat(1, batch_size, 1)
68
        c = self.c_init.clone().repeat(1, batch_size, 1)
69
        return h, c
70
71
    def network_size(self):
72
        return self.num_inputs, self.num_outputs
73
74
    def forward(self, inp, prev_states):                                 # inp dimension: (seq_len x batch_size x input_size)
75
        inp = inp[torch.arange(inp.shape[0]-1, -1, -1), :, :]            # Reversing the input for backward direction
76
        output, state = self.lstm_network(inp, prev_states)              # Input to LSTM must be of shape (seq_len x batch_size x input_size) in Pytorch. Here, seq_len = 1
77
        # output = output[torch.arange(output.shape[0]-1, -1, -1), :, :] # Reversing the 'output'.
78
        return output, state                                             # Output size is (seq_len x batch x hidden_size) as per documentation