Switch to unified view

a b/medacy/tools/json_to_pipeline.py
1
import json
2
import os
3
4
import spacy
5
6
from medacy.pipeline_components.feature_extractors.discrete_feature_extractor import FeatureExtractor
7
from medacy.pipeline_components.feature_extractors.text_extractor import TextExtractor
8
from medacy.pipeline_components.feature_overlayers.metamap.metamap import MetaMap
9
from medacy.pipeline_components.feature_overlayers.metamap.metamap_all_types_component import MetaMapAllTypesOverlayer
10
from medacy.pipeline_components.feature_overlayers.metamap.metamap_component import MetaMapOverlayer
11
from medacy.pipeline_components.learners.bert_learner import BertLearner
12
from medacy.pipeline_components.learners.bilstm_crf_learner import BiLstmCrfLearner
13
from medacy.pipeline_components.learners.crf_learner import get_crf
14
from medacy.pipeline_components.tokenizers.character_tokenizer import CharacterTokenizer
15
from medacy.pipeline_components.tokenizers.clinical_tokenizer import ClinicalTokenizer
16
from medacy.pipeline_components.tokenizers.systematic_review_tokenizer import SystematicReviewTokenizer
17
from medacy.pipelines.base.base_pipeline import BasePipeline
18
19
20
required_keys = [
21
    'learner',
22
    'spacy_pipeline',
23
]
24
25
26
def json_to_pipeline(json_path):
27
    """
28
    Constructs a custom pipeline from a json file
29
30
    The json must have the following keys:
31
32
    'learner': 'CRF', 'BiLSTM', or 'BERT'
33
    'spacy_pipeline': the spaCy model to use
34
35
    The following keys are optional:
36
    'spacy_features': a list of features that exist as spaCy token annotations
37
    'window_size': the number of words +/- the target word whose features should be used along with the target word; defaults to 0
38
    'tokenizer': 'clinical', 'systematic_review', or 'character'; defaults to the spaCy model's tokenizer
39
    'metamap': the path to the MetaMap binary; MetaMap will only be used if this key is present
40
        if 'metamap' is a key, 'semantic_types' must also be a key, with value 'all', 'none', or
41
        a list of semantic type strings
42
43
    :param json_path: the path to the json file, or a dict of what that json would be
44
    :return: a custom pipeline class
45
    """
46
47
    if isinstance(json_path, (str, os.PathLike)):
48
        with open(json_path, 'rb') as f:
49
            input_json = json.load(f)
50
    elif isinstance(json_path, dict):
51
        input_json = json_path
52
53
    missing_keys = [key for key in required_keys if key not in input_json.keys()]
54
    if missing_keys:
55
        raise ValueError(f"Required key(s) '{missing_keys}' was/were not found in the json file.")
56
57
    class CustomPipeline(BasePipeline):
58
        """A custom pipeline configured from a JSON file"""
59
60
        def __init__(self, entities, **kwargs):
61
            super().__init__(entities, spacy_pipeline=spacy.load(input_json['spacy_pipeline']))
62
63
            if 'metamap' in input_json.keys():
64
                if 'semantic_types' not in input_json.keys():
65
                    raise ValueError("'semantic_types' must be a key when 'metamap' is a key.")
66
67
                metamap = MetaMap(input_json['metamap'])
68
69
                if input_json['semantic_types'] == 'all':
70
                    self.add_component(MetaMapAllTypesOverlayer, metamap)
71
                elif input_json['semantic_types'] == 'none':
72
                    self.add_component(MetaMapOverlayer, metamap, semantic_type_labels=[])
73
                elif isinstance(input_json['semantic_types'], list):
74
                    self.add_component(MetaMapOverlayer, metamap, semantic_type_labels=input_json['semantic_types'])
75
                else:
76
                    raise ValueError("'semantic_types' must be 'all', 'none', or a list of strings")
77
78
            # BERT values
79
            self.cuda_device = kwargs['cuda_device'] if 'cuda_device' in kwargs else -1
80
            self.batch_size = kwargs['batch_size'] if 'batch_size' in kwargs else 8
81
            self.learning_rate = kwargs['learning_rate'] if 'learning_rate' in kwargs else 1e-5
82
            self.epochs = kwargs['epochs'] if 'epochs' in kwargs else 3
83
            self.pretrained_model = kwargs['pretrained_model'] if 'pretrained_model' in kwargs else 'bert-large-cased'
84
            self.using_crf = kwargs['using_crf'] if 'using_crf' in kwargs else False
85
86
            # BiLSTM value
87
            if input_json['learner'] == 'BiLSTM':
88
                if 'word_embeddings' not in kwargs:
89
                    raise ValueError("BiLSTM learner requires word embeddings; use the parameter '--word_embeddings' "
90
                                     "to specify an embedding path")
91
            self.word_embeddings = kwargs['word_embeddings']
92
93
        def get_tokenizer(self):
94
            if 'tokenizer' not in input_json.keys():
95
                return None
96
97
            selection = input_json['tokenizer']
98
            options = {
99
                'clinical': ClinicalTokenizer,
100
                'systematic_review': SystematicReviewTokenizer,
101
                'character': CharacterTokenizer
102
            }
103
104
            if selection not in options:
105
                raise ValueError(f"Tokenizer selection '{selection}' not an option")
106
107
            Tokenizer = options[selection]
108
            return Tokenizer(self.spacy_pipeline)
109
110
        def get_learner(self):
111
            learner_selection = input_json['learner']
112
113
            if learner_selection == 'CRF':
114
                return "CRF_l2sgd", get_crf()
115
            if learner_selection == 'BiLSTM':
116
                return 'BiLSTM+CRF', BiLstmCrfLearner(self.word_embeddings, self.cuda_device)
117
            if learner_selection == 'BERT':
118
                learner = BertLearner(
119
                    self.cuda_device,
120
                    pretrained_model=self.pretrained_model,
121
                    batch_size=self.batch_size,
122
                    learning_rate=self.learning_rate,
123
                    epochs=self.epochs,
124
                    using_crf=self.using_crf
125
                )
126
                return 'BERT', learner
127
            else:
128
                raise ValueError(f"'learner' must be 'CRF', 'BiLSTM', or 'BERT', but is {learner_selection}")
129
130
        def get_feature_extractor(self):
131
            if input_json['learner'] == 'BERT':
132
                return TextExtractor()
133
134
            return FeatureExtractor(
135
                window_size=input_json['window_size'] if 'window_size' in input_json else 0,
136
                spacy_features=input_json['spacy_features'] if 'spacy_features' in input_json else ['text']
137
            )
138
139
        def get_report(self):
140
            report = super().get_report() + '\n'
141
            report += f"Pipeline configured from a JSON: {json.dumps(input_json)}\nJSON path: {json_path}"
142
            return report
143
144
    return CustomPipeline