a b/medicalbert/classifiers/standard/bert_model.py
1
from torch import nn
2
from torch.nn import CrossEntropyLoss
3
from transformers import BertPreTrainedModel, BertModel
4
5
6
class BertForSequenceClassification(BertPreTrainedModel):
7
    def __init__(self, config):
8
        super(BertForSequenceClassification, self).__init__(config)
9
        self.num_labels = config.num_labels
10
11
        self.bert = BertModel(config)
12
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
13
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
14
        self.head = nn.Softmax(dim=1)
15
        self.init_weights()
16
17
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
18
                position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
19
20
        outputs = self.bert(input_ids,
21
                            attention_mask=attention_mask,
22
                            token_type_ids=token_type_ids,
23
                            position_ids=position_ids,
24
                            head_mask=head_mask,
25
                            inputs_embeds=inputs_embeds)
26
27
        pooled_output = outputs[1]
28
29
        pooled_output = self.dropout(pooled_output)
30
        logits = self.classifier(pooled_output)
31
        logits = self.head(logits)
32
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
33
34
        if labels is not None:
35
36
            loss_fct = CrossEntropyLoss()
37
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
38
39
            outputs = (loss,) + outputs
40
41
        return outputs  # (loss), logits, (hidden_states), (attentions)