--- a +++ b/model/NextXVisit.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import pytorch_pretrained_bert as Bert +import numpy as np + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, segment, age + """ + + def __init__(self, config, feature_dict=None): + super(BertEmbeddings, self).__init__() + + if feature_dict is None: + self.feature_dict = { + 'word': True, + 'seg': True, + 'age': True, + 'position': True + } + else: + self.feature_dict = feature_dict + + if feature_dict['word']: + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + + if feature_dict['seg']: + self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size) + + if feature_dict['age']: + self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size) + + if feature_dict['position']: + self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size). \ + from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size)) + + self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, word_ids, age_ids, seg_ids, posi_ids): + embeddings = self.word_embeddings(word_ids) + + if self.feature_dict['seg']: + segment_embed = self.segment_embeddings(seg_ids) + embeddings = embeddings + segment_embed + + if self.feature_dict['age']: + age_embed = self.age_embeddings(age_ids) + embeddings = embeddings + age_embed + + if self.feature_dict['position']: + posi_embeddings = self.posi_embeddings(posi_ids) + embeddings = embeddings + posi_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def _init_posi_embedding(self, max_position_embedding, hidden_size): + def even_code(pos, idx): + return np.sin(pos / (10000 ** (2 * idx / hidden_size))) + + def odd_code(pos, idx): + return np.cos(pos / (10000 ** (2 * idx / hidden_size))) + + # initialize position embedding table + lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32) + + # reset table parameters with hard encoding + # set even dimension + for pos in range(max_position_embedding): + for idx in np.arange(0, hidden_size, step=2): + lookup_table[pos, idx] = even_code(pos, idx) + # set odd dimension + for pos in range(max_position_embedding): + for idx in np.arange(1, hidden_size, step=2): + lookup_table[pos, idx] = odd_code(pos, idx) + + return torch.tensor(lookup_table) + + +class BertModel(Bert.modeling.BertPreTrainedModel): + def __init__(self, config, feature_dict): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config=config, feature_dict=feature_dict) + self.encoder = Bert.modeling.BertEncoder(config=config) + self.pooler = Bert.modeling.BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, age_ids, seg_ids, posi_ids, attention_mask, + output_all_encoded_layers=True): + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel): + def __init__(self, config, num_labels, feature_dict): + super(BertForMultiLabelPrediction, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config, feature_dict) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None): + _, pooled_output = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask, + output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = nn.MultiLabelSoftMarginLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) + return loss, logits + else: + return logits \ No newline at end of file