|
a |
|
b/medacy/tests/model/test_model.py |
|
|
1 |
import os |
|
|
2 |
import shutil |
|
|
3 |
import tempfile |
|
|
4 |
import logging |
|
|
5 |
import unittest |
|
|
6 |
|
|
|
7 |
import pkg_resources |
|
|
8 |
|
|
|
9 |
from medacy.data.annotations import Annotations |
|
|
10 |
from medacy.data.dataset import Dataset |
|
|
11 |
from medacy.model.model import Model |
|
|
12 |
from medacy.pipelines.testing_pipeline import TestingPipeline |
|
|
13 |
from medacy.tests.sample_data import test_dir |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class TestModel(unittest.TestCase): |
|
|
17 |
"""Tests for medacy.model.model.Model""" |
|
|
18 |
|
|
|
19 |
@classmethod |
|
|
20 |
def setUpClass(cls): |
|
|
21 |
cls.dataset = Dataset(os.path.join(test_dir, 'sample_dataset_1')) |
|
|
22 |
cls.entities = cls.dataset.get_labels(as_list=True) |
|
|
23 |
cls.prediction_directory = tempfile.mkdtemp() # directory to store predictions |
|
|
24 |
cls.prediction_directory_2 = tempfile.mkdtemp() |
|
|
25 |
cls.prediction_directory_3 = tempfile.mkdtemp() |
|
|
26 |
cls.groundtruth_directory = tempfile.mkdtemp() |
|
|
27 |
cls.groundtruth_2_directory = tempfile.mkdtemp() |
|
|
28 |
cls.pipeline = TestingPipeline(entities=cls.entities) |
|
|
29 |
|
|
|
30 |
@classmethod |
|
|
31 |
def tearDownClass(cls): |
|
|
32 |
pkg_resources.cleanup_resources() |
|
|
33 |
for d in [cls.prediction_directory, cls.prediction_directory_2, |
|
|
34 |
cls.prediction_directory_3, cls.groundtruth_directory, cls.groundtruth_2_directory]: |
|
|
35 |
shutil.rmtree(d) |
|
|
36 |
|
|
|
37 |
def test_fit_predict_dump_load(self): |
|
|
38 |
"""Fits a model, tests that it predicts correctly, dumps and loads it, then tests that it still predicts""" |
|
|
39 |
|
|
|
40 |
model = Model(self.pipeline) |
|
|
41 |
|
|
|
42 |
# Test attempting to predict before fitting |
|
|
43 |
with self.assertRaises(RuntimeError): |
|
|
44 |
model.predict('Lorem ipsum dolor sit amet.') |
|
|
45 |
|
|
|
46 |
model.fit(self.dataset, groundtruth_directory=self.groundtruth_2_directory) |
|
|
47 |
# Test X and y data are set |
|
|
48 |
self.assertTrue(model.X_data) |
|
|
49 |
self.assertTrue(model.y_data) |
|
|
50 |
|
|
|
51 |
# Test that there is at least one prediction |
|
|
52 |
resulting_ann = model.predict('To exclude the possibility that alterations in PSSD might be a consequence of changes in the volume of reference, we used a subset of the vibratome sections') |
|
|
53 |
self.assertIsInstance(resulting_ann, Annotations) |
|
|
54 |
self.assertTrue(resulting_ann) |
|
|
55 |
|
|
|
56 |
# Test prediction over directory |
|
|
57 |
resulting_dataset = model.predict(self.dataset.data_directory, prediction_directory=self.prediction_directory) |
|
|
58 |
self.assertIsInstance(resulting_dataset, Dataset) |
|
|
59 |
self.assertEqual(len(self.dataset), len(resulting_dataset)) |
|
|
60 |
|
|
|
61 |
# Test that groundtruth is written |
|
|
62 |
groundtruth_dataset = Dataset(self.groundtruth_2_directory) |
|
|
63 |
expected = [d.file_name for d in self.dataset] |
|
|
64 |
actual = [d.file_name for d in groundtruth_dataset] |
|
|
65 |
self.assertListEqual(expected, actual) |
|
|
66 |
|
|
|
67 |
# Test that the groundtruth ann files have content |
|
|
68 |
for ann in groundtruth_dataset.generate_annotations(): |
|
|
69 |
self.assertTrue(ann) |
|
|
70 |
|
|
|
71 |
# Test pickling a model |
|
|
72 |
pickle_path = os.path.join(self.prediction_directory, 'test.pkl') |
|
|
73 |
model.dump(pickle_path) |
|
|
74 |
new_model = Model(self.pipeline) |
|
|
75 |
new_model.load(pickle_path) |
|
|
76 |
|
|
|
77 |
# Test that there is at least one prediction |
|
|
78 |
resulting_ann = new_model.predict('To exclude the possibility that alterations in PSSD might be a consequence of changes in the volume of reference, we used a subset of the vibratome sections') |
|
|
79 |
self.assertIsInstance(resulting_ann, Annotations) |
|
|
80 |
self.assertTrue(resulting_ann) |
|
|
81 |
|
|
|
82 |
def test_predict(self): |
|
|
83 |
""" |
|
|
84 |
predict() has different functionality depending on what is passed to it; therefore this test |
|
|
85 |
ensures that each type of input is handled correctly |
|
|
86 |
""" |
|
|
87 |
|
|
|
88 |
# Init the Model |
|
|
89 |
pipe = TestingPipeline(entities=self.entities) |
|
|
90 |
sample_model_path = os.path.join(test_dir, 'sample_models', 'sample_test_pipe.pkl') |
|
|
91 |
model = Model(pipe) |
|
|
92 |
model.load(sample_model_path) |
|
|
93 |
|
|
|
94 |
# Test passing a Dataset |
|
|
95 |
dataset_output = model.predict(self.dataset) |
|
|
96 |
self.assertIsInstance(dataset_output, Dataset) |
|
|
97 |
self.assertEqual(len(dataset_output), len(self.dataset)) |
|
|
98 |
|
|
|
99 |
# Test passing a directory |
|
|
100 |
directory_output = model.predict(self.dataset.data_directory) |
|
|
101 |
self.assertIsInstance(directory_output, Dataset) |
|
|
102 |
self.assertEqual(len(directory_output), len(self.dataset)) |
|
|
103 |
|
|
|
104 |
# Test passing a string |
|
|
105 |
string_output = model.predict('This is a sample string.') |
|
|
106 |
self.assertIsInstance(string_output, Annotations) |
|
|
107 |
|
|
|
108 |
# Test that the predictions are written to the expected location when no path is provided |
|
|
109 |
expected_dir = os.path.join(self.dataset.data_directory, 'predictions') |
|
|
110 |
self.assertTrue(os.path.isdir(expected_dir)) |
|
|
111 |
|
|
|
112 |
# Delete that directory |
|
|
113 |
shutil.rmtree(expected_dir) |
|
|
114 |
|
|
|
115 |
# Test predicting to a specific directory |
|
|
116 |
model.predict(self.dataset.data_directory, prediction_directory=self.prediction_directory_2) |
|
|
117 |
expected_files = os.listdir(self.prediction_directory_2) |
|
|
118 |
self.assertEqual(6, len(expected_files)) |
|
|
119 |
|
|
|
120 |
def test_cross_validate(self): |
|
|
121 |
"""Ensures that changes made in the package do not prevent cross_validate from running to completion""" |
|
|
122 |
model = Model(self.pipeline) |
|
|
123 |
|
|
|
124 |
# Test that invalid fold counts raise ValueError |
|
|
125 |
for num in [-1, 0, 1]: |
|
|
126 |
with self.assertRaises(ValueError): |
|
|
127 |
model.cross_validate(self.dataset, num) |
|
|
128 |
|
|
|
129 |
try: |
|
|
130 |
resulting_data = model.cross_validate(self.dataset, 2) |
|
|
131 |
# Checking the log can help verify that the results of cross validation are expectable |
|
|
132 |
logging.debug(resulting_data) |
|
|
133 |
except: |
|
|
134 |
self.assertTrue(False) |
|
|
135 |
|
|
|
136 |
def test_run_through_pipeline(self): |
|
|
137 |
""" |
|
|
138 |
Tests that this function runs a document through the pipeline by testing that it has attributes |
|
|
139 |
overlayed by the pipeline |
|
|
140 |
""" |
|
|
141 |
model = Model(self.pipeline) |
|
|
142 |
sample_df = list(self.dataset)[0] |
|
|
143 |
result = model._run_through_pipeline(sample_df) |
|
|
144 |
|
|
|
145 |
expected = sample_df.txt_path |
|
|
146 |
actual = result._.file_name |
|
|
147 |
self.assertEqual(actual, expected) |
|
|
148 |
|
|
|
149 |
expected = sample_df.ann_path |
|
|
150 |
actual = result._.gold_annotation_file |
|
|
151 |
self.assertEqual(actual, expected) |
|
|
152 |
|
|
|
153 |
def test_cross_validate_create_groundtruth_predictions(self): |
|
|
154 |
""" |
|
|
155 |
Tests that during cross validation, the medaCy groundtruth (that is, the version of the training dataset |
|
|
156 |
used by medaCy) is written as well as the predictions that are created for each fold |
|
|
157 |
""" |
|
|
158 |
model = Model(self.pipeline) |
|
|
159 |
model.cross_validate( |
|
|
160 |
self.dataset, |
|
|
161 |
num_folds=2, |
|
|
162 |
prediction_directory=self.prediction_directory_3, |
|
|
163 |
groundtruth_directory=self.groundtruth_directory |
|
|
164 |
) |
|
|
165 |
|
|
|
166 |
prediction_dataset = Dataset(self.prediction_directory_3) |
|
|
167 |
groundtruth_dataset = Dataset(self.groundtruth_directory) |
|
|
168 |
|
|
|
169 |
for d in [prediction_dataset, groundtruth_dataset]: |
|
|
170 |
self.assertIsInstance(d, Dataset) |
|
|
171 |
|
|
|
172 |
original_file_names = {d.file_name for d in self.dataset} |
|
|
173 |
prediction_file_names = {d.file_name for d in prediction_dataset} |
|
|
174 |
groundtruth_file_names = {d.file_name for d in groundtruth_dataset} |
|
|
175 |
|
|
|
176 |
for n in [prediction_file_names, groundtruth_file_names]: |
|
|
177 |
self.assertSetEqual(n, original_file_names) |
|
|
178 |
|
|
|
179 |
# Container for all Annotations in all files in all folds |
|
|
180 |
all_anns_all_folds_actual = Annotations([]) |
|
|
181 |
|
|
|
182 |
# Test that fold groundtruth is written to file |
|
|
183 |
for fold_name in ["fold_1", "fold_2"]: |
|
|
184 |
fold_dataset = Dataset(groundtruth_dataset.data_directory / fold_name) |
|
|
185 |
for d in fold_dataset: |
|
|
186 |
fold_ann = Annotations(d.ann_path) |
|
|
187 |
groundtruth_ann = groundtruth_dataset[d.file_name] |
|
|
188 |
# Test that the entities in the fold groundtruth are a subset of the whole for that file |
|
|
189 |
self.assertTrue(set(fold_ann) <= set(groundtruth_ann)) |
|
|
190 |
all_anns_all_folds_actual |= fold_ann |
|
|
191 |
|
|
|
192 |
# Container for all annotations pulled directly from the groundtruth dataset |
|
|
193 |
all_groundtruth_tuples = Annotations([]) |
|
|
194 |
for ann in groundtruth_dataset.generate_annotations(): |
|
|
195 |
all_groundtruth_tuples |= ann |
|
|
196 |
|
|
|
197 |
expected = set(all_groundtruth_tuples) |
|
|
198 |
actual = set(all_anns_all_folds_actual) |
|
|
199 |
self.assertSetEqual(expected, actual) |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
if __name__ == '__main__': |
|
|
203 |
unittest.main() |