a b/medacy/model/model.py
1
import importlib
2
import logging
3
import os
4
from itertools import cycle
5
from pathlib import Path
6
from shutil import copyfile
7
from statistics import mean
8
from typing import List, Tuple, Dict, Iterable
9
10
import joblib
11
import numpy as np
12
from sklearn_crfsuite import metrics
13
from tabulate import tabulate
14
15
from medacy.data.annotations import Annotations, EntTuple
16
from medacy.data.dataset import Dataset
17
from medacy.pipeline_components.feature_extractors import FeatureTuple
18
from medacy.pipelines.base.base_pipeline import BasePipeline
19
20
DEFAULT_NUM_FOLDS = 10
21
22
23
def create_folds(y, num_folds=DEFAULT_NUM_FOLDS) -> List[Tuple[FeatureTuple, List]]:
24
    """
25
    Partitions a data set of sequence labels and classifications into a number of stratified folds. Each partition
26
    should have an evenly distributed representation of sequence labels. Without stratification, under-representated
27
    labels may not appear in some folds. Returns an iterable [(X*,y*), ...] where each element contains the indices
28
    of the train and test set for the particular testing fold.
29
30
    See Dietterich, 1997 "Approximate Statistical Tests for Comparing Supervised Classification
31
    Algorithms" for in-depth analysis.
32
33
    :param y: a collection of sequence labels
34
    :param num_folds: the number of folds (defaults to five, but must be >= 2
35
    :return: an iterable
36
    """
37
    if not isinstance(num_folds, int) or num_folds < 2:
38
        raise ValueError(f"'num_folds' must be an int >= 2, but is {repr(num_folds)}")
39
40
    # labels are ordered by most examples in data
41
    labels = np.unique([label for sequence in y for label in sequence])
42
    np.flip(labels)
43
44
    added = np.ones(len(y), dtype=bool)
45
    partitions = [[] for _ in range(num_folds)]
46
    partition_cycler = cycle(partitions)
47
48
    for label in labels:
49
        possible_sequences = [index for index, sequence in enumerate(y) if label in sequence]
50
        for index in possible_sequences:
51
            if added[index]:
52
                partition = next(partition_cycler)
53
                partition.append(index)
54
                added[index] = 0
55
56
    train_test_array = []
57
58
    for i, y in enumerate(partitions):
59
        X = []
60
        for j, partition in enumerate(partitions):
61
            if i != j:
62
                X += partition
63
64
        train_test_array.append((X, y))
65
66
    return train_test_array
67
68
69
def sequence_to_ann(X: List[FeatureTuple], y: List[str], file_names: Iterable[str]) -> Dict[str, Annotations]:
70
    """
71
    Creates a dictionary of document-level Annotations objects for a given sequence
72
    :param X: A list of sentence level zipped (features, indices, document_name) tuples
73
    :param y: A  list of sentence-level lists of tags
74
    :param file_names: A list of file names that are used by these sequences
75
    :return: A dictionary mapping txt file names (the whole path) to their Annotations objects, where the
76
    Annotations are constructed from the X and y data given here.
77
    """
78
    # Flattening nested structures into 2d lists
79
    anns = {filename: Annotations([]) for filename in file_names}
80
    tuples_by_doc = {filename: [] for filename in file_names}
81
    document_indices = []
82
    span_indices = []
83
84
    for sequence in X:
85
        document_indices += [sequence.file_name] * len(sequence.features)
86
        span_indices.extend(sequence.indices)
87
88
    groundtruth = [element for sentence in y for element in sentence]
89
90
    # Map the predicted sequences to their corresponding documents
91
    i = 0
92
93
    while i < len(groundtruth):
94
        if groundtruth[i] == 'O':
95
            i += 1
96
            continue
97
98
        entity = groundtruth[i]
99
        document = document_indices[i]
100
        first_start, first_end = span_indices[i]
101
        # Ensure that consecutive tokens with the same label are merged
