|
a |
|
b/medicalbert/classifiers/classifier_factory.py |
|
|
1 |
# Factory for making new Classifier objects |
|
|
2 |
from classifiers.standard.bert_classifier import BertGeneralClassifier |
|
|
3 |
from classifiers.standard.bert_random_classifier import BertRandomClassifier |
|
|
4 |
from classifiers.standard.fast_text_classifier import FastTextClassifier |
|
|
5 |
from classifiers.standard.bert_mean_pool_classifier import BertMeanPoolClassifier |
|
|
6 |
from classifiers.standard.bert_concat_classifier import BertConcatClassifier |
|
|
7 |
from classifiers.sequential.bert_sequence_classifier import BertSequenceClassifier |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class ClassifierFactory: |
|
|
11 |
def __init__(self, config): |
|
|
12 |
self._classifiers = {"bert-general": BertGeneralClassifier, |
|
|
13 |
"bert-random": BertRandomClassifier, |
|
|
14 |
"fast-text": FastTextClassifier, |
|
|
15 |
"bert-mean-pool": BertMeanPoolClassifier, |
|
|
16 |
"bert-concat": BertConcatClassifier, |
|
|
17 |
"bert-seq": BertSequenceClassifier} |
|
|
18 |
self.config = config |
|
|
19 |
|
|
|
20 |
def register_classifier(self, name, classifier): |
|
|
21 |
self._classifiers[name] = classifier |
|
|
22 |
|
|
|
23 |
def make_classifier(self, name): |
|
|
24 |
classifier = self._classifiers.get(name) |
|
|
25 |
if not classifier: |
|
|
26 |
raise ValueError(format) |
|
|
27 |
return classifier(self.config) |
|
|
28 |
|