a b/medacy/tests/model/test_model.py
1
import os
2
import shutil
3
import tempfile
4
import logging
5
import unittest
6
7
import pkg_resources
8
9
from medacy.data.annotations import Annotations
10
from medacy.data.dataset import Dataset
11
from medacy.model.model import Model
12
from medacy.pipelines.testing_pipeline import TestingPipeline
13
from medacy.tests.sample_data import test_dir
14
15
16
class TestModel(unittest.TestCase):
17
    """Tests for medacy.model.model.Model"""
18
19
    @classmethod
20
    def setUpClass(cls):
21
        cls.dataset = Dataset(os.path.join(test_dir, 'sample_dataset_1'))
22
        cls.entities = cls.dataset.get_labels(as_list=True)
23
        cls.prediction_directory = tempfile.mkdtemp()  # directory to store predictions
24
        cls.prediction_directory_2 = tempfile.mkdtemp()
25
        cls.prediction_directory_3 = tempfile.mkdtemp()
26
        cls.groundtruth_directory = tempfile.mkdtemp()
27
        cls.groundtruth_2_directory = tempfile.mkdtemp()
28
        cls.pipeline = TestingPipeline(entities=cls.entities)
29
30
    @classmethod
31
    def tearDownClass(cls):
32
        pkg_resources.cleanup_resources()
33
        for d in [cls.prediction_directory, cls.prediction_directory_2,
34
                  cls.prediction_directory_3, cls.groundtruth_directory, cls.groundtruth_2_directory]:
35
            shutil.rmtree(d)
36
37
    def test_fit_predict_dump_load(self):
38
        """Fits a model, tests that it predicts correctly, dumps and loads it, then tests that it still predicts"""
39
40
        model = Model(self.pipeline)
41
42
        # Test attempting to predict before fitting
43
        with self.assertRaises(RuntimeError):
44
            model.predict('Lorem ipsum dolor sit amet.')
45
46
        model.fit(self.dataset, groundtruth_directory=self.groundtruth_2_directory)
47
        # Test X and y data are set
48
        self.assertTrue(model.X_data)
49
        self.assertTrue(model.y_data)
50
51
        # Test that there is at least one prediction
52
        resulting_ann = model.predict('To exclude the possibility that alterations in PSSD might be a consequence of changes in the volume of reference, we used a subset of the vibratome sections')
53
        self.assertIsInstance(resulting_ann, Annotations)
54
        self.assertTrue(resulting_ann)
55
56
        # Test prediction over directory
57
        resulting_dataset = model.predict(self.dataset.data_directory, prediction_directory=self.prediction_directory)
58
        self.assertIsInstance(resulting_dataset, Dataset)
59
        self.assertEqual(len(self.dataset), len(resulting_dataset))
60
61
        # Test that groundtruth is written
62
        groundtruth_dataset = Dataset(self.groundtruth_2_directory)
63
        expected = [d.file_name for d in self.dataset]
64
        actual = [d.file_name for d in groundtruth_dataset]
65
        self.assertListEqual(expected, actual)
66
67
        # Test that the groundtruth ann files have content
68
        for ann in groundtruth_dataset.generate_annotations():
69
            self.assertTrue(ann)
70
71
        # Test pickling a model
72
        pickle_path = os.path.join(self.prediction_directory, 'test.pkl')
73
        model.dump(pickle_path)
74
        new_model = Model(self.pipeline)
75
        new_model.load(pickle_path)
76
77
        # Test that there is at least one prediction
78
        resulting_ann = new_model.predict('To exclude the possibility that alterations in PSSD might be a consequence of changes in the volume of reference, we used a subset of the vibratome sections')
79
        self.assertIsInstance(resulting_ann, Annotations)
80
        self.assertTrue(resulting_ann)
81
82
    def test_predict(self):
83
        """
84
        predict() has different functionality depending on what is passed to it; therefore this test
85
        ensures that each type of input is handled correctly
86
        """
87
88
        # Init the Model
89
        pipe = TestingPipeline(entities=self.entities)
90
        sample_model_path = os.path.join(test_dir, 'sample_models', 'sample_test_pipe.pkl')
91
        model = Model(pipe)
92
        model.load(sample_model_path)
93
94
        # Test passing a Dataset
95
        dataset_output = model.predict(self.dataset)
96
        self.assertIsInstance(dataset_output, Dataset)
97
        self.assertEqual(len(dataset_output), len(self.dataset))
