Diff of /pretrain/model.py [000000] .. [c3444c]

Switch to unified view

a b/pretrain/model.py
1
#from transformers import BertConfig, BertPreTrainedModel, BertTokenizer, BertModel
2
from transformers import AutoConfig
3
from transformers import AutoModelForPreTraining
4
from transformers import AutoTokenizer
5
from transformers import AutoModel
6
from transformers.modeling_utils import SequenceSummary
7
from torch import nn
8
import torch.nn.functional as F
9
import torch
10
from loss import AMSoftmax
11
from pytorch_metric_learning import losses, miners
12
from trans import TransE
13
14
15
class UMLSPretrainedModel(nn.Module):
16
    def __init__(self, device, model_name_or_path,
17
                 cui_label_count, rel_label_count, sty_label_count,
18
                 re_weight=1.0, sty_weight=0.1,
19
                 cui_loss_type="ms_loss",
20
                 trans_loss_type="TransE", trans_margin=1.0):
21
        super(UMLSPretrainedModel, self).__init__()
22
23
        self.device = device
24
        self.model_name_or_path = model_name_or_path
25
        if self.model_name_or_path.find("large") >= 0:
26
            self.feature_dim = 1024
27
        else:
28
            self.feature_dim = 768
29
        self.bert = AutoModel.from_pretrained(model_name_or_path)
30
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
31
        self.dropout = nn.Dropout(0.1)
32
33
        self.rel_label_count = rel_label_count
34
        self.re_weight = re_weight
35
36
        self.sty_label_count = sty_label_count
37
        self.linear_sty = nn.Linear(self.feature_dim, self.sty_label_count)
38
        self.sty_loss_fn = nn.CrossEntropyLoss()
39
        self.sty_weight = sty_weight
40
41
        self.cui_loss_type = cui_loss_type
42
        self.cui_label_count = cui_label_count
43
44
        if self.cui_loss_type == "softmax":
45
            self.cui_loss_fn = nn.CrossEntropyLoss()
46
            self.linear = nn.Linear(self.feature_dim, self.cui_label_count)
47
        if self.cui_loss_type == "am_softmax":
48
            self.cui_loss_fn = AMSoftmax(
49
                self.feature_dim, self.cui_label_count)
50
        if self.cui_loss_type == "ms_loss":
51
            self.cui_loss_fn = losses.MultiSimilarityLoss(alpha=2, beta=50)
52
            self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
53
54
        self.trans_loss_type = trans_loss_type
55
        if self.trans_loss_type == "TransE":
56
            self.re_loss_fn = TransE(trans_margin)
57
        self.re_embedding = nn.Embedding(
58
            self.rel_label_count, self.feature_dim)
59
60
        self.standard_dataloader = None
61
62
        self.sequence_summary = SequenceSummary(AutoConfig.from_pretrained(model_name_or_path)) # Now only used for XLNet
63
64
    def softmax(self, logits, label):
65
        loss = self.cui_loss_fn(logits, label)
66
        return loss
67
68
    def am_softmax(self, pooled_output, label):
69
        loss, _ = self.cui_loss_fn(pooled_output, label)
70
        return loss
71
72
    def ms_loss(self, pooled_output, label):
73
        pairs = self.miner(pooled_output, label)
74
        loss = self.cui_loss_fn(pooled_output, label, pairs)
75
        return loss
76
77
    def calculate_loss(self, pooled_output=None, logits=None, label=None):
78
        if self.cui_loss_type == "softmax":
79
            return self.softmax(logits, label)
80
        if self.cui_loss_type == "am_softmax":
81
            return self.am_softmax(pooled_output, label)
82
        if self.cui_loss_type == "ms_loss":
83
            return self.ms_loss(pooled_output, label)
84
85
    def get_sentence_feature(self, input_ids):
86
        # bert, albert, roberta
87
        if self.model_name_or_path.find("xlnet") < 0:
88
            outputs = self.bert(input_ids)
89
            pooled_output = outputs[1]
90
            return pooled_output
91
92
        # xlnet
93
        outputs = self.bert(input_ids)
94
        pooled_output = self.sequence_summary(outputs[0])
95
        return pooled_output
96
97
98
    # @profile
99
    def forward(self,
100
                input_ids_0, input_ids_1, input_ids_2,
101
                cui_label_0, cui_label_1, cui_label_2,
102
                sty_label_0, sty_label_1, sty_label_2,
103
                re_label):
104
        input_ids = torch.cat((input_ids_0, input_ids_1, input_ids_2), 0)
105
        cui_label = torch.cat((cui_label_0, cui_label_1, cui_label_2))
106
        sty_label = torch.cat((sty_label_0, sty_label_1, sty_label_2))
107
        #print(input_ids.shape, cui_label.shape, sty_label.shape)
108
109
        use_len = input_ids_0.shape[0]
110
111
        pooled_output = self.get_sentence_feature(
112
            input_ids)  # (3 * pair) * re_label
113
        logits_sty = self.linear_sty(pooled_output)
114
        sty_loss = self.sty_loss_fn(logits_sty, sty_label)
115
116
        if self.cui_loss_type == "softmax":
117
            logits = self.linear(pooled_output)
118
        else:
119
            logits = None
120
        cui_loss = self.calculate_loss(pooled_output, logits, cui_label)
121
122
        cui_0_output = pooled_output[0:use_len]
123
        cui_1_output = pooled_output[use_len:2 * use_len]
124
        cui_2_output = pooled_output[2 * use_len:]
125
        re_output = self.re_embedding(re_label)
126
        re_loss = self.re_loss_fn(
127
            cui_0_output, cui_1_output, cui_2_output, re_output)
128
129
        loss = self.sty_weight * sty_loss + cui_loss + self.re_weight * re_loss
130
        #print(sty_loss.device, cui_loss.device, re_loss.device)
131
132
        return loss, (sty_loss, cui_loss, re_loss)
133
134
    """
135
    def predict(self, input_ids):
136
        if self.loss_type == "softmax":
137
            return self.predict_by_softmax(input_ids)
138
        if self.loss_type == "am_softmax":
139
            return self.predict_by_amsoftmax(input_ids)        
140
141
    def predict_by_softmax(self, input_ids):
142
        pooled_output = self.get_sentence_feature(input_ids)
143
        logits = self.linear(pooled_output)
144
        return torch.max(logits, dim=1)[1], logits
145
146
    def predict_by_amsoftmax(self, input_ids):
147
        pooled_output = self.get_sentence_feature(input_ids)
148
        logits = self.loss_fn.predict(pooled_output)
149
        return torch.max(logits, dim=1)[1], logits
150
    """
151
152
    def init_standard_feature(self):
153
        if self.standard_dataloader is not None:
154
            for index, batch in enumerate(self.standard_dataloader):
155
                input_ids = batch[0].to(self.device)
156
                outputs = self.get_sentence_feature(input_ids)
157
                normalized_standard_feature = torch.norm(
158
                    outputs, p=2, dim=1, keepdim=True).clamp(min=1e-12)
159
                normalized_standard_feature = torch.div(
160
                    outputs, normalized_standard_feature)
161
                if index == 0:
162
                    self.standard_feature = normalized_standard_feature
163
                else:
164
                    self.standard_feature = torch.cat(
165
                        (self.standard_feature, normalized_standard_feature), 0)
166
            assert self.standard_feature.shape == (
167
                self.num_label, self.feature_dim), self.standard_feature.shape
168
        return None
169
170
    def predict_by_cosine(self, input_ids):
171
        pooled_output = self.get_sentence_feature(input_ids)
172
173
        normalized_feature = torch.norm(
174
            pooled_output, p=2, dim=1, keepdim=True).clamp(min=1e-12)
175
        normalized_feature = torch.div(pooled_output, normalized_feature)
176
        sim_mat = torch.matmul(normalized_feature, torch.t(
177
            self.standard_feature))  # batch_size * num_label
178
        return torch.max(sim_mat, dim=1)[1], sim_mat