[d129b2]: / medicalbert / classifiers / standard / bert_classifier.py

Download this file

17 lines (12 with data), 584 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
from classifiers.standard.bert_model import BertForSequenceClassification
from classifiers.standard.classifier import Classifier
class BertGeneralClassifier(Classifier):
def __init__(self, config):
self.config = config
self.model = BertForSequenceClassification.from_pretrained(self.config['pretrained_model'])
# here, we can do some layer removal if we want to
# setup the optimizer
self.optimizer = torch.optim.Adam(self.model.parameters(), self.config['learning_rate'])
self.epochs = 0
print(self.model)