[d129b2]: / medicalbert / classifiers / sequential / bert_sequence_classifier.py

Download this file

20 lines (13 with data), 707 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
from classifiers.sequential.sequence_classifier import SequenceClassifier
from classifiers.sequential.bert_sequence_wrapper import BertSequenceWrapper
from transformers import BertModel
class BertSequenceClassifier(SequenceClassifier):
def __init__(self, config):
self.config = config
baseModel = BertModel.from_pretrained(self.config['pretrained_model'], output_hidden_states=True)
self.model = BertSequenceWrapper(baseModel, 2, config['num_sections'])
self.optimizer = torch.optim.Adam(self.model.parameters(), self.config['learning_rate'])
# here, we can do some layer removal if we want to
self.epochs = 0
print(self.model)