|
a |
|
b/medicalbert/classifiers/standard/bert_model.py |
|
|
1 |
from torch import nn |
|
|
2 |
from torch.nn import CrossEntropyLoss |
|
|
3 |
from transformers import BertPreTrainedModel, BertModel |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class BertForSequenceClassification(BertPreTrainedModel): |
|
|
7 |
def __init__(self, config): |
|
|
8 |
super(BertForSequenceClassification, self).__init__(config) |
|
|
9 |
self.num_labels = config.num_labels |
|
|
10 |
|
|
|
11 |
self.bert = BertModel(config) |
|
|
12 |
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
13 |
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
14 |
self.head = nn.Softmax(dim=1) |
|
|
15 |
self.init_weights() |
|
|
16 |
|
|
|
17 |
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, |
|
|
18 |
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): |
|
|
19 |
|
|
|
20 |
outputs = self.bert(input_ids, |
|
|
21 |
attention_mask=attention_mask, |
|
|
22 |
token_type_ids=token_type_ids, |
|
|
23 |
position_ids=position_ids, |
|
|
24 |
head_mask=head_mask, |
|
|
25 |
inputs_embeds=inputs_embeds) |
|
|
26 |
|
|
|
27 |
pooled_output = outputs[1] |
|
|
28 |
|
|
|
29 |
pooled_output = self.dropout(pooled_output) |
|
|
30 |
logits = self.classifier(pooled_output) |
|
|
31 |
logits = self.head(logits) |
|
|
32 |
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here |
|
|
33 |
|
|
|
34 |
if labels is not None: |
|
|
35 |
|
|
|
36 |
loss_fct = CrossEntropyLoss() |
|
|
37 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
38 |
|
|
|
39 |
outputs = (loss,) + outputs |
|
|
40 |
|
|
|
41 |
return outputs # (loss), logits, (hidden_states), (attentions) |