|
a |
|
b/code/dnc_code/DNC/dnc.py |
|
|
1 |
# Final DNC version packaging all the modules |
|
|
2 |
import torch |
|
|
3 |
from torch import nn |
|
|
4 |
from .memory import memory_unit |
|
|
5 |
from .processor import processor |
|
|
6 |
from .controller import backward_controller |
|
|
7 |
|
|
|
8 |
class DNC_Module(nn.Module): |
|
|
9 |
|
|
|
10 |
def __init__(self, num_inputs, num_outputs, controller_size, controller_layers, num_read_heads, num_write_heads, N, M): |
|
|
11 |
|
|
|
12 |
# Params: |
|
|
13 |
# num_inputs : Size of input data |
|
|
14 |
# num_outputs : Size of output data |
|
|
15 |
# controller_size : Size of LSTM Controller output/state |
|
|
16 |
# controller_layers : Number of layers in LSTM Network |
|
|
17 |
# num_read_heads : Number of Read heads to be created |
|
|
18 |
# num_write_heads : Number of Write heads to be created |
|
|
19 |
# N : Number of memory cells |
|
|
20 |
# M : Size of Each memory cell |
|
|
21 |
|
|
|
22 |
super(DNC_Module, self).__init__() |
|
|
23 |
|
|
|
24 |
self.num_inputs = num_inputs |
|
|
25 |
self.num_outputs = num_outputs |
|
|
26 |
self.N = N |
|
|
27 |
self.M = M |
|
|
28 |
|
|
|
29 |
# Creating NTM modules |
|
|
30 |
self.memory = memory_unit(self.N, self.M) |
|
|
31 |
self.processor = processor(self.num_inputs, self.num_outputs, self.M, self.N, num_read_heads, num_write_heads, controller_size, controller_layers) |
|
|
32 |
|
|
|
33 |
# Creating the Reverse Controller |
|
|
34 |
self.bController = backward_controller(self.num_inputs, controller_size, controller_layers) |
|
|
35 |
|
|
|
36 |
def initialization(self, batch_size): # Initializing all the Modules |
|
|
37 |
self.batch_size = batch_size |
|
|
38 |
self.memory.reset_memory(batch_size) |
|
|
39 |
self.previous_state = self.processor.create_new_state(batch_size) |
|
|
40 |
self.previous_backward_states = self.bController.create_hidden_state(batch_size) |
|
|
41 |
|
|
|
42 |
def backward_prediction(self, X): # Giving the input to the Backward Controller for making DNC Bi-Directional |
|
|
43 |
# X dim: (seq_len x batch_size x self.num_inputs) |
|
|
44 |
output, _ = self.bController(X, self.previous_backward_states) # Output dim: (seq_len x batch x controller_size) |
|
|
45 |
return output # Use embiddings from last to first (reverse way) |
|
|
46 |
|
|
|
47 |
def forward(self, X=None, backward_embeddings=None): |
|
|
48 |
if X is None: |
|
|
49 |
X = torch.zeros(self.batch_size, self.num_inputs) |
|
|
50 |
out, self.previous_state = self.processor(X, backward_embeddings, self.previous_state, self.memory) |
|
|
51 |
return out, self.previous_state |
|
|
52 |
|
|
|
53 |
''' |
|
|
54 |
def calculate_num_params(self): # This maybe for model statistics. Adapted |
|
|
55 |
num_params = 0 |
|
|
56 |
for p in self.parameters(): |
|
|
57 |
num_params += p.data.view(-1).size(0) |
|
|
58 |
return num_params |
|
|
59 |
''' |