102
        while i < len(groundtruth) - 1 and groundtruth[i + 1] == entity:  # If inside entity, keep incrementing
103
            i += 1
104
105
        last_start, last_end = span_indices[i]
106
        tuples_by_doc[document].append((entity, first_start, last_end))
107
        i += 1
108
109
    # Create the Annotations objects
110
    for file_name, tups in tuples_by_doc.items():
111
        ann_tups = []
112
        with open(file_name) as f:
113
            text = f.read()
114
        for tup in tups:
115
            entity, start, end = tup
116
            ent_text = text[start:end]
117
            new_tup = EntTuple(entity, start, end, ent_text)
118
            ann_tups.append(new_tup)
119
        anns[file_name].annotations = ann_tups
120
121
    return anns
122
123
124
def write_ann_dicts(output_dir: Path, dict_list: List[Dict[str, Annotations]]) -> Dict[str, Annotations]:
125
    """
126
    Merges a list of dicts of Annotations into one dict representing all the individual ann files and prints the
127
    ann data for both the individual Annotations and the combined one.
128
    :param output_dir: Path object of the output directory (a subdirectory is made for each fold)
129
    :param dict_list: a list of file_name: Annotations dictionaries
130
    :return: The merged Annotations dict, if wanted
131
    """
132
    file_names = set()
133
    for d in dict_list:
134
        file_names |= set(d.keys())
135
136
    all_annotations_dict = {filename: Annotations([]) for filename in file_names}
137
    for i, fold_dict in enumerate(dict_list, 1):
138
        fold_dir = output_dir / f"fold_{i}"
139
        os.mkdir(fold_dir)
140
        for file_name, ann in fold_dict.items():
141
            # Write the Annotations from the individual fold to file;
142
            # Note that in this is written to the fold_dir, which is a subfolder of the output_dir
143
            ann.to_ann(fold_dir / (os.path.basename(file_name).rstrip("txt") + "ann"))
144
            # Merge the Annotations from the fold into the inter-fold Annotations
145
            all_annotations_dict[file_name] |= ann
146
147
    # Write the Annotations that are the combination of all folds to file
148
    for file_name, ann in all_annotations_dict.items():
149
        output_file_path = output_dir / (os.path.basename(file_name).rstrip("txt") + "ann")
150
        ann.to_ann(output_file_path)
151
152
    return all_annotations_dict
153
154
155
class Model:
156
    """
157
    A medaCy Model allows the fitting of a named entity recognition model to a given dataset according to the
158
    configuration of a given medaCy pipeline. Once fitted, Model instances can be used to predict over documents.
159
    Also included is a function for cross validating over a dataset for measuring the performance of a pipeline.
160
161
    :ivar pipeline: a medaCy pipeline, must be a subclass of BasePipeline (see medacy.pipelines.base.BasePipeline)
162
    :ivar model: weights, if the model has been fitted
163
    :ivar X_data: X_data from the pipeline; primarily for internal use
164
    :ivar y_data: y_data from the pipeline; primarily for internal use
165
    """
166
167
    def __init__(self, medacy_pipeline, model=None):
168
169
        if not isinstance(medacy_pipeline, BasePipeline):
170
            raise TypeError("Pipeline must be a medaCy pipeline that interfaces medacy.pipelines.base.BasePipeline")
171
172
        self.pipeline = medacy_pipeline
173
        self.model = model
174
175
        # These arrays will store the sequences of features and sequences of corresponding labels
176
        self.X_data = []
177
        self.y_data = []
178
179
        # Run an initializing document through the pipeline to register all token extensions.
180
        # This allows the gathering of pipeline information prior to fitting with live data.
181
        doc = self.pipeline(medacy_pipeline.spacy_pipeline.make_doc("Initialize"), predict=True)
182
        if doc is None:
183
            raise IOError("Model could not be initialized with the set pipeline.")
184
185
    def preprocess(self, dataset):
186
        """
187
        Preprocess dataset into a list of sequences and tags.
188
        :param dataset: Dataset object to preprocess.
189
        """
