[c3444c]: / pretrain / model.py

Download this file

179 lines (148 with data), 7.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#from transformers import BertConfig, BertPreTrainedModel, BertTokenizer, BertModel
from transformers import AutoConfig
from transformers import AutoModelForPreTraining
from transformers import AutoTokenizer
from transformers import AutoModel
from transformers.modeling_utils import SequenceSummary
from torch import nn
import torch.nn.functional as F
import torch
from loss import AMSoftmax
from pytorch_metric_learning import losses, miners
from trans import TransE
class UMLSPretrainedModel(nn.Module):
def __init__(self, device, model_name_or_path,
cui_label_count, rel_label_count, sty_label_count,
re_weight=1.0, sty_weight=0.1,
cui_loss_type="ms_loss",
trans_loss_type="TransE", trans_margin=1.0):
super(UMLSPretrainedModel, self).__init__()
self.device = device
self.model_name_or_path = model_name_or_path
if self.model_name_or_path.find("large") >= 0:
self.feature_dim = 1024
else:
self.feature_dim = 768
self.bert = AutoModel.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.dropout = nn.Dropout(0.1)
self.rel_label_count = rel_label_count
self.re_weight = re_weight
self.sty_label_count = sty_label_count
self.linear_sty = nn.Linear(self.feature_dim, self.sty_label_count)
self.sty_loss_fn = nn.CrossEntropyLoss()
self.sty_weight = sty_weight
self.cui_loss_type = cui_loss_type
self.cui_label_count = cui_label_count
if self.cui_loss_type == "softmax":
self.cui_loss_fn = nn.CrossEntropyLoss()
self.linear = nn.Linear(self.feature_dim, self.cui_label_count)
if self.cui_loss_type == "am_softmax":
self.cui_loss_fn = AMSoftmax(
self.feature_dim, self.cui_label_count)
if self.cui_loss_type == "ms_loss":
self.cui_loss_fn = losses.MultiSimilarityLoss(alpha=2, beta=50)
self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
self.trans_loss_type = trans_loss_type
if self.trans_loss_type == "TransE":
self.re_loss_fn = TransE(trans_margin)
self.re_embedding = nn.Embedding(
self.rel_label_count, self.feature_dim)
self.standard_dataloader = None
self.sequence_summary = SequenceSummary(AutoConfig.from_pretrained(model_name_or_path)) # Now only used for XLNet
def softmax(self, logits, label):
loss = self.cui_loss_fn(logits, label)
return loss
def am_softmax(self, pooled_output, label):
loss, _ = self.cui_loss_fn(pooled_output, label)
return loss
def ms_loss(self, pooled_output, label):
pairs = self.miner(pooled_output, label)
loss = self.cui_loss_fn(pooled_output, label, pairs)
return loss
def calculate_loss(self, pooled_output=None, logits=None, label=None):
if self.cui_loss_type == "softmax":
return self.softmax(logits, label)
if self.cui_loss_type == "am_softmax":
return self.am_softmax(pooled_output, label)
if self.cui_loss_type == "ms_loss":
return self.ms_loss(pooled_output, label)
def get_sentence_feature(self, input_ids):
# bert, albert, roberta
if self.model_name_or_path.find("xlnet") < 0:
outputs = self.bert(input_ids)
pooled_output = outputs[1]
return pooled_output
# xlnet
outputs = self.bert(input_ids)
pooled_output = self.sequence_summary(outputs[0])
return pooled_output
# @profile
def forward(self,
input_ids_0, input_ids_1, input_ids_2,
cui_label_0, cui_label_1, cui_label_2,
sty_label_0, sty_label_1, sty_label_2,
re_label):
input_ids = torch.cat((input_ids_0, input_ids_1, input_ids_2), 0)
cui_label = torch.cat((cui_label_0, cui_label_1, cui_label_2))
sty_label = torch.cat((sty_label_0, sty_label_1, sty_label_2))
#print(input_ids.shape, cui_label.shape, sty_label.shape)
use_len = input_ids_0.shape[0]
pooled_output = self.get_sentence_feature(
input_ids) # (3 * pair) * re_label
logits_sty = self.linear_sty(pooled_output)
sty_loss = self.sty_loss_fn(logits_sty, sty_label)
if self.cui_loss_type == "softmax":
logits = self.linear(pooled_output)
else:
logits = None
cui_loss = self.calculate_loss(pooled_output, logits, cui_label)
cui_0_output = pooled_output[0:use_len]
cui_1_output = pooled_output[use_len:2 * use_len]
cui_2_output = pooled_output[2 * use_len:]
re_output = self.re_embedding(re_label)
re_loss = self.re_loss_fn(
cui_0_output, cui_1_output, cui_2_output, re_output)
loss = self.sty_weight * sty_loss + cui_loss + self.re_weight * re_loss
#print(sty_loss.device, cui_loss.device, re_loss.device)
return loss, (sty_loss, cui_loss, re_loss)
"""
def predict(self, input_ids):
if self.loss_type == "softmax":
return self.predict_by_softmax(input_ids)
if self.loss_type == "am_softmax":
return self.predict_by_amsoftmax(input_ids)
def predict_by_softmax(self, input_ids):
pooled_output = self.get_sentence_feature(input_ids)
logits = self.linear(pooled_output)
return torch.max(logits, dim=1)[1], logits
def predict_by_amsoftmax(self, input_ids):
pooled_output = self.get_sentence_feature(input_ids)
logits = self.loss_fn.predict(pooled_output)
return torch.max(logits, dim=1)[1], logits
"""
def init_standard_feature(self):
if self.standard_dataloader is not None:
for index, batch in enumerate(self.standard_dataloader):
input_ids = batch[0].to(self.device)
outputs = self.get_sentence_feature(input_ids)
normalized_standard_feature = torch.norm(
outputs, p=2, dim=1, keepdim=True).clamp(min=1e-12)
normalized_standard_feature = torch.div(
outputs, normalized_standard_feature)
if index == 0:
self.standard_feature = normalized_standard_feature
else:
self.standard_feature = torch.cat(
(self.standard_feature, normalized_standard_feature), 0)
assert self.standard_feature.shape == (
self.num_label, self.feature_dim), self.standard_feature.shape
return None
def predict_by_cosine(self, input_ids):
pooled_output = self.get_sentence_feature(input_ids)
normalized_feature = torch.norm(
pooled_output, p=2, dim=1, keepdim=True).clamp(min=1e-12)
normalized_feature = torch.div(pooled_output, normalized_feature)
sim_mat = torch.matmul(normalized_feature, torch.t(
self.standard_feature)) # batch_size * num_label
return torch.max(sim_mat, dim=1)[1], sim_mat