Diff of /model/NextXVisit.py [000000] .. [bad60c]

Switch to unified view

a b/model/NextXVisit.py
1
import torch
2
import torch.nn as nn
3
import pytorch_pretrained_bert as Bert
4
import numpy as np
5
6
7
class BertEmbeddings(nn.Module):
8
    """Construct the embeddings from word, segment, age
9
    """
10
11
    def __init__(self, config, feature_dict=None):
12
        super(BertEmbeddings, self).__init__()
13
14
        if feature_dict is None:
15
            self.feature_dict = {
16
                'word': True,
17
                'seg': True,
18
                'age': True,
19
                'position': True
20
            }
21
        else:
22
            self.feature_dict = feature_dict
23
24
        if feature_dict['word']:
25
            self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
26
27
        if feature_dict['seg']:
28
            self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
29
30
        if feature_dict['age']:
31
            self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
32
33
        if feature_dict['position']:
34
            self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size). \
35
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))
36
37
        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
38
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
39
40
    def forward(self, word_ids, age_ids, seg_ids, posi_ids):
41
        embeddings = self.word_embeddings(word_ids)
42
43
        if self.feature_dict['seg']:
44
            segment_embed = self.segment_embeddings(seg_ids)
45
            embeddings = embeddings + segment_embed
46
47
        if self.feature_dict['age']:
48
            age_embed = self.age_embeddings(age_ids)
49
            embeddings = embeddings + age_embed
50
51
        if self.feature_dict['position']:
52
            posi_embeddings = self.posi_embeddings(posi_ids)
53
            embeddings = embeddings + posi_embeddings
54
55
        embeddings = self.LayerNorm(embeddings)
56
        embeddings = self.dropout(embeddings)
57
        return embeddings
58
59
    def _init_posi_embedding(self, max_position_embedding, hidden_size):
60
        def even_code(pos, idx):
61
            return np.sin(pos / (10000 ** (2 * idx / hidden_size)))
62
63
        def odd_code(pos, idx):
64
            return np.cos(pos / (10000 ** (2 * idx / hidden_size)))
65
66
        # initialize position embedding table
67
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)
68
69
        # reset table parameters with hard encoding
70
        # set even dimension
71
        for pos in range(max_position_embedding):
72
            for idx in np.arange(0, hidden_size, step=2):
73
                lookup_table[pos, idx] = even_code(pos, idx)
74
        # set odd dimension
75
        for pos in range(max_position_embedding):
76
            for idx in np.arange(1, hidden_size, step=2):
77
                lookup_table[pos, idx] = odd_code(pos, idx)
78
79
        return torch.tensor(lookup_table)
80
81
82
class BertModel(Bert.modeling.BertPreTrainedModel):
83
    def __init__(self, config, feature_dict):
84
        super(BertModel, self).__init__(config)
85
        self.embeddings = BertEmbeddings(config=config, feature_dict=feature_dict)
86
        self.encoder = Bert.modeling.BertEncoder(config=config)
87
        self.pooler = Bert.modeling.BertPooler(config)
88
        self.apply(self.init_bert_weights)
89
90
    def forward(self, input_ids, age_ids, seg_ids, posi_ids, attention_mask,
91
                output_all_encoded_layers=True):
92
93
        # We create a 3D attention mask from a 2D tensor mask.
94
        # Sizes are [batch_size, 1, 1, to_seq_length]
95
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
96
        # this attention mask is more simple than the triangular masking of causal attention
97
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
98
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
99
100
        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
101
        # masked positions, this operation will create a tensor which is 0.0 for
102
        # positions we want to attend and -10000.0 for masked positions.
103
        # Since we are adding it to the raw scores before the softmax, this is
104
        # effectively the same as removing these entirely.
105
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
106
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
107
108
        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)
109
        encoded_layers = self.encoder(embedding_output,
110
                                      extended_attention_mask,
111
                                      output_all_encoded_layers=output_all_encoded_layers)
112
        sequence_output = encoded_layers[-1]
113
        pooled_output = self.pooler(sequence_output)
114
        if not output_all_encoded_layers:
115
            encoded_layers = encoded_layers[-1]
116
        return encoded_layers, pooled_output
117
118
119
class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):
120
    def __init__(self, config, num_labels, feature_dict):
121
        super(BertForMultiLabelPrediction, self).__init__(config)
122
        self.num_labels = num_labels
123
        self.bert = BertModel(config, feature_dict)
124
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
125
        self.classifier = nn.Linear(config.hidden_size, num_labels)
126
        self.apply(self.init_bert_weights)
127
128
    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):
129
        _, pooled_output = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask,
130
                                     output_all_encoded_layers=False)
131
        pooled_output = self.dropout(pooled_output)
132
        logits = self.classifier(pooled_output)
133
134
        if labels is not None:
135
            loss_fct = nn.MultiLabelSoftMarginLoss()
136
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
137
            return loss, logits
138
        else:
139
            return logits