|
a |
|
b/medacy/model/model.py |
|
|
1 |
import importlib |
|
|
2 |
import logging |
|
|
3 |
import os |
|
|
4 |
from itertools import cycle |
|
|
5 |
from pathlib import Path |
|
|
6 |
from shutil import copyfile |
|
|
7 |
from statistics import mean |
|
|
8 |
from typing import List, Tuple, Dict, Iterable |
|
|
9 |
|
|
|
10 |
import joblib |
|
|
11 |
import numpy as np |
|
|
12 |
from sklearn_crfsuite import metrics |
|
|
13 |
from tabulate import tabulate |
|
|
14 |
|
|
|
15 |
from medacy.data.annotations import Annotations, EntTuple |
|
|
16 |
from medacy.data.dataset import Dataset |
|
|
17 |
from medacy.pipeline_components.feature_extractors import FeatureTuple |
|
|
18 |
from medacy.pipelines.base.base_pipeline import BasePipeline |
|
|
19 |
|
|
|
20 |
DEFAULT_NUM_FOLDS = 10 |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def create_folds(y, num_folds=DEFAULT_NUM_FOLDS) -> List[Tuple[FeatureTuple, List]]: |
|
|
24 |
""" |
|
|
25 |
Partitions a data set of sequence labels and classifications into a number of stratified folds. Each partition |
|
|
26 |
should have an evenly distributed representation of sequence labels. Without stratification, under-representated |
|
|
27 |
labels may not appear in some folds. Returns an iterable [(X*,y*), ...] where each element contains the indices |
|
|
28 |
of the train and test set for the particular testing fold. |
|
|
29 |
|
|
|
30 |
See Dietterich, 1997 "Approximate Statistical Tests for Comparing Supervised Classification |
|
|
31 |
Algorithms" for in-depth analysis. |
|
|
32 |
|
|
|
33 |
:param y: a collection of sequence labels |
|
|
34 |
:param num_folds: the number of folds (defaults to five, but must be >= 2 |
|
|
35 |
:return: an iterable |
|
|
36 |
""" |
|
|
37 |
if not isinstance(num_folds, int) or num_folds < 2: |
|
|
38 |
raise ValueError(f"'num_folds' must be an int >= 2, but is {repr(num_folds)}") |
|
|
39 |
|
|
|
40 |
# labels are ordered by most examples in data |
|
|
41 |
labels = np.unique([label for sequence in y for label in sequence]) |
|
|
42 |
np.flip(labels) |
|
|
43 |
|
|
|
44 |
added = np.ones(len(y), dtype=bool) |
|
|
45 |
partitions = [[] for _ in range(num_folds)] |
|
|
46 |
partition_cycler = cycle(partitions) |
|
|
47 |
|
|
|
48 |
for label in labels: |
|
|
49 |
possible_sequences = [index for index, sequence in enumerate(y) if label in sequence] |
|
|
50 |
for index in possible_sequences: |
|
|
51 |
if added[index]: |
|
|
52 |
partition = next(partition_cycler) |
|
|
53 |
partition.append(index) |
|
|
54 |
added[index] = 0 |
|
|
55 |
|
|
|
56 |
train_test_array = [] |
|
|
57 |
|
|
|
58 |
for i, y in enumerate(partitions): |
|
|
59 |
X = [] |
|
|
60 |
for j, partition in enumerate(partitions): |
|
|
61 |
if i != j: |
|
|
62 |
X += partition |
|
|
63 |
|
|
|
64 |
train_test_array.append((X, y)) |
|
|
65 |
|
|
|
66 |
return train_test_array |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
def sequence_to_ann(X: List[FeatureTuple], y: List[str], file_names: Iterable[str]) -> Dict[str, Annotations]: |
|
|
70 |
""" |
|
|
71 |
Creates a dictionary of document-level Annotations objects for a given sequence |
|
|
72 |
:param X: A list of sentence level zipped (features, indices, document_name) tuples |
|
|
73 |
:param y: A list of sentence-level lists of tags |
|
|
74 |
:param file_names: A list of file names that are used by these sequences |
|
|
75 |
:return: A dictionary mapping txt file names (the whole path) to their Annotations objects, where the |
|
|
76 |
Annotations are constructed from the X and y data given here. |
|
|
77 |
""" |
|
|
78 |
# Flattening nested structures into 2d lists |
|
|
79 |
anns = {filename: Annotations([]) for filename in file_names} |
|
|
80 |
tuples_by_doc = {filename: [] for filename in file_names} |
|
|
81 |
document_indices = [] |
|
|
82 |
span_indices = [] |
|
|
83 |
|
|
|
84 |
for sequence in X: |
|
|
85 |
document_indices += [sequence.file_name] * len(sequence.features) |
|
|
86 |
span_indices.extend(sequence.indices) |
|
|
87 |
|
|
|
88 |
groundtruth = [element for sentence in y for element in sentence] |
|
|
89 |
|
|
|
90 |
# Map the predicted sequences to their corresponding documents |
|
|
91 |
i = 0 |
|
|
92 |
|
|
|
93 |
while i < len(groundtruth): |
|
|
94 |
if groundtruth[i] == 'O': |
|
|
95 |
i += 1 |
|
|
96 |
continue |
|
|
97 |
|
|
|
98 |
entity = groundtruth[i] |
|
|
99 |
document = document_indices[i] |
|
|
100 |
first_start, first_end = span_indices[i] |
|
|
101 |
# Ensure that consecutive tokens with the same label are merged |
|
|
102 |
while i < len(groundtruth) - 1 and groundtruth[i + 1] == entity: # If inside entity, keep incrementing |
|
|
103 |
i += 1 |
|
|
104 |
|
|
|
105 |
last_start, last_end = span_indices[i] |
|
|
106 |
tuples_by_doc[document].append((entity, first_start, last_end)) |
|
|
107 |
i += 1 |
|
|
108 |
|
|
|
109 |
# Create the Annotations objects |
|
|
110 |
for file_name, tups in tuples_by_doc.items(): |
|
|
111 |
ann_tups = [] |
|
|
112 |
with open(file_name) as f: |
|
|
113 |
text = f.read() |
|
|
114 |
for tup in tups: |
|
|
115 |
entity, start, end = tup |
|
|
116 |
ent_text = text[start:end] |
|
|
117 |
new_tup = EntTuple(entity, start, end, ent_text) |
|
|
118 |
ann_tups.append(new_tup) |
|
|
119 |
anns[file_name].annotations = ann_tups |
|
|
120 |
|
|
|
121 |
return anns |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def write_ann_dicts(output_dir: Path, dict_list: List[Dict[str, Annotations]]) -> Dict[str, Annotations]: |
|
|
125 |
""" |
|
|
126 |
Merges a list of dicts of Annotations into one dict representing all the individual ann files and prints the |
|
|
127 |
ann data for both the individual Annotations and the combined one. |
|
|
128 |
:param output_dir: Path object of the output directory (a subdirectory is made for each fold) |
|
|
129 |
:param dict_list: a list of file_name: Annotations dictionaries |
|
|
130 |
:return: The merged Annotations dict, if wanted |
|
|
131 |
""" |
|
|
132 |
file_names = set() |
|
|
133 |
for d in dict_list: |
|
|
134 |
file_names |= set(d.keys()) |
|
|
135 |
|
|
|
136 |
all_annotations_dict = {filename: Annotations([]) for filename in file_names} |
|
|
137 |
for i, fold_dict in enumerate(dict_list, 1): |
|
|
138 |
fold_dir = output_dir / f"fold_{i}" |
|
|
139 |
os.mkdir(fold_dir) |
|
|
140 |
for file_name, ann in fold_dict.items(): |
|
|
141 |
# Write the Annotations from the individual fold to file; |
|
|
142 |
# Note that in this is written to the fold_dir, which is a subfolder of the output_dir |
|
|
143 |
ann.to_ann(fold_dir / (os.path.basename(file_name).rstrip("txt") + "ann")) |
|
|
144 |
# Merge the Annotations from the fold into the inter-fold Annotations |
|
|
145 |
all_annotations_dict[file_name] |= ann |
|
|
146 |
|
|
|
147 |
# Write the Annotations that are the combination of all folds to file |
|
|
148 |
for file_name, ann in all_annotations_dict.items(): |
|
|
149 |
output_file_path = output_dir / (os.path.basename(file_name).rstrip("txt") + "ann") |
|
|
150 |
ann.to_ann(output_file_path) |
|
|
151 |
|
|
|
152 |
return all_annotations_dict |
|
|
153 |
|
|
|
154 |
|
|
|
155 |
class Model: |
|
|
156 |
""" |
|
|
157 |
A medaCy Model allows the fitting of a named entity recognition model to a given dataset according to the |
|
|
158 |
configuration of a given medaCy pipeline. Once fitted, Model instances can be used to predict over documents. |
|
|
159 |
Also included is a function for cross validating over a dataset for measuring the performance of a pipeline. |
|
|
160 |
|
|
|
161 |
:ivar pipeline: a medaCy pipeline, must be a subclass of BasePipeline (see medacy.pipelines.base.BasePipeline) |
|
|
162 |
:ivar model: weights, if the model has been fitted |
|
|
163 |
:ivar X_data: X_data from the pipeline; primarily for internal use |
|
|
164 |
:ivar y_data: y_data from the pipeline; primarily for internal use |
|
|
165 |
""" |
|
|
166 |
|
|
|
167 |
def __init__(self, medacy_pipeline, model=None): |
|
|
168 |
|
|
|
169 |
if not isinstance(medacy_pipeline, BasePipeline): |
|
|
170 |
raise TypeError("Pipeline must be a medaCy pipeline that interfaces medacy.pipelines.base.BasePipeline") |
|
|
171 |
|
|
|
172 |
self.pipeline = medacy_pipeline |
|
|
173 |
self.model = model |
|
|
174 |
|
|
|
175 |
# These arrays will store the sequences of features and sequences of corresponding labels |
|
|
176 |
self.X_data = [] |
|
|
177 |
self.y_data = [] |
|
|
178 |
|
|
|
179 |
# Run an initializing document through the pipeline to register all token extensions. |
|
|
180 |
# This allows the gathering of pipeline information prior to fitting with live data. |
|
|
181 |
doc = self.pipeline(medacy_pipeline.spacy_pipeline.make_doc("Initialize"), predict=True) |
|
|
182 |
if doc is None: |
|
|
183 |
raise IOError("Model could not be initialized with the set pipeline.") |
|
|
184 |
|
|
|
185 |
def preprocess(self, dataset): |
|
|
186 |
""" |
|
|
187 |
Preprocess dataset into a list of sequences and tags. |
|
|
188 |
:param dataset: Dataset object to preprocess. |
|
|
189 |
""" |
|
|
190 |
self.X_data = [] |
|
|
191 |
self.y_data = [] |
|
|
192 |
# Run all Docs through the pipeline before extracting features, allowing for pipeline components |
|
|
193 |
# that require inter-dependent doc objects |
|
|
194 |
docs = [self._run_through_pipeline(data_file) for data_file in dataset if data_file.txt_path] |
|
|
195 |
for doc in docs: |
|
|
196 |
features, labels = self._extract_features(doc) |
|
|
197 |
self.X_data += features |
|
|
198 |
self.y_data += labels |
|
|
199 |
|
|
|
200 |
def fit(self, dataset: Dataset, groundtruth_directory: Path = None): |
|
|
201 |
""" |
|
|
202 |
Runs dataset through the designated pipeline, extracts features, and fits a conditional random field. |
|
|
203 |
:param dataset: Instance of Dataset. |
|
|
204 |
:return model: a trained instance of a sklearn_crfsuite.CRF model. |
|
|
205 |
""" |
|
|
206 |
|
|
|
207 |
groundtruth_directory = Path(groundtruth_directory) if groundtruth_directory else False |
|
|
208 |
|
|
|
209 |
report = self.pipeline.get_report() |
|
|
210 |
self.preprocess(dataset) |
|
|
211 |
|
|
|
212 |
if groundtruth_directory: |
|
|
213 |
logging.info(f"Writing dataset groundtruth to {groundtruth_directory}") |
|
|
214 |
for file_path, ann in sequence_to_ann(self.X_data, self.y_data, {x[2] for x in self.X_data}).items(): |
|
|
215 |
ann.to_ann(groundtruth_directory / (os.path.basename(file_path).strip("txt") + "ann")) |
|
|
216 |
|
|
|
217 |
learner_name, learner = self.pipeline.get_learner() |
|
|
218 |
logging.info(f"Training: {learner_name}") |
|
|
219 |
|
|
|
220 |
train_data = [x[0] for x in self.X_data] |
|
|
221 |
learner.fit(train_data, self.y_data) |
|
|
222 |
logging.info(f"Successfully Trained: {learner_name}\n{report}") |
|
|
223 |
|
|
|
224 |
self.model = learner |
|
|
225 |
return self.model |
|
|
226 |
|
|
|
227 |
def _predict_document(self, doc): |
|
|
228 |
""" |
|
|
229 |
Generates an dictionary of predictions of the given model over the corresponding document. The passed document |
|
|
230 |
is assumed to be annotated by the same pipeline utilized when training the model. |
|
|
231 |
:param doc: A spacy document |
|
|
232 |
:return: an Annotations object containing the model predictions |
|
|
233 |
""" |
|
|
234 |
|
|
|
235 |
feature_extractor = self.pipeline.get_feature_extractor() |
|
|
236 |
|
|
|
237 |
features, indices = feature_extractor.get_features_with_span_indices(doc) |
|
|
238 |
predictions = self.model.predict(features) |
|
|
239 |
predictions = [element for sentence in predictions for element in sentence] # flatten 2d list |
|
|
240 |
span_indices = [element for sentence in indices for element in sentence] # parallel array containing indices |
|
|
241 |
annotations = [] |
|
|
242 |
|
|
|
243 |
i = 0 |
|
|
244 |
while i < len(predictions): |
|
|
245 |
if predictions[i] == 'O': |
|
|
246 |
i += 1 |
|
|
247 |
continue |
|
|
248 |
|
|
|
249 |
entity = predictions[i] |
|
|
250 |
first_start, first_end = span_indices[i] |
|
|
251 |
|
|
|
252 |
# Ensure that consecutive tokens with the same label are merged |
|
|
253 |
while i < len(predictions) - 1 and predictions[i + 1] == entity: # If inside entity, keep incrementing |
|
|
254 |
i += 1 |
|
|
255 |
|
|
|
256 |
last_start, last_end = span_indices[i] |
|
|
257 |
labeled_text = doc.text[first_start:last_end] |
|
|
258 |
new_ent = EntTuple(entity, first_start, last_end, labeled_text) |
|
|
259 |
annotations.append(new_ent) |
|
|
260 |
|
|
|
261 |
logging.debug(f"{doc._.file_name}: Predicted {entity} at ({first_start}, {last_end}) {labeled_text}") |
|
|
262 |
|
|
|
263 |
i += 1 |
|
|
264 |
|
|
|
265 |
return Annotations(annotations) |
|
|
266 |
|
|
|
267 |
def predict(self, input_data, prediction_directory=None): |
|
|
268 |
""" |
|
|
269 |
Generates predictions over a string or a input_data utilizing the pipeline equipped to the instance. |
|
|
270 |
|
|
|
271 |
:param input_data: a string, Dataset, or directory path to predict over |
|
|
272 |
:param prediction_directory: The directory to write predictions if doing bulk prediction |
|
|
273 |
(default: */prediction* sub-directory of Dataset) |
|
|
274 |
:return: if input_data is a str, returns an Annotations of the predictions; |
|
|
275 |
if input_data is a Dataset or a valid directory path, returns a Dataset of the predictions. |
|
|
276 |
|
|
|
277 |
Note that if input_data is supposed to be a directory path but the directory is not found, it will be predicted |
|
|
278 |
over as a string. This can be prevented by validating inputs with os.path.isdir(). |
|
|
279 |
""" |
|
|
280 |
|
|
|
281 |
if self.model is None: |
|
|
282 |
raise RuntimeError("Must fit or load a pickled model before predicting") |
|
|
283 |
|
|
|
284 |
if isinstance(input_data, str) and not os.path.isdir(input_data): |
|
|
285 |
doc = self.pipeline.spacy_pipeline.make_doc(input_data) |
|
|
286 |
doc.set_extension('file_name', default=None, force=True) |
|
|
287 |
doc._.file_name = 'STRING_INPUT' |
|
|
288 |
doc = self.pipeline(doc, predict=True) |
|
|
289 |
annotations = self._predict_document(doc) |
|
|
290 |
return annotations |
|
|
291 |
|
|
|
292 |
if isinstance(input_data, Dataset): |
|
|
293 |
input_files = [d.txt_path for d in input_data] |
|
|
294 |
# Change input_data to point to the Dataset's directory path so that we can use it |
|
|
295 |
# to create the prediction directory |
|
|
296 |
input_data = input_data.data_directory |
|
|
297 |
elif os.path.isdir(input_data): |
|
|
298 |
input_files = [os.path.join(input_data, f) for f in os.listdir(input_data) if f.endswith('.txt')] |
|
|
299 |
else: |
|
|
300 |
raise ValueError(f"'input_data' must be a string (which can be a directory path) or a Dataset, but is {repr(input_data)}") |
|
|
301 |
|
|
|
302 |
if prediction_directory is None: |
|
|
303 |
prediction_directory = os.path.join(input_data, 'predictions') |
|
|
304 |
if os.path.isdir(prediction_directory): |
|
|
305 |
logging.warning("Overwriting existing predictions at %s", prediction_directory) |
|
|
306 |
else: |
|
|
307 |
os.mkdir(prediction_directory) |
|
|
308 |
|
|
|
309 |
for file_path in input_files: |
|
|
310 |
file_name = os.path.basename(file_path).strip('.txt') |
|
|
311 |
logging.info("Predicting file: %s", file_path) |
|
|
312 |
|
|
|
313 |
with open(file_path, 'r') as f: |
|
|
314 |
doc = self.pipeline.spacy_pipeline.make_doc(f.read()) |
|
|
315 |
|
|
|
316 |
doc.set_extension('file_name', default=None, force=True) |
|
|
317 |
doc._.file_name = file_name |
|
|
318 |
|
|
|
319 |
# run through the pipeline |
|
|
320 |
doc = self.pipeline(doc, predict=True) |
|
|
321 |
|
|
|
322 |
# Predict, creating a new Annotations object |
|
|
323 |
annotations = self._predict_document(doc) |
|
|
324 |
logging.debug("Writing to: %s", os.path.join(prediction_directory, file_name + ".ann")) |
|
|
325 |
annotations.to_ann(write_location=os.path.join(prediction_directory, file_name + ".ann")) |
|
|
326 |
|
|
|
327 |
# Copy the txt file so that the output will also be a Dataset |
|
|
328 |
copyfile(file_path, os.path.join(prediction_directory, file_name + ".txt")) |
|
|
329 |
|
|
|
330 |
return Dataset(prediction_directory) |
|
|
331 |
|
|
|
332 |
def cross_validate(self, training_dataset, num_folds=DEFAULT_NUM_FOLDS, prediction_directory=None, groundtruth_directory=None): |
|
|
333 |
""" |
|
|
334 |
Performs k-fold stratified cross-validation using our model and pipeline. |
|
|
335 |
|
|
|
336 |
If the training dataset, groundtruth_directory and prediction_directory are passed, intermediate predictions during cross validation |
|
|
337 |
are written to the directory `write_predictions`. This allows one to construct a confusion matrix or to compute |
|
|
338 |
the prediction ambiguity with the methods present in the Dataset class to support pipeline development without |
|
|
339 |
a designated evaluation set. |
|
|
340 |
|
|
|
341 |
:param training_dataset: Dataset that is being cross validated |
|
|
342 |
:param num_folds: number of folds to split training data into for cross validation, defaults to 5 |
|
|
343 |
:param prediction_directory: directory to write predictions of cross validation to |
|
|
344 |
:param groundtruth_directory: directory to write the ground truth MedaCy evaluates on |
|
|
345 |
:return: Prints out performance metrics, if prediction_directory |
|
|
346 |
""" |
|
|
347 |
|
|
|
348 |
if num_folds <= 1: |
|
|
349 |
raise ValueError("Number of folds for cross validation must be greater than 1, but is %s" % repr(num_folds)) |
|
|
350 |
|
|
|
351 |
groundtruth_directory = Path(groundtruth_directory) if groundtruth_directory else False |
|
|
352 |
prediction_directory = Path(prediction_directory) if prediction_directory else False |
|
|
353 |
|
|
|
354 |
for d in [groundtruth_directory, prediction_directory]: |
|
|
355 |
if d and not d.exists(): |
|
|
356 |
raise NotADirectoryError(f"Options groundtruth_directory and predictions_directory must be existing directories, but one is {d}") |
|
|
357 |
|
|
|
358 |
pipeline_report = self.pipeline.get_report() |
|
|
359 |
|
|
|
360 |
self.preprocess(training_dataset) |
|
|
361 |
|
|
|
362 |
if not (self.X_data and self.y_data): |
|
|
363 |
raise RuntimeError("Must have features and labels extracted for cross validation") |
|
|
364 |
|
|
|
365 |
tags = sorted(self.pipeline.entities) |
|
|
366 |
logging.info(f'Tagset: {tags}') |
|
|
367 |
|
|
|
368 |
eval_stats = {} |
|
|
369 |
|
|
|
370 |
# Dict for storing mapping of sequences to their corresponding file |
|
|
371 |
fold_groundtruth_dicts = [] |
|
|
372 |
fold_prediction_dicts = [] |
|
|
373 |
file_names = {x.file_name for x in self.X_data} |
|
|
374 |
|
|
|
375 |
folds = create_folds(self.y_data, num_folds) |
|
|
376 |
|
|
|
377 |
for fold_num, fold_data in enumerate(folds, 1): |
|
|
378 |
train_indices, test_indices = fold_data |
|
|
379 |
fold_statistics = {} |
|
|
380 |
learner_name, learner = self.pipeline.get_learner() |
|
|
381 |
|
|
|
382 |
X_train = [self.X_data[index] for index in train_indices] |
|
|
383 |
y_train = [self.y_data[index] for index in train_indices] |
|
|
384 |
|
|
|
385 |
X_test = [self.X_data[index] for index in test_indices] |
|
|
386 |
y_test = [self.y_data[index] for index in test_indices] |
|
|
387 |
|
|
|
388 |
logging.info("Training Fold %i", fold_num) |
|
|
389 |
train_data = [x[0] for x in X_train] |
|
|
390 |
test_data = [x[0] for x in X_test] |
|
|
391 |
learner.fit(train_data, y_train) |
|
|
392 |
y_pred = learner.predict(test_data) |
|
|
393 |
|
|
|
394 |
if groundtruth_directory is not None: |
|
|
395 |
ann_dict = sequence_to_ann(X_test, y_test, file_names) |
|
|
396 |
fold_groundtruth_dicts.append(ann_dict) |
|
|
397 |
|
|
|
398 |
if prediction_directory is not None: |
|
|
399 |
ann_dict = sequence_to_ann(X_test, y_pred, file_names) |
|
|
400 |
fold_prediction_dicts.append(ann_dict) |
|
|
401 |
|
|
|
402 |
# Write the metrics for this fold. |
|
|
403 |
for label in tags: |
|
|
404 |
fold_statistics[label] = { |
|
|
405 |
"recall": metrics.flat_recall_score(y_test, y_pred, average='weighted', labels=[label]), |
|
|
406 |
"precision": metrics.flat_precision_score(y_test, y_pred, average='weighted', labels=[label]), |
|
|
407 |
"f1": metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=[label]) |
|
|
408 |
} |
|
|
409 |
|
|
|
410 |
# add averages |
|
|
411 |
fold_statistics['system'] = { |
|
|
412 |
"recall": metrics.flat_recall_score(y_test, y_pred, average='weighted', labels=tags), |
|
|
413 |
"precision": metrics.flat_precision_score(y_test, y_pred, average='weighted', labels=tags), |
|
|
414 |
"f1": metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=tags) |
|
|
415 |
} |
|
|
416 |
|
|
|
417 |
table_data = [ |
|
|
418 |
[label, |
|
|
419 |
format(fold_statistics[label]['precision'], ".3f"), |
|
|
420 |
format(fold_statistics[label]['recall'], ".3f"), |
|
|
421 |
format(fold_statistics[label]['f1'], ".3f") |
|
|
422 |
] for label in tags + ['system'] |
|
|
423 |
] |
|
|
424 |
|
|
|
425 |
logging.info('\n' + tabulate(table_data, headers=['Entity', 'Precision', 'Recall', 'F1'], tablefmt='orgtbl')) |
|
|
426 |
|
|
|
427 |
eval_stats[fold_num] = fold_statistics |
|
|
428 |
|
|
|
429 |
statistics_all_folds = {} |
|
|
430 |
|
|
|
431 |
for label in tags + ['system']: |
|
|
432 |
statistics_all_folds[label] = { |
|
|
433 |
'precision_average': mean(eval_stats[fold][label]['precision'] for fold in eval_stats), |
|
|
434 |
'precision_max': max(eval_stats[fold][label]['precision'] for fold in eval_stats), |
|
|
435 |
'precision_min': min(eval_stats[fold][label]['precision'] for fold in eval_stats), |
|
|
436 |
'recall_average': mean(eval_stats[fold][label]['recall'] for fold in eval_stats), |
|
|
437 |
'recall_max': max(eval_stats[fold][label]['recall'] for fold in eval_stats), |
|
|
438 |
'f1_average': mean(eval_stats[fold][label]['f1'] for fold in eval_stats), |
|
|
439 |
'f1_max': max(eval_stats[fold][label]['f1'] for fold in eval_stats), |
|
|
440 |
'f1_min': min(eval_stats[fold][label]['f1'] for fold in eval_stats), |
|
|
441 |
} |
|
|
442 |
|
|
|
443 |
entity_counts = training_dataset.compute_counts() |
|
|
444 |
entity_counts['system'] = sum(v for k, v in entity_counts.items() if k in self.pipeline.entities) |
|
|
445 |
|
|
|
446 |
table_data = [ |
|
|
447 |
[f"{label} ({entity_counts[label]})", # Entity (Count) |
|
|
448 |
format(statistics_all_folds[label]['precision_average'], ".3f"), |
|
|
449 |
format(statistics_all_folds[label]['recall_average'], ".3f"), |
|
|
450 |
format(statistics_all_folds[label]['f1_average'], ".3f"), |
|
|
451 |
format(statistics_all_folds[label]['f1_min'], ".3f"), |
|
|
452 |
format(statistics_all_folds[label]['f1_max'], ".3f") |
|
|
453 |
] for label in tags + ['system'] |
|
|
454 |
] |
|
|
455 |
|
|
|
456 |
# Combine the pipeline report and the resulting data, then log it or print it (whichever ensures that it prints) |
|
|
457 |
|
|
|
458 |
output_str = '\n' + pipeline_report + '\n\n' + tabulate( |
|
|
459 |
table_data, |
|
|
460 |
headers=['Entity (Count)', 'Precision', 'Recall', 'F1', 'F1_Min', 'F1_Max'], |
|
|
461 |
tablefmt='orgtbl' |
|
|
462 |
) |
|
|
463 |
|
|
|
464 |
if logging.root.level > logging.INFO: |
|
|
465 |
print(output_str) |
|
|
466 |
else: |
|
|
467 |
logging.info(output_str) |
|
|
468 |
|
|
|
469 |
# Write groundtruth and predictions to file |
|
|
470 |
if groundtruth_directory: |
|
|
471 |
write_ann_dicts(groundtruth_directory, fold_groundtruth_dicts) |
|
|
472 |
if prediction_directory: |
|
|
473 |
write_ann_dicts(prediction_directory, fold_prediction_dicts) |
|
|
474 |
|
|
|
475 |
return statistics_all_folds |
|
|
476 |
|
|
|
477 |
def _run_through_pipeline(self, data_file): |
|
|
478 |
""" |
|
|
479 |
Runs a DataFile through the pipeline, returning the resulting Doc object |
|
|
480 |
:param data_file: instance of DataFile |
|
|
481 |
:return: a Doc object |
|
|
482 |
""" |
|
|
483 |
nlp = self.pipeline.spacy_pipeline |
|
|
484 |
logging.info("Processing file: %s", data_file.file_name) |
|
|
485 |
|
|
|
486 |
with open(data_file.txt_path, 'r', encoding='utf-8') as f: |
|
|
487 |
doc = nlp.make_doc(f.read()) |
|
|
488 |
|
|
|
489 |
# Link ann_path to doc |
|
|
490 |
doc.set_extension('gold_annotation_file', default=None, force=True) |
|
|
491 |
doc.set_extension('file_name', default=None, force=True) |
|
|
492 |
|
|
|
493 |
doc._.gold_annotation_file = data_file.ann_path |
|
|
494 |
doc._.file_name = data_file.txt_path |
|
|
495 |
|
|
|
496 |
# run 'er through |
|
|
497 |
return self.pipeline(doc) |
|
|
498 |
|
|
|
499 |
def _extract_features(self, doc): |
|
|
500 |
""" |
|
|
501 |
Extracts features from a Doc |
|
|
502 |
:param doc: an instance of Doc |
|
|
503 |
:return: a tuple of the feature dict and label list |
|
|
504 |
""" |
|
|
505 |
|
|
|
506 |
feature_extractor = self.pipeline.get_feature_extractor() |
|
|
507 |
features, labels = feature_extractor(doc) |
|
|
508 |
|
|
|
509 |
logging.info(f"{doc._.file_name}: Feature Extraction Completed (num_sequences={len(labels)})") |
|
|
510 |
return features, labels |
|
|
511 |
|
|
|
512 |
def load(self, path): |
|
|
513 |
""" |
|
|
514 |
Loads a pickled model. |
|
|
515 |
|
|
|
516 |
:param path: File path to directory where fitted model should be dumped |
|
|
517 |
:return: |
|
|
518 |
""" |
|
|
519 |
model_name, model = self.pipeline.get_learner() |
|
|
520 |
|
|
|
521 |
if model_name == 'BiLSTM+CRF' or model_name == 'BERT': |
|
|
522 |
model.load(path) |
|
|
523 |
self.model = model |
|
|
524 |
else: |
|
|
525 |
self.model = joblib.load(path) |
|
|
526 |
|
|
|
527 |
def dump(self, path): |
|
|
528 |
""" |
|
|
529 |
Dumps a model into a pickle file |
|
|
530 |
|
|
|
531 |
:param path: Directory path to dump the model |
|
|
532 |
:return: |
|
|
533 |
""" |
|
|
534 |
if self.model is None: |
|
|
535 |
raise RuntimeError("Must fit model before dumping.") |
|
|
536 |
|
|
|
537 |
model_name, _ = self.pipeline.get_learner() |
|
|
538 |
|
|
|
539 |
if model_name == 'BiLSTM+CRF' or model_name == 'BERT': |
|
|
540 |
self.model.save(path) |
|
|
541 |
else: |
|
|
542 |
joblib.dump(self.model, path) |
|
|
543 |
|
|
|
544 |
@staticmethod |
|
|
545 |
def load_external(package_name): |
|
|
546 |
""" |
|
|
547 |
Loads an external medaCy compatible Model. Require's the models package to be installed |
|
|
548 |
Alternatively, you can import the package directly and call it's .load() method. |
|
|
549 |
|
|
|
550 |
:param package_name: the package name of the model |
|
|
551 |
:return: an instance of Model that is configured and loaded - ready for prediction. |
|
|
552 |
""" |
|
|
553 |
if importlib.util.find_spec(package_name) is None: |
|
|
554 |
raise ImportError("Package not installed: %s" % package_name) |
|
|
555 |
return importlib.import_module(package_name).load() |