Diff of /model/MLM.py [000000] .. [bad60c]

Switch to unified view

a b/model/MLM.py
1
import torch.nn as nn
2
import pytorch_pretrained_bert as Bert
3
import numpy as np
4
import torch
5
6
class BertEmbeddings(nn.Module):
7
    """Construct the embeddings from word, segment, age
8
    """
9
10
    def __init__(self, config):
11
        super(BertEmbeddings, self).__init__()
12
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
13
        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
14
        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
15
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size). \
16
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))
17
18
        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
19
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
20
21
    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):
22
        if seg_ids is None:
23
            seg_ids = torch.zeros_like(word_ids)
24
        if age_ids is None:
25
            age_ids = torch.zeros_like(word_ids)
26
        if posi_ids is None:
27
            posi_ids = torch.zeros_like(word_ids)
28
29
        word_embed = self.word_embeddings(word_ids)
30
        segment_embed = self.segment_embeddings(seg_ids)
31
        age_embed = self.age_embeddings(age_ids)
32
        posi_embeddings = self.posi_embeddings(posi_ids)
33
34
        if age:
35
            embeddings = word_embed + segment_embed + age_embed + posi_embeddings
36
        else:
37
            embeddings = word_embed + segment_embed + posi_embeddings
38
        embeddings = self.LayerNorm(embeddings)
39
        embeddings = self.dropout(embeddings)
40
        return embeddings
41
42
    def _init_posi_embedding(self, max_position_embedding, hidden_size):
43
        def even_code(pos, idx):
44
            return np.sin(pos / (10000 ** (2 * idx / hidden_size)))
45
46
        def odd_code(pos, idx):
47
            return np.cos(pos / (10000 ** (2 * idx / hidden_size)))
48
49
        # initialize position embedding table
50
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)
51
52
        # reset table parameters with hard encoding
53
        # set even dimension
54
        for pos in range(max_position_embedding):
55
            for idx in np.arange(0, hidden_size, step=2):
56
                lookup_table[pos, idx] = even_code(pos, idx)
57
        # set odd dimension
58
        for pos in range(max_position_embedding):
59
            for idx in np.arange(1, hidden_size, step=2):
60
                lookup_table[pos, idx] = odd_code(pos, idx)
61
62
        return torch.tensor(lookup_table)
63
64
65
class BertModel(Bert.modeling.BertPreTrainedModel):
66
    def __init__(self, config):
67
        super(BertModel, self).__init__(config)
68
        self.embeddings = BertEmbeddings(config=config)
69
        self.encoder = Bert.modeling.BertEncoder(config=config)
70
        self.pooler = Bert.modeling.BertPooler(config)
71
        self.apply(self.init_bert_weights)
72
73
    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None,
74
                output_all_encoded_layers=True):
75
        if attention_mask is None:
76
            attention_mask = torch.ones_like(input_ids)
77
        if age_ids is None:
78
            age_ids = torch.zeros_like(input_ids)
79
        if seg_ids is None:
80
            seg_ids = torch.zeros_like(input_ids)
81
        if posi_ids is None:
82
            posi_ids = torch.zeros_like(input_ids)
83
84
        # We create a 3D attention mask from a 2D tensor mask.
85
        # Sizes are [batch_size, 1, 1, to_seq_length]
86
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
87
        # this attention mask is more simple than the triangular masking of causal attention
88
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
89
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
90
91
        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
92
        # masked positions, this operation will create a tensor which is 0.0 for
93
        # positions we want to attend and -10000.0 for masked positions.
94
        # Since we are adding it to the raw scores before the softmax, this is
95
        # effectively the same as removing these entirely.
96
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
97
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
98
99
        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)
100
        encoded_layers = self.encoder(embedding_output,
101
                                      extended_attention_mask,
102
                                      output_all_encoded_layers=output_all_encoded_layers)
103
        sequence_output = encoded_layers[-1]
104
        pooled_output = self.pooler(sequence_output)
105
        if not output_all_encoded_layers:
106
            encoded_layers = encoded_layers[-1]
107
        return encoded_layers, pooled_output
108
109
110
class BertForMaskedLM(Bert.modeling.BertPreTrainedModel):
111
    def __init__(self, config):
112
        super(BertForMaskedLM, self).__init__(config)
113
        self.bert = BertModel(config)
114
        self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
115
        self.apply(self.init_bert_weights)
116
117
    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, masked_lm_labels=None):
118
        sequence_output, _ = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask,
119
                                       output_all_encoded_layers=False)
120
        prediction_scores = self.cls(sequence_output)
121
122
        if masked_lm_labels is not None:
123
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
124
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
125
            return masked_lm_loss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)
126
        else:
127
            return prediction_scores