--- a +++ b/pytorch_pretrained_bert/module.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import pytorch_pretrained_bert as Bert + + +def sequence_mask(sequence_length, max_len=None, device=None): + sequence_length = torch.tensor(sequence_length) + max_len = torch.tensor(max_len) + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.to(device) + seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + mask= seq_range_expand < seq_length_expand + return mask.detach().long() + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, segment, age + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size) + self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size) + + self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, word_ids, age_ids=None, seg_ids=None): + if seg_ids is None: + seg_ids = torch.zeros_like(word_ids) + if age_ids is None: + age_ids = torch.zeros_like(word_ids) + + word_embed = self.word_embeddings(word_ids) + segment_embed = self.segment_embeddings(seg_ids) + age_embed = self.age_embeddings(age_ids) + + embeddings = word_embed + segment_embed + age_embed + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertModel(Bert.modeling.BertPreTrainedModel): + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config=config) + self.encoder = Bert.modeling.BertEncoder(config=config) + self.pooler = Bert.modeling.BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, age_ids=None, seg_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if age_ids is None: + age_ids = torch.zeros_like(input_ids) + if seg_ids is None: + seg_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, age_ids, seg_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output \ No newline at end of file