--- a +++ b/pretrain/model.py @@ -0,0 +1,178 @@ +#from transformers import BertConfig, BertPreTrainedModel, BertTokenizer, BertModel +from transformers import AutoConfig +from transformers import AutoModelForPreTraining +from transformers import AutoTokenizer +from transformers import AutoModel +from transformers.modeling_utils import SequenceSummary +from torch import nn +import torch.nn.functional as F +import torch +from loss import AMSoftmax +from pytorch_metric_learning import losses, miners +from trans import TransE + + +class UMLSPretrainedModel(nn.Module): + def __init__(self, device, model_name_or_path, + cui_label_count, rel_label_count, sty_label_count, + re_weight=1.0, sty_weight=0.1, + cui_loss_type="ms_loss", + trans_loss_type="TransE", trans_margin=1.0): + super(UMLSPretrainedModel, self).__init__() + + self.device = device + self.model_name_or_path = model_name_or_path + if self.model_name_or_path.find("large") >= 0: + self.feature_dim = 1024 + else: + self.feature_dim = 768 + self.bert = AutoModel.from_pretrained(model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.dropout = nn.Dropout(0.1) + + self.rel_label_count = rel_label_count + self.re_weight = re_weight + + self.sty_label_count = sty_label_count + self.linear_sty = nn.Linear(self.feature_dim, self.sty_label_count) + self.sty_loss_fn = nn.CrossEntropyLoss() + self.sty_weight = sty_weight + + self.cui_loss_type = cui_loss_type + self.cui_label_count = cui_label_count + + if self.cui_loss_type == "softmax": + self.cui_loss_fn = nn.CrossEntropyLoss() + self.linear = nn.Linear(self.feature_dim, self.cui_label_count) + if self.cui_loss_type == "am_softmax": + self.cui_loss_fn = AMSoftmax( + self.feature_dim, self.cui_label_count) + if self.cui_loss_type == "ms_loss": + self.cui_loss_fn = losses.MultiSimilarityLoss(alpha=2, beta=50) + self.miner = miners.MultiSimilarityMiner(epsilon=0.1) + + self.trans_loss_type = trans_loss_type + if self.trans_loss_type == "TransE": + self.re_loss_fn = TransE(trans_margin) + self.re_embedding = nn.Embedding( + self.rel_label_count, self.feature_dim) + + self.standard_dataloader = None + + self.sequence_summary = SequenceSummary(AutoConfig.from_pretrained(model_name_or_path)) # Now only used for XLNet + + def softmax(self, logits, label): + loss = self.cui_loss_fn(logits, label) + return loss + + def am_softmax(self, pooled_output, label): + loss, _ = self.cui_loss_fn(pooled_output, label) + return loss + + def ms_loss(self, pooled_output, label): + pairs = self.miner(pooled_output, label) + loss = self.cui_loss_fn(pooled_output, label, pairs) + return loss + + def calculate_loss(self, pooled_output=None, logits=None, label=None): + if self.cui_loss_type == "softmax": + return self.softmax(logits, label) + if self.cui_loss_type == "am_softmax": + return self.am_softmax(pooled_output, label) + if self.cui_loss_type == "ms_loss": + return self.ms_loss(pooled_output, label) + + def get_sentence_feature(self, input_ids): + # bert, albert, roberta + if self.model_name_or_path.find("xlnet") < 0: + outputs = self.bert(input_ids) + pooled_output = outputs[1] + return pooled_output + + # xlnet + outputs = self.bert(input_ids) + pooled_output = self.sequence_summary(outputs[0]) + return pooled_output + + + # @profile + def forward(self, + input_ids_0, input_ids_1, input_ids_2, + cui_label_0, cui_label_1, cui_label_2, + sty_label_0, sty_label_1, sty_label_2, + re_label): + input_ids = torch.cat((input_ids_0, input_ids_1, input_ids_2), 0) + cui_label = torch.cat((cui_label_0, cui_label_1, cui_label_2)) + sty_label = torch.cat((sty_label_0, sty_label_1, sty_label_2)) + #print(input_ids.shape, cui_label.shape, sty_label.shape) + + use_len = input_ids_0.shape[0] + + pooled_output = self.get_sentence_feature( + input_ids) # (3 * pair) * re_label + logits_sty = self.linear_sty(pooled_output) + sty_loss = self.sty_loss_fn(logits_sty, sty_label) + + if self.cui_loss_type == "softmax": + logits = self.linear(pooled_output) + else: + logits = None + cui_loss = self.calculate_loss(pooled_output, logits, cui_label) + + cui_0_output = pooled_output[0:use_len] + cui_1_output = pooled_output[use_len:2 * use_len] + cui_2_output = pooled_output[2 * use_len:] + re_output = self.re_embedding(re_label) + re_loss = self.re_loss_fn( + cui_0_output, cui_1_output, cui_2_output, re_output) + + loss = self.sty_weight * sty_loss + cui_loss + self.re_weight * re_loss + #print(sty_loss.device, cui_loss.device, re_loss.device) + + return loss, (sty_loss, cui_loss, re_loss) + + """ + def predict(self, input_ids): + if self.loss_type == "softmax": + return self.predict_by_softmax(input_ids) + if self.loss_type == "am_softmax": + return self.predict_by_amsoftmax(input_ids) + + def predict_by_softmax(self, input_ids): + pooled_output = self.get_sentence_feature(input_ids) + logits = self.linear(pooled_output) + return torch.max(logits, dim=1)[1], logits + + def predict_by_amsoftmax(self, input_ids): + pooled_output = self.get_sentence_feature(input_ids) + logits = self.loss_fn.predict(pooled_output) + return torch.max(logits, dim=1)[1], logits + """ + + def init_standard_feature(self): + if self.standard_dataloader is not None: + for index, batch in enumerate(self.standard_dataloader): + input_ids = batch[0].to(self.device) + outputs = self.get_sentence_feature(input_ids) + normalized_standard_feature = torch.norm( + outputs, p=2, dim=1, keepdim=True).clamp(min=1e-12) + normalized_standard_feature = torch.div( + outputs, normalized_standard_feature) + if index == 0: + self.standard_feature = normalized_standard_feature + else: + self.standard_feature = torch.cat( + (self.standard_feature, normalized_standard_feature), 0) + assert self.standard_feature.shape == ( + self.num_label, self.feature_dim), self.standard_feature.shape + return None + + def predict_by_cosine(self, input_ids): + pooled_output = self.get_sentence_feature(input_ids) + + normalized_feature = torch.norm( + pooled_output, p=2, dim=1, keepdim=True).clamp(min=1e-12) + normalized_feature = torch.div(pooled_output, normalized_feature) + sim_mat = torch.matmul(normalized_feature, torch.t( + self.standard_feature)) # batch_size * num_label + return torch.max(sim_mat, dim=1)[1], sim_mat