a b/coderpp/train/model.py
1
#from transformers import BertConfig, BertPreTrainedModel, BertTokenizer, BertModel
2
from transformers import AutoConfig
3
from transformers import AutoModelForPreTraining
4
from transformers import AutoTokenizer
5
from transformers import AutoModel
6
from transformers.modeling_utils import SequenceSummary
7
from torch import nn
8
import torch.nn.functional as F
9
import torch
10
from loss import AMSoftmax
11
from pytorch_metric_learning import losses, miners
12
13
class UMLSFinetuneModel(nn.Module):
14
    def __init__(self, device, model_name_or_path, cui_label_count, cui_loss_type="ms_loss"):
15
        super(UMLSFinetuneModel, self).__init__()
16
17
        self.device = device
18
        self.model_name_or_path = model_name_or_path
19
        self.config = AutoConfig.from_pretrained(model_name_or_path)
20
        self.bert = AutoModel.from_pretrained(self.model_name_or_path, config=self.config)
21
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
22
        self.dropout = nn.Dropout(0.1)
23
        self.feature_dim = 768
24
25
        self.cui_loss_type = cui_loss_type
26
        self.cui_label_count = cui_label_count
27
28
        if self.cui_loss_type == "softmax":
29
            self.cui_loss_fn = nn.CrossEntropyLoss()
30
            self.linear = nn.Linear(self.feature_dim, self.cui_label_count)
31
        if self.cui_loss_type == "am_softmax":
32
            self.cui_loss_fn = AMSoftmax(self.feature_dim, self.cui_label_count)
33
        if self.cui_loss_type == "ms_loss":
34
            self.cui_loss_fn = losses.MultiSimilarityLoss(alpha=2, beta=50)
35
            self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
36
    
37
    def softmax(self, logits, label):
38
        loss = self.cui_loss_fn(logits, label)
39
        return loss
40
    
41
    def am_softmax(self, pooled_output, label):
42
        loss, _ = self.cui_loss_fn(pooled_output, label)
43
        return loss
44
    
45
    def ms_loss(self, pooled_output, label):
46
        pairs = self.miner(pooled_output, label)
47
        loss = self.cui_loss_fn(pooled_output, label, pairs)
48
        return loss
49
    
50
    def calculate_loss(self, pooled_output=None, logits=None, label=None):
51
        if self.cui_loss_type == "softmax":
52
            return self.softmax(logits, label)
53
        if self.cui_loss_type == "am_softmax":
54
            return self.am_softmax(pooled_output, label)
55
        if self.cui_loss_type == "ms_loss":
56
            return self.ms_loss(pooled_output, label)    
57
    
58
    def get_sentence_feature(self, input_ids, attention_mask):
59
        outputs = self.bert(input_ids, attention_mask)
60
        pooled_output = outputs[1]
61
        return pooled_output
62
63
    def forward(self, input_ids, cui_label, attention_mask):
64
        pooled_output = self.get_sentence_feature(input_ids, attention_mask)
65
        if self.cui_loss_type == "softmax":
66
            logits = self.linear(pooled_output)
67
        else:
68
            logits = None
69
        cui_loss = self.calculate_loss(pooled_output, logits, cui_label)            
70
        loss = cui_loss
71
        return loss