|
a |
|
b/model/MLM.py |
|
|
1 |
import torch.nn as nn |
|
|
2 |
import pytorch_pretrained_bert as Bert |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
|
|
|
6 |
class BertEmbeddings(nn.Module): |
|
|
7 |
"""Construct the embeddings from word, segment, age |
|
|
8 |
""" |
|
|
9 |
|
|
|
10 |
def __init__(self, config): |
|
|
11 |
super(BertEmbeddings, self).__init__() |
|
|
12 |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
13 |
self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size) |
|
|
14 |
self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size) |
|
|
15 |
self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size). \ |
|
|
16 |
from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size)) |
|
|
17 |
|
|
|
18 |
self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12) |
|
|
19 |
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
20 |
|
|
|
21 |
def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True): |
|
|
22 |
if seg_ids is None: |
|
|
23 |
seg_ids = torch.zeros_like(word_ids) |
|
|
24 |
if age_ids is None: |
|
|
25 |
age_ids = torch.zeros_like(word_ids) |
|
|
26 |
if posi_ids is None: |
|
|
27 |
posi_ids = torch.zeros_like(word_ids) |
|
|
28 |
|
|
|
29 |
word_embed = self.word_embeddings(word_ids) |
|
|
30 |
segment_embed = self.segment_embeddings(seg_ids) |
|
|
31 |
age_embed = self.age_embeddings(age_ids) |
|
|
32 |
posi_embeddings = self.posi_embeddings(posi_ids) |
|
|
33 |
|
|
|
34 |
if age: |
|
|
35 |
embeddings = word_embed + segment_embed + age_embed + posi_embeddings |
|
|
36 |
else: |
|
|
37 |
embeddings = word_embed + segment_embed + posi_embeddings |
|
|
38 |
embeddings = self.LayerNorm(embeddings) |
|
|
39 |
embeddings = self.dropout(embeddings) |
|
|
40 |
return embeddings |
|
|
41 |
|
|
|
42 |
def _init_posi_embedding(self, max_position_embedding, hidden_size): |
|
|
43 |
def even_code(pos, idx): |
|
|
44 |
return np.sin(pos / (10000 ** (2 * idx / hidden_size))) |
|
|
45 |
|
|
|
46 |
def odd_code(pos, idx): |
|
|
47 |
return np.cos(pos / (10000 ** (2 * idx / hidden_size))) |
|
|
48 |
|
|
|
49 |
# initialize position embedding table |
|
|
50 |
lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32) |
|
|
51 |
|
|
|
52 |
# reset table parameters with hard encoding |
|
|
53 |
# set even dimension |
|
|
54 |
for pos in range(max_position_embedding): |
|
|
55 |
for idx in np.arange(0, hidden_size, step=2): |
|
|
56 |
lookup_table[pos, idx] = even_code(pos, idx) |
|
|
57 |
# set odd dimension |
|
|
58 |
for pos in range(max_position_embedding): |
|
|
59 |
for idx in np.arange(1, hidden_size, step=2): |
|
|
60 |
lookup_table[pos, idx] = odd_code(pos, idx) |
|
|
61 |
|
|
|
62 |
return torch.tensor(lookup_table) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
class BertModel(Bert.modeling.BertPreTrainedModel): |
|
|
66 |
def __init__(self, config): |
|
|
67 |
super(BertModel, self).__init__(config) |
|
|
68 |
self.embeddings = BertEmbeddings(config=config) |
|
|
69 |
self.encoder = Bert.modeling.BertEncoder(config=config) |
|
|
70 |
self.pooler = Bert.modeling.BertPooler(config) |
|
|
71 |
self.apply(self.init_bert_weights) |
|
|
72 |
|
|
|
73 |
def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, |
|
|
74 |
output_all_encoded_layers=True): |
|
|
75 |
if attention_mask is None: |
|
|
76 |
attention_mask = torch.ones_like(input_ids) |
|
|
77 |
if age_ids is None: |
|
|
78 |
age_ids = torch.zeros_like(input_ids) |
|
|
79 |
if seg_ids is None: |
|
|
80 |
seg_ids = torch.zeros_like(input_ids) |
|
|
81 |
if posi_ids is None: |
|
|
82 |
posi_ids = torch.zeros_like(input_ids) |
|
|
83 |
|
|
|
84 |
# We create a 3D attention mask from a 2D tensor mask. |
|
|
85 |
# Sizes are [batch_size, 1, 1, to_seq_length] |
|
|
86 |
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
|
|
87 |
# this attention mask is more simple than the triangular masking of causal attention |
|
|
88 |
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. |
|
|
89 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
90 |
|
|
|
91 |
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for |
|
|
92 |
# masked positions, this operation will create a tensor which is 0.0 for |
|
|
93 |
# positions we want to attend and -10000.0 for masked positions. |
|
|
94 |
# Since we are adding it to the raw scores before the softmax, this is |
|
|
95 |
# effectively the same as removing these entirely. |
|
|
96 |
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility |
|
|
97 |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
98 |
|
|
|
99 |
embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids) |
|
|
100 |
encoded_layers = self.encoder(embedding_output, |
|
|
101 |
extended_attention_mask, |
|
|
102 |
output_all_encoded_layers=output_all_encoded_layers) |
|
|
103 |
sequence_output = encoded_layers[-1] |
|
|
104 |
pooled_output = self.pooler(sequence_output) |
|
|
105 |
if not output_all_encoded_layers: |
|
|
106 |
encoded_layers = encoded_layers[-1] |
|
|
107 |
return encoded_layers, pooled_output |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
class BertForMaskedLM(Bert.modeling.BertPreTrainedModel): |
|
|
111 |
def __init__(self, config): |
|
|
112 |
super(BertForMaskedLM, self).__init__(config) |
|
|
113 |
self.bert = BertModel(config) |
|
|
114 |
self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) |
|
|
115 |
self.apply(self.init_bert_weights) |
|
|
116 |
|
|
|
117 |
def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, masked_lm_labels=None): |
|
|
118 |
sequence_output, _ = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask, |
|
|
119 |
output_all_encoded_layers=False) |
|
|
120 |
prediction_scores = self.cls(sequence_output) |
|
|
121 |
|
|
|
122 |
if masked_lm_labels is not None: |
|
|
123 |
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
|
|
124 |
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) |
|
|
125 |
return masked_lm_loss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) |
|
|
126 |
else: |
|
|
127 |
return prediction_scores |