Diff of /code/dnc_code/DNC/dnc.py [000000] .. [349d16]

Switch to unified view

a b/code/dnc_code/DNC/dnc.py
1
# Final DNC version packaging all the modules
2
import torch
3
from torch import nn
4
from .memory import memory_unit
5
from .processor import processor
6
from .controller import backward_controller
7
8
class DNC_Module(nn.Module):
9
10
    def __init__(self, num_inputs, num_outputs, controller_size, controller_layers, num_read_heads, num_write_heads, N, M):
11
12
        # Params:
13
        # num_inputs : Size of input data
14
        # num_outputs : Size of output data
15
        # controller_size : Size of LSTM Controller output/state
16
        # controller_layers : Number of layers in LSTM Network
17
        # num_read_heads : Number of Read heads to be created
18
        # num_write_heads : Number of Write heads to be created
19
        # N : Number of memory cells
20
        # M : Size of Each memory cell
21
22
        super(DNC_Module, self).__init__()
23
24
        self.num_inputs = num_inputs
25
        self.num_outputs = num_outputs
26
        self.N = N
27
        self.M = M
28
29
        # Creating NTM modules
30
        self.memory = memory_unit(self.N, self.M)
31
        self.processor = processor(self.num_inputs, self.num_outputs, self.M, self.N, num_read_heads, num_write_heads, controller_size, controller_layers)
32
        
33
        # Creating the Reverse Controller
34
        self.bController = backward_controller(self.num_inputs, controller_size, controller_layers)
35
36
    def initialization(self, batch_size):   # Initializing all the Modules
37
        self.batch_size = batch_size
38
        self.memory.reset_memory(batch_size)
39
        self.previous_state = self.processor.create_new_state(batch_size)
40
        self.previous_backward_states = self.bController.create_hidden_state(batch_size)
41
42
    def backward_prediction(self, X):       # Giving the input to the Backward Controller for making DNC Bi-Directional
43
        # X dim: (seq_len x batch_size x self.num_inputs)
44
        output, _ = self.bController(X, self.previous_backward_states)  # Output dim: (seq_len x batch x controller_size)
45
        return output                                                   # Use embiddings from last to first (reverse way)
46
47
    def forward(self, X=None, backward_embeddings=None):
48
        if X is None:
49
            X = torch.zeros(self.batch_size, self.num_inputs)
50
        out, self.previous_state = self.processor(X, backward_embeddings, self.previous_state, self.memory)
51
        return out, self.previous_state
52
    
53
    '''
54
    def calculate_num_params(self):     # This maybe for model statistics. Adapted
55
        num_params = 0
56
        for p in self.parameters():
57
            num_params += p.data.view(-1).size(0)
58
        return num_params
59
    '''