|
a |
|
b/medicalbert/classifiers/standard/bert_classifier.py |
|
|
1 |
import torch |
|
|
2 |
from classifiers.standard.bert_model import BertForSequenceClassification |
|
|
3 |
from classifiers.standard.classifier import Classifier |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class BertGeneralClassifier(Classifier): |
|
|
7 |
def __init__(self, config): |
|
|
8 |
self.config = config |
|
|
9 |
self.model = BertForSequenceClassification.from_pretrained(self.config['pretrained_model']) |
|
|
10 |
|
|
|
11 |
# here, we can do some layer removal if we want to |
|
|
12 |
|
|
|
13 |
# setup the optimizer |
|
|
14 |
self.optimizer = torch.optim.Adam(self.model.parameters(), self.config['learning_rate']) |
|
|
15 |
|
|
|
16 |
self.epochs = 0 |
|
|
17 |
print(self.model) |