Switch to unified view

a b/medicalbert/classifiers/classifier_factory.py
1
# Factory for making new Classifier objects
2
from classifiers.standard.bert_classifier import BertGeneralClassifier
3
from classifiers.standard.bert_random_classifier import BertRandomClassifier
4
from classifiers.standard.fast_text_classifier import FastTextClassifier
5
from classifiers.standard.bert_mean_pool_classifier import BertMeanPoolClassifier
6
from classifiers.standard.bert_concat_classifier import BertConcatClassifier
7
from classifiers.sequential.bert_sequence_classifier import BertSequenceClassifier
8
9
10
class ClassifierFactory:
11
    def __init__(self, config):
12
        self._classifiers = {"bert-general": BertGeneralClassifier,
13
                             "bert-random": BertRandomClassifier,
14
                             "fast-text": FastTextClassifier,
15
                             "bert-mean-pool": BertMeanPoolClassifier,
16
                             "bert-concat": BertConcatClassifier,
17
                             "bert-seq": BertSequenceClassifier}
18
        self.config = config
19
20
    def register_classifier(self, name, classifier):
21
        self._classifiers[name] = classifier
22
23
    def make_classifier(self, name):
24
        classifier = self._classifiers.get(name)
25
        if not classifier:
26
            raise ValueError(format)
27
        return classifier(self.config)
28