[349d16]: / code / dnc_code / DNC_GPU / memory.py

Download this file

135 lines (106 with data), 5.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# DNC memory module
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
class memory_unit(nn.Module):
def __init__(self, N, M, memory_init=None):
super(memory_unit, self).__init__()
self.N = N # Number of Memory cells
self.M = M # Single Memory cell length
# Registering Memory Initialization matrix into state_dict
self.register_buffer('memory_init', torch.Tensor(N, M))
# Memory Initialization
stdev = 1.0 / (np.sqrt(N + M))
# Following snippet allows user to exclusively give initialization values to the memory
# Otherwise it initializes automatically
if memory_init == None:
nn.init.uniform_(self.memory_init, -stdev, stdev) # Memory size is (N, M)
else:
self.memory_init = memory_init.clone()
def memory_size(self):
return self.N, self.M
def reset_memory(self, batch_size):
self.batch_size = batch_size
self.memory = self.memory_init.clone().repeat(batch_size, 1, 1).cuda()
def memory_read(self, W): # Assuming shape of Memory -> (batch_size, N, M)
"""
Input :
W : Read Weights -> (batch_size x N)
"""
return torch.bmm(W.unsqueeze(1), self.memory).squeeze(1) # Out : (batch_size x M) size vector
def memory_write(self, W, e, a):
"""
Input :
W : Write Weights -> (batch_size x N)
e : Erase Vector -> (batch_size x M)
a : Write Vector -> (batch_size x M)
"""
erase_mat = torch.bmm(W.unsqueeze(-1), e.unsqueeze(1)) # Out : (batch_size x N x M) matrix
add_mat = torch.bmm(W.unsqueeze(-1), a.unsqueeze(1)) # Out : (batch_size x N x M) matrix
self.memory = self.memory * (1 - erase_mat) + add_mat # Out : (batch x N x M) matrix
def access_memory_write(self, k, beta, g_w, g_a, alloc_weights): # Returns the weight vector to access memory for write handle
"""
Input :
k : Key vector for matching -> (batch_size x M)
beta : Constant for strength focus -> (batch_size x 1)
g_w : Interpolation Write gate -> (batch_size x 1)
g_a : Allocation Gate -> (batch_size x 1)
alloc_weights : Allocation Weights -> (batch_size x N)
"""
# Content based addressing
W_c = self._content_focusing(k, beta) # Out : (batch_size x N) vector
# Location based addressing
W = self._gating(g_w, g_a, alloc_weights, W_c) # Out : (batch_size x N) vector
return W # Out : (batch_size x N) vector
def _read_mode_interpolation(self, f, b, W_c, r_mode): # Performs Interpolation of Forward, Backward and Content vectors to generate Read weights
"""
Input :
f : Forward Temporal Weights -> 'num_write_heads' sized list of (batch_size x N) tensors
b : Backward Temporal Weights -> 'num_write_heads' sized list of (batch_size x N) tensors
W_c : Content Similarity Weights -> (batch_size x N)
r_mode : Reading Mode Vector -> (batch_size x num_read_mode)
"""
i = 0
W = torch.zeros(W_c.shape).cuda()
for forward in f:
W = W + r_mode[:,i].unsqueeze(-1)*forward
i += 1
W = W + r_mode[:,i].unsqueeze(-1)*W_c
i += 1
for backward in b:
W = W + r_mode[:,i].unsqueeze(-1)*backward
i += 1
return W # Out : (batch_size x N) vector
def access_memory_read(self, k, beta, f, b, r_mode): # Returns the weight vector to access memory for read handle
"""
Input :
k : Key vector for matching -> (batch_size x M)
beta : Constant for strength focus -> (batch_size x 1)
f : Forward Temporal Weights -> 'num_write_heads' sized list of (batch_size x N) tensors
b : Backward Temporal Weights -> 'num_write_heads' sized list of (batch_size x N) tensors
r_mode : Reading Mode Vector -> (batch_size x num_read_mode)
"""
# Content based addressing
W_c = self._content_focusing(k, beta) # Out : (batch_size x N) vector
# Read Mode Interpolation
W = self._read_mode_interpolation(f, b, W_c, r_mode) # Out : (batch_size x N) vector
return W # Out : (batch_size x N) vector
def _content_focusing(self, key, beta):
"""
Input :
k : Key vector for matching -> (batch_size x M)
beta : Constant for strength focus -> (batch_size x 1)
"""
similarity_vector = F.cosine_similarity(key.unsqueeze(1) + 1e-16, self.memory + 1e-16, dim = 2) # We are adding some offset to inputs to avoid zero output
temp_vec = beta*similarity_vector # similarity_vector -> dim : (batch_size x N)
return F.softmax(temp_vec, dim = 1)
def _gating(self, g_w, g_a, alloc_weights, W_c): # Performs gating/interpolation
"""
Input :
g_w : Interpolation Write gate -> (batch_size x 1)
g_a : Allocation Gate -> (batch_size x 1)
alloc_weights : Allocation Weights -> (batch_size x N)
W_c : Content Similarity Weights -> (batch_size x N)
"""
return g_w*(g_a*alloc_weights + (1-g_a)*W_c)