98
99
        # Test passing a directory
100
        directory_output = model.predict(self.dataset.data_directory)
101
        self.assertIsInstance(directory_output, Dataset)
102
        self.assertEqual(len(directory_output), len(self.dataset))
103
104
        # Test passing a string
105
        string_output = model.predict('This is a sample string.')
106
        self.assertIsInstance(string_output, Annotations)
107
108
        # Test that the predictions are written to the expected location when no path is provided
109
        expected_dir = os.path.join(self.dataset.data_directory, 'predictions')
110
        self.assertTrue(os.path.isdir(expected_dir))
111
112
        # Delete that directory
113
        shutil.rmtree(expected_dir)
114
115
        # Test predicting to a specific directory
116
        model.predict(self.dataset.data_directory, prediction_directory=self.prediction_directory_2)
117
        expected_files = os.listdir(self.prediction_directory_2)
118
        self.assertEqual(6, len(expected_files))
119
120
    def test_cross_validate(self):
121
        """Ensures that changes made in the package do not prevent cross_validate from running to completion"""
122
        model = Model(self.pipeline)
123
124
        # Test that invalid fold counts raise ValueError
125
        for num in [-1, 0, 1]:
126
            with self.assertRaises(ValueError):
127
                model.cross_validate(self.dataset, num)
128
129
        try:
130
            resulting_data = model.cross_validate(self.dataset, 2)
131
            # Checking the log can help verify that the results of cross validation are expectable
132
            logging.debug(resulting_data)
133
        except:
134
            self.assertTrue(False)
135
136
    def test_run_through_pipeline(self):
137
        """
138
        Tests that this function runs a document through the pipeline by testing that it has attributes
139
        overlayed by the pipeline
140
        """
141
        model = Model(self.pipeline)
142
        sample_df = list(self.dataset)[0]
143
        result = model._run_through_pipeline(sample_df)
144
145
        expected = sample_df.txt_path
146
        actual = result._.file_name
147
        self.assertEqual(actual, expected)
148
149
        expected = sample_df.ann_path
150
        actual = result._.gold_annotation_file
151
        self.assertEqual(actual, expected)
152
153
    def test_cross_validate_create_groundtruth_predictions(self):
154
        """
155
        Tests that during cross validation, the medaCy groundtruth (that is, the version of the training dataset
156
        used by medaCy) is written as well as the predictions that are created for each fold
157
        """
158
        model = Model(self.pipeline)
159
        model.cross_validate(
160
            self.dataset,
161
            num_folds=2,
162
            prediction_directory=self.prediction_directory_3,
163
            groundtruth_directory=self.groundtruth_directory
164
        )
165
166
        prediction_dataset = Dataset(self.prediction_directory_3)
167
        groundtruth_dataset = Dataset(self.groundtruth_directory)
168
169
        for d in [prediction_dataset, groundtruth_dataset]:
170
            self.assertIsInstance(d, Dataset)
171
172
        original_file_names = {d.file_name for d in self.dataset}
173
        prediction_file_names = {d.file_name for d in prediction_dataset}
174
        groundtruth_file_names = {d.file_name for d in groundtruth_dataset}
175
176
        for n in [prediction_file_names, groundtruth_file_names]:
177
            self.assertSetEqual(n, original_file_names)
178
179
        # Container for all Annotations in all files in all folds
180
        all_anns_all_folds_actual = Annotations([])
181
182
        # Test that fold groundtruth is written to file
183
        for fold_name in ["fold_1", "fold_2"]:
184
            fold_dataset = Dataset(groundtruth_dataset.data_directory / fold_name)
185
            for d in fold_dataset:
186
                fold_ann = Annotations(d.ann_path)
187
                groundtruth_ann = groundtruth_dataset[d.file_name]
188
                # Test that the entities in the fold groundtruth are a subset of the whole for that file
189
                self.assertTrue(set(fold_ann) <= set(groundtruth_ann))
190
                all_anns_all_folds_actual |= fold_ann
191
192
        # Container for all annotations pulled directly from the groundtruth dataset
193
        all_groundtruth_tuples = Annotations([])
194
        for ann in groundtruth_dataset.generate_annotations():
195
            all_groundtruth_tuples |= ann
196
197
        expected = set(all_groundtruth_tuples)
198
        actual = set(all_anns_all_folds_actual)
199
        self.assertSetEqual(expected, actual)
200
201
202
if __name__ == '__main__':
203
    unittest.main()