190
        self.X_data = []
191
        self.y_data = []
192
        # Run all Docs through the pipeline before extracting features, allowing for pipeline components
193
        # that require inter-dependent doc objects
194
        docs = [self._run_through_pipeline(data_file) for data_file in dataset if data_file.txt_path]
195
        for doc in docs:
196
            features, labels = self._extract_features(doc)
197
            self.X_data += features
198
            self.y_data += labels
199
200
    def fit(self, dataset: Dataset, groundtruth_directory: Path = None):
201
        """
202
        Runs dataset through the designated pipeline, extracts features, and fits a conditional random field.
203
        :param dataset: Instance of Dataset.
204
        :return model: a trained instance of a sklearn_crfsuite.CRF model.
205
        """
206
207
        groundtruth_directory = Path(groundtruth_directory) if groundtruth_directory else False
208
209
        report = self.pipeline.get_report()
210
        self.preprocess(dataset)
211
212
        if groundtruth_directory:
213
            logging.info(f"Writing dataset groundtruth to {groundtruth_directory}")
214
            for file_path, ann in sequence_to_ann(self.X_data, self.y_data, {x[2] for x in self.X_data}).items():
215
                ann.to_ann(groundtruth_directory / (os.path.basename(file_path).strip("txt") + "ann"))
216
217
        learner_name, learner = self.pipeline.get_learner()
218
        logging.info(f"Training: {learner_name}")
219
220
        train_data = [x[0] for x in self.X_data]
221
        learner.fit(train_data, self.y_data)
222
        logging.info(f"Successfully Trained: {learner_name}\n{report}")
223
224
        self.model = learner
225
        return self.model
226
227
    def _predict_document(self, doc):
228
        """
229
        Generates an dictionary of predictions of the given model over the corresponding document. The passed document
230
        is assumed to be annotated by the same pipeline utilized when training the model.
231
        :param doc: A spacy document
232
        :return: an Annotations object containing the model predictions
233
        """
234
235
        feature_extractor = self.pipeline.get_feature_extractor()
236
237
        features, indices = feature_extractor.get_features_with_span_indices(doc)
238
        predictions = self.model.predict(features)
239
        predictions = [element for sentence in predictions for element in sentence]  # flatten 2d list
240
        span_indices = [element for sentence in indices for element in sentence]  # parallel array containing indices
241
        annotations = []
242
243
        i = 0
244
        while i < len(predictions):
245
            if predictions[i] == 'O':
246
                i += 1
247
                continue
248
249
            entity = predictions[i]
250
            first_start, first_end = span_indices[i]
251
252
            # Ensure that consecutive tokens with the same label are merged
253
            while i < len(predictions) - 1 and predictions[i + 1] == entity:  # If inside entity, keep incrementing
254
                i += 1
255
256
            last_start, last_end = span_indices[i]
257
            labeled_text = doc.text[first_start:last_end]
258
            new_ent = EntTuple(entity, first_start, last_end, labeled_text)
259
            annotations.append(new_ent)
260
261
            logging.debug(f"{doc._.file_name}: Predicted {entity} at ({first_start}, {last_end}) {labeled_text}")
262
263
            i += 1
264
265
        return Annotations(annotations)
266
267
    def predict(self, input_data, prediction_directory=None):
268
        """
269
        Generates predictions over a string or a input_data utilizing the pipeline equipped to the instance.
270
271
        :param input_data: a string, Dataset, or directory path to predict over
272
        :param prediction_directory: The directory to write predictions if doing bulk prediction
273
            (default: */prediction* sub-directory of Dataset)
274
        :return: if input_data is a str, returns an Annotations of the predictions;
275
            if input_data is a Dataset or a valid directory path, returns a Dataset of the predictions.
276
277
        Note that if input_data is supposed to be a directory path but the directory is not found, it will be predicted
278
        over as a string. This can be prevented by validating inputs with os.path.isdir().
279
        """
280
281
        if self.model is None:
282
            raise RuntimeError("Must fit or load a pickled model before predicting")
283
284
        if isinstance(input_data, str) and not os.path.isdir(input_data):
285
            doc = self.pipeline.spacy_pipeline.make_doc(input_data)
286
            doc.set_extension('file_name', default=None, force=True)
287
            doc._.file_name = 'STRING_INPUT'
288
            doc = self.pipeline(doc, predict=True)
289
            annotations = self._predict_document(doc)
290
            return annotations
291
292
        if isinstance(input_data, Dataset):
293
            input_files = [d.txt_path for d in input_data]
294
            # Change input_data to point to the Dataset's directory path so that we can use it
295
            # to create the prediction directory
296
            input_data = input_data.data_directory
297
        elif os.path.isdir(input_data):
298
            input_files = [os.path.join(input_data, f) for f in os.listdir(input_data) if f.endswith('.txt')]
299
        else:
300
            raise ValueError(f"'input_data' must be a string (which can be a directory path) or a Dataset, but is {repr(input_data)}")
301
302
        if prediction_directory is None:
303
            prediction_directory = os.path.join(input_data, 'predictions')
304
            if os.path.isdir(prediction_directory):
305
                logging.warning("Overwriting existing predictions at %s", prediction_directory)
306
            else:
307
                os.mkdir(prediction_directory)
308
309
        for file_path in input_files:
310
            file_name = os.path.basename(file_path).strip('.txt')
311
            logging.info("Predicting file: %s", file_path)
312
313
            with open(file_path, 'r') as f:
314
                doc = self.pipeline.spacy_pipeline.make_doc(f.read())
315
316
            doc.set_extension('file_name', default=None, force=True)
317
            doc._.file_name = file_name
318
319
            # run through the pipeline
320
            doc = self.pipeline(doc, predict=True)
321
322
            # Predict, creating a new Annotations object
323
            annotations = self._predict_document(doc)
324
            logging.debug("Writing to: %s", os.path.join(prediction_directory, file_name + ".ann"))
325
            annotations.to_ann(write_location=os.path.join(prediction_directory, file_name + ".ann"))
326
327
            # Copy the txt file so that the output will also be a Dataset
328
            copyfile(file_path, os.path.join(prediction_directory, file_name + ".txt"))
329
330
        return Dataset(prediction_directory)
331
332
    def cross_validate(self, training_dataset, num_folds=DEFAULT_NUM_FOLDS, prediction_directory=None, groundtruth_directory=None):
333
        """
334
        Performs k-fold stratified cross-validation using our model and pipeline.
335
336
        If the training dataset, groundtruth_directory and prediction_directory are passed, intermediate predictions during cross validation
337
        are written to the directory `write_predictions`. This allows one to construct a confusion matrix or to compute
338
        the prediction ambiguity with the methods present in the Dataset class to support pipeline development without
339
        a designated evaluation set.
340
341
        :param training_dataset: Dataset that is being cross validated
342
        :param num_folds: number of folds to split training data into for cross validation, defaults to 5
343
        :param prediction_directory: directory to write predictions of cross validation to
344
        :param groundtruth_directory: directory to write the ground truth MedaCy evaluates on
345
        :return: Prints out performance metrics, if prediction_directory
346
        """
347
348
        if num_folds <= 1:
349
            raise ValueError("Number of folds for cross validation must be greater than 1, but is %s" % repr(num_folds))
350
351
        groundtruth_directory = Path(groundtruth_directory) if groundtruth_directory else False
352
        prediction_directory = Path(prediction_directory) if prediction_directory else False
353
354
        for d in [groundtruth_directory, prediction_directory]:
355
            if d and not d.exists():
356
                raise NotADirectoryError(f"Options groundtruth_directory and predictions_directory must be existing directories, but one is {d}")
357
358
        pipeline_report = self.pipeline.get_report()
359
360
        self.preprocess(training_dataset)
361
362
        if not (self.X_data and self.y_data):
363
            raise RuntimeError("Must have features and labels extracted for cross validation")
364
365
        tags = sorted(self.pipeline.entities)
366
        logging.info(f'Tagset: {tags}')
367
368
        eval_stats = {}
369
370
        # Dict for storing mapping of sequences to their corresponding file
371
        fold_groundtruth_dicts = []
372
        fold_prediction_dicts = []
373
        file_names = {x.file_name for x in self.X_data}
374
375
        folds = create_folds(self.y_data, num_folds)
376
377
        for fold_num, fold_data in enumerate(folds, 1):
378
            train_indices, test_indices = fold_data
379
            fold_statistics = {}
380
            learner_name, learner = self.pipeline.get_learner()
381
382
            X_train = [self.X_data[index] for index in train_indices]
383
            y_train = [self.y_data[index] for index in train_indices]
384
385
            X_test = [self.X_data[index] for index in test_indices]
386
            y_test = [self.y_data[index] for index in test_indices]
387
388
            logging.info("Training Fold %i", fold_num)
389
            train_data = [x[0] for x in X_train]
390
            test_data = [x[0] for x in X_test]
391
            learner.fit(train_data, y_train)
392
            y_pred = learner.predict(test_data)
393
394
            if groundtruth_directory is not None:
395
                ann_dict = sequence_to_ann(X_test, y_test, file_names)
396
                fold_groundtruth_dicts.append(ann_dict)
397
398
            if prediction_directory is not None:
399
                ann_dict = sequence_to_ann(X_test, y_pred, file_names)
400
                fold_prediction_dicts.append(ann_dict)
401
402
            # Write the metrics for this fold.
403
            for label in tags:
404
                fold_statistics[label] = {
405
                    "recall": metrics.flat_recall_score(y_test, y_pred, average='weighted', labels=[label]),
406
                    "precision": metrics.flat_precision_score(y_test, y_pred, average='weighted', labels=[label]),
407
                    "f1": metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=[label])
408
                }
409
410
            # add averages
411
            fold_statistics['system'] = {
412
                "recall": metrics.flat_recall_score(y_test, y_pred, average='weighted', labels=tags),
413
                "precision": metrics.flat_precision_score(y_test, y_pred, average='weighted', labels=tags),
414
                "f1": metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=tags)
415
            }
416
417
            table_data = [
418
                [label,
419
                 format(fold_statistics[label]['precision'], ".3f"),
420
                 format(fold_statistics[label]['recall'], ".3f"),
421
                 format(fold_statistics[label]['f1'], ".3f")
422
                 ] for label in tags + ['system']
423
            ]
424
425
            logging.info('\n' + tabulate(table_data, headers=['Entity', 'Precision', 'Recall', 'F1'], tablefmt='orgtbl'))
426
427
            eval_stats[fold_num] = fold_statistics
428
429
        statistics_all_folds = {}
430
431
        for label in tags + ['system']:
432
            statistics_all_folds[label] = {
433
                'precision_average': mean(eval_stats[fold][label]['precision'] for fold in eval_stats),
434
                'precision_max': max(eval_stats[fold][label]['precision'] for fold in eval_stats),
435
                'precision_min': min(eval_stats[fold][label]['precision'] for fold in eval_stats),
436
                'recall_average': mean(eval_stats[fold][label]['recall'] for fold in eval_stats),
437
                'recall_max': max(eval_stats[fold][label]['recall'] for fold in eval_stats),
438
                'f1_average': mean(eval_stats[fold][label]['f1'] for fold in eval_stats),
439
                'f1_max': max(eval_stats[fold][label]['f1'] for fold in eval_stats),
440
                'f1_min': min(eval_stats[fold][label]['f1'] for fold in eval_stats),
441
            }
442
443
        entity_counts = training_dataset.compute_counts()
444
        entity_counts['system'] = sum(v for k, v in entity_counts.items() if k in self.pipeline.entities)
445
446
        table_data = [
447
            [f"{label} ({entity_counts[label]})",  # Entity (Count)
448
             format(statistics_all_folds[label]['precision_average'], ".3f"),
449
             format(statistics_all_folds[label]['recall_average'], ".3f"),
450
             format(statistics_all_folds[label]['f1_average'], ".3f"),
451
             format(statistics_all_folds[label]['f1_min'], ".3f"),
452
             format(statistics_all_folds[label]['f1_max'], ".3f")
453
             ] for label in tags + ['system']
454
        ]
455
456
        # Combine the pipeline report and the resulting data, then log it or print it (whichever ensures that it prints)
457
458
        output_str = '\n' + pipeline_report + '\n\n' + tabulate(
459
            table_data,
460
            headers=['Entity (Count)', 'Precision', 'Recall', 'F1', 'F1_Min', 'F1_Max'],
461
            tablefmt='orgtbl'
462
        )
463
464
        if logging.root.level > logging.INFO:
465
            print(output_str)
466
        else:
467
            logging.info(output_str)
468
469
        # Write groundtruth and predictions to file
470
        if groundtruth_directory:
471
            write_ann_dicts(groundtruth_directory, fold_groundtruth_dicts)
472
        if prediction_directory:
473
            write_ann_dicts(prediction_directory, fold_prediction_dicts)
474
475
        return statistics_all_folds
476
477
    def _run_through_pipeline(self, data_file):
478
        """
479
        Runs a DataFile through the pipeline, returning the resulting Doc object
480
        :param data_file: instance of DataFile
481
        :return: a Doc object
482
        """
483
        nlp = self.pipeline.spacy_pipeline
484
        logging.info("Processing file: %s", data_file.file_name)
485
486
        with open(data_file.txt_path, 'r', encoding='utf-8') as f:
487
            doc = nlp.make_doc(f.read())
488
489
        # Link ann_path to doc
490
        doc.set_extension('gold_annotation_file', default=None, force=True)
491
        doc.set_extension('file_name', default=None, force=True)
492
493
        doc._.gold_annotation_file = data_file.ann_path
494
        doc._.file_name = data_file.txt_path
495
496
        # run 'er through
497
        return self.pipeline(doc)
498
499
    def _extract_features(self, doc):
500
        """
501
        Extracts features from a Doc
502
        :param doc: an instance of Doc
503
        :return: a tuple of the feature dict and label list
504
        """
505
506
        feature_extractor = self.pipeline.get_feature_extractor()
507
        features, labels = feature_extractor(doc)
508
509
        logging.info(f"{doc._.file_name}: Feature Extraction Completed (num_sequences={len(labels)})")
510
        return features, labels
511
512
    def load(self, path):
513
        """
514
        Loads a pickled model.
515
516
        :param path: File path to directory where fitted model should be dumped
517
        :return:
518
        """
519
        model_name, model = self.pipeline.get_learner()
520
521
        if model_name == 'BiLSTM+CRF' or model_name == 'BERT':
522
            model.load(path)
523
            self.model = model
524
        else:
525
            self.model = joblib.load(path)
526
527
    def dump(self, path):
528
        """
529
        Dumps a model into a pickle file
530
531
        :param path: Directory path to dump the model
532
        :return:
533
        """
534
        if self.model is None:
535
            raise RuntimeError("Must fit model before dumping.")
536
537
        model_name, _ = self.pipeline.get_learner()
538
539
        if model_name == 'BiLSTM+CRF' or model_name == 'BERT':
540
            self.model.save(path)
541
        else:
542
            joblib.dump(self.model, path)
543
544
    @staticmethod
545
    def load_external(package_name):
546
        """
547
        Loads an external medaCy compatible Model. Require's the models package to be installed
548
        Alternatively, you can import the package directly and call it's .load() method.
549
550
        :param package_name: the package name of the model
551
        :return: an instance of Model that is configured and loaded - ready for prediction.
552
        """
553
        if importlib.util.find_spec(package_name) is None:
554
            raise ImportError("Package not installed: %s" % package_name)
555
        return importlib.import_module(package_name).load()