Switch to unified view

a b/medicalbert/classifiers/standard/bert_classifier.py
1
import torch
2
from classifiers.standard.bert_model import BertForSequenceClassification
3
from classifiers.standard.classifier import Classifier
4
5
6
class BertGeneralClassifier(Classifier):
7
    def __init__(self, config):
8
        self.config = config
9
        self.model = BertForSequenceClassification.from_pretrained(self.config['pretrained_model'])
10
11
        # here, we can do some layer removal if we want to
12
13
        # setup the optimizer
14
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.config['learning_rate'])
15
16
        self.epochs = 0
17
        print(self.model)