--- a +++ b/dl/models/transformer.py @@ -0,0 +1,326 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Variable + +if torch.cuda.is_available(): + dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} +else: + dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} + +class MultiheadAttention(nn.Module): + """ + """ + def __init__(self, in_dim, out_dim, key_dim, value_dim, num_heads=1, mask=False, + query_in_dim=None, knn=None): + super(MultiheadAttention, self).__init__() + self.key_dim = key_dim + self.keys = nn.ModuleList([nn.Linear(in_dim, key_dim) for i in range(num_heads)]) + if query_in_dim is not None: + self.keys_query = nn.ModuleList([nn.Linear(query_in_dim, key_dim) + for i in range(num_heads)]) + self.values = nn.ModuleList([nn.Linear(in_dim, value_dim) for i in range(num_heads)]) + self.out = nn.Linear(value_dim*num_heads, out_dim) + self.mask = mask + self.knn = knn + + def forward(self, x, q=None, return_graph=False): + y = [] + if return_graph: + graph = [] + if q is not None: + #assert self.knn is None or self.knn <= x.size(-2) # found bug here, not clear why yet + size_x = x.size() + size_q = q.size() + x = x.contiguous().view(size_x[0], -1, size_x[-1]) + q = q.contiguous().view(size_q[0], -1, size_q[-1]) + self.mask = False + for i, (K, V) in enumerate(zip(self.keys, self.values)): + key = K(x) + value = V(x) + if q is None: + query = key + else: + if hasattr(self, 'keys_query'): + query = self.keys_query[i](q) + else: + query = K(q) + att_unnorm = (query.unsqueeze(-2)*key.unsqueeze(-3)).sum(-1) / np.sqrt(self.key_dim) + if return_graph: + graph.append(nn.functional.softmax(att_unnorm, dim=-1)) + if self.mask: # mask right side; useful for decoder with sequential output + seq_len = att_unnorm.size(-2) + if att_unnorm.dim() == 3: + for i in range(seq_len-1): + att_unnorm[:, i, (i+1):] = float('-inf') + elif att_unnorm.dim() == 4: + for i in range(seq_len-1): + att_unnorm[:, :, i, (i+1):] = float('-inf') + else: + raise ValueError('Expect x.dim() <= 4, but x.dim() = {0}'.format(x.dim())) + if isinstance(self.knn, int): + self.knn = min(self.knn, att_unnorm.size(-1)) + att_topk, idx = att_unnorm.topk(self.knn, dim=-1) + att_ = Variable(torch.zeros(att_unnorm.size()).fill_(float('-inf')).type(dtype['float'])) + att_.scatter_(-1, idx, att_topk) + att_unnorm = att_ + att = nn.functional.softmax(att_unnorm, dim=-1) + # tricky + cur_y = (att.unsqueeze(-1) * value.unsqueeze(-3)).sum(-2) + if q is not None: + cur_y = cur_y.contiguous().view(*size_q[:-1], cur_y.size(-1)) + y.append(cur_y) + y = torch.cat(y, -1) + y = self.out(y) + if return_graph: + graph = torch.stack(graph).mean(0) + return y, graph + return y + + +class EncoderAttention(nn.Module): + """ + """ + def __init__(self, in_dim, out_dim, key_dim, value_dim, fc_dim, num_heads=1, residual=True, + normalization=None, nonlinearity=nn.ReLU(), mask=False, query_in_dim=None, knn=None): + super(EncoderAttention, self).__init__() + self.attention = MultiheadAttention(in_dim, out_dim, key_dim, value_dim, num_heads, mask=mask, + query_in_dim=query_in_dim, knn=knn) + self.residual = residual + self.normalization = normalization + self.fc = nn.Sequential(nn.Linear(out_dim, fc_dim), + nonlinearity, + nn.Linear(fc_dim, out_dim)) + + def forward(self, x, q=None, return_graph=False): + if return_graph: + out, graph = self.attention(x, q, return_graph=True) + else: + out = self.attention(x, q) + if self.residual: + out += x + if isinstance(self.normalization, nn.Module): + out = self.normalization(out) + x = self.fc(out) + if self.residual: + x += out + if isinstance(self.normalization, nn.Module): + out = self.normalization(x) + if return_graph: + return out, graph + return out + +class DecoderAttention(nn.Module): + """ + """ + def __init__(self, in_dim, out_dim, key_dim, value_dim, fc_dim, num_heads=1, residual=True, + normalization=None, nonlinearity=nn.ReLU(), mask=True, query_key=False, knn=None): + super(DecoderAttention, self).__init__() + if residual: + assert in_dim == out_dim + self.attention_mask = MultiheadAttention(in_dim, out_dim, key_dim, value_dim, num_heads, mask=mask, knn=knn) + self.attention_encoder = MultiheadAttention(in_dim, out_dim, key_dim, value_dim, num_heads, + mask=False, + query_in_dim=out_dim if query_key else None, knn=knn) + self.residual = residual + self.normalization = normalization + self.fc = nn.Sequential(nn.Linear(out_dim, fc_dim), + nonlinearity, + nn.Linear(fc_dim, out_dim)) + + def forward(self, x, input, return_graph=False): + if return_graph: + out, graph = self.attention_mask(x, return_graph=True) + else: + out = self.attention_mask(x) + if self.residual: + out = out + x + if isinstance(self.normalization, nn.Module): + out = self.normalization(out) + x = self.attention_encoder(input, out) + if self.residual: + out = out + x + if isinstance(self.normalization, nn.Module): + out = self.normalization(out) + x = self.fc(out) + if self.residual: + x = x + out + if isinstance(self.normalization, nn.Module): + out = self.normalization(x) + if return_graph: + return out, graph + return out + + +def get_uniq_topk(rank, history): + res = [] + if history is None: + res = rank[:, 0] + history = rank[:, :1] + else: + for r, h in zip(rank.data, history.data): + for i in r: + if i in h: + continue + else: + res.append(i) + break + res = Variable(dtype['long'](res)) + history = torch.cat([history, res.unsqueeze(-1)], -1) + return res, history + +def get_target(s, t): + return Variable(dtype['long'](np.array([[ + k if k in set(j).intersection(i) else np.random.choice(list(set(j).difference(i))) + for idx, k in enumerate(i)] for i,j in zip(s.data, t.data)]))) + + +class Transformer(nn.Module): + """ + """ + def __init__(self, in_dim, key_dim, value_dim, fc_dim, linear_dim, in_voc_size, + out_voc_size, in_seq_len, out_seq_len, encode_input_position=True, + encode_output_position=False, num_heads=1, num_attention=1, residual=True, + normalization=None, nonlinearity=nn.ReLU(), duplicated_attention=False, mask=True, + unique_output=False, knn=None): + super(Transformer, self).__init__() + self.in_dim = in_dim + self.out_seq_len = out_seq_len + self.out_voc_size = out_voc_size + self.in_embed = nn.Embedding(in_voc_size, in_dim) + self.out_embed = nn.Embedding(out_voc_size+2, in_dim) + self.encode_input_position = encode_input_position + if self.encode_input_position: + self.input_pos_weight = nn.Parameter(torch.randn(2)) + self.input_pos_vec = Variable(torch.Tensor([[np.sin(i/in_seq_len**(j/in_dim)) if j%2==0 + else np.cos(i/in_seq_len**(j/in_dim)) + for j in range(in_dim)] for i in range(in_seq_len)]).type(dtype['float'])) + self.encode_output_position = encode_output_position + if self.encode_output_position: + self.output_pos_weight = nn.Parameter(torch.randn(2)) + self.output_pos_vec = Variable(torch.Tensor([[np.sin(i/out_seq_len**(j/in_dim)) if j%2==0 + else np.cos(i/out_seq_len**(j/in_dim)) + for j in range(in_dim)] for i in range(out_seq_len)]).type(dtype['float'])) + + if duplicated_attention: + self.encoders = nn.ModuleList([EncoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, nonlinearity, knn=knn)] + * num_attention) + self.decoders = nn.ModuleList([DecoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, + nonlinearity, mask, knn=knn)] * num_attention) + else: + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + for i in range(num_attention): + self.encoders.append(EncoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, nonlinearity, knn=knn)) + self.decoders.append(DecoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, + nonlinearity, mask, knn=knn)) + + self.linear = nn.Linear(in_dim, out_voc_size+1) + self.unique_output = unique_output + self.knn = knn + + def forward(self, x, out=None, sequential=True, last_output_only=True): + if sequential: + assert self.knn is None + if x.dim()==2: + x = self.in_embed(x) + else: + size_x = x.size() + x = self.in_embed(x.contiguous().view(-1, size_x[-1])).contiguous().view(*size_x, self.in_dim) + if self.encode_input_position: + pos_weight = nn.functional.softmax(self.input_pos_weight, dim=0) + x = x*pos_weight[0] + self.input_pos_vec*pos_weight[1] + for encoder in self.encoders: + x = encoder(x) + + if not sequential: + # This does not work well + if out is None: + out = Variable(dtype['long']([[self.out_voc_size+1]*self.out_seq_len]*x.size(0))) + out = self.out_embed(out) + if self.encode_output_position: + pos_weight = nn.functional.softmax(self.output_pos_weight, dim=0) + out = out*pos_weight[0] + self.output_pos_vec*pos_weight[1] + for decoder in self.decoders: + cur_out = decoder(out, x) + y = self.linear(cur_out) + + else: + cur_out = self.out_embed(Variable(dtype['long']([self.out_voc_size]*x.size(0)). + unsqueeze(-1))) + y = [] + if self.unique_output: + self.seq_generated = None + for i in range(self.out_seq_len): + for decoder in self.decoders: + cur_out = decoder(cur_out, x) + cur_y = self.linear(cur_out)[:, -1] + y.append(cur_y) + if self.unique_output: + assert self.out_seq_len <= self.out_voc_size+1 + rank = cur_y.topk(self.out_seq_len, dim=-1)[1] + idx, self.seq_generated = get_uniq_topk(rank, self.seq_generated) + next_out = self.out_embed.weight[idx] + else: + next_out = self.out_embed.weight[cur_y.topk(1, dim=-1)[1].squeeze()] + if self.encode_output_position: + pos_weight = nn.functional.softmax(self.output_pos_weight, dim=0) + next_out = next_out*pos_weight[0] + self.output_pos_vec[i]*pos_weight[1] + cur_out = torch.cat([cur_out, next_out.unsqueeze(-2)], dim=-2) + y = torch.stack(y, dim=-2) + return y + + +class StackedEncoder(nn.Module): + """ + + Examples: + + model = StackedEncoder(in_dim=4, key_dim=3, value_dim=5, fc_dim=6, linear_dim=7, num_cls=8, num_heads=2, num_attention=2, + knn=None, residual=True, normalization=None, nonlinearity=nn.ReLU(), duplicated_attention=False, mask=False) + x = Variable(torch.randn(4, 4)) + model(x, return_graph=True, return_all=True) + """ + def __init__(self, in_dim, key_dim, value_dim, fc_dim, linear_dim, num_cls, num_heads=1, num_attention=1, + knn=None, residual=True, normalization=None, nonlinearity=nn.ReLU(), duplicated_attention=False, mask=False, + return_graph=False, return_all=False): + super(StackedEncoder, self).__init__() + + if duplicated_attention: + self.encoders = nn.ModuleList([EncoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, nonlinearity, knn=knn)] + * num_attention) + else: + self.encoders = nn.ModuleList() + for i in range(num_attention): + self.encoders.append(EncoderAttention( + in_dim, in_dim, key_dim, value_dim, fc_dim, num_heads, residual, normalization, nonlinearity, knn=knn)) + self.linear = nn.Linear(in_dim, num_cls) + self.return_graph = return_graph + self.return_all = return_all + + def forward(self, x): + return_graph = self.return_graph + return_all = self.return_all + if return_graph and return_all: + graphs = [] + for encoder in self.encoders: + if return_graph: + x, graph = encoder(x, return_graph=True) + if return_all: + graphs.append(graph) + else: + x = encoder(x) + out = self.linear(x) + if return_graph: + if return_all: + return out, graphs + else: + return out, graph + else: + return out \ No newline at end of file