[6c353a]: / medacy / tests / model / test_multi_model.py

Download this file

79 lines (58 with data), 2.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import shutil
import tempfile
import unittest
from medacy.data.dataset import Dataset
from medacy.model.multi_model import MultiModel
from medacy.pipelines.clinical_pipeline import ClinicalPipeline
from medacy.pipelines.testing_pipeline import TestingPipeline
from medacy.tests.sample_data import test_dir
class TestMultiModel(unittest.TestCase):
"""Unit tests for medacy.model.multi_model.MultiModel"""
@classmethod
def setUpClass(cls) -> None:
"""Create a temporary directory for predictions"""
cls.temp_dir = tempfile.mkdtemp()
cls.data_dir = os.path.join(test_dir, 'sample_dataset_1')
cls.sample_model_1_path = os.path.join(test_dir, 'sample_models', 'sample_clin_pipe.pkl')
cls.sample_model_2_path = os.path.join(test_dir, 'sample_models', 'sample_test_pipe.pkl')
@classmethod
def tearDownClass(cls) -> None:
"""Delete the temporary directory"""
shutil.rmtree(cls.temp_dir)
def test_multi_model(self):
"""Runs all tests for valid uses of MultiModel"""
data = Dataset(self.data_dir)
ents_1 = {'Endpoints', 'Species', 'DoseUnits'}
ents_2 = {'TestArticle', 'Dose', 'Sex'}
multimodel = MultiModel()
# Test that *args works
multimodel.add_model(self.sample_model_1_path, ClinicalPipeline, list(ents_1))
# Test that **kwargs works
multimodel.add_model(self.sample_model_2_path, TestingPipeline, entities=list(ents_2))
# Test __len__
self.assertEqual(len(multimodel), 2)
# Test that each model gets instantiated correctly
for model, pipeline_class in zip(multimodel, [ClinicalPipeline, TestingPipeline]):
current_pipeline = model.pipeline
self.assertIsInstance(current_pipeline, pipeline_class)
self.assertGreater(len(current_pipeline.entities), 0)
# Test predict_directory
resulting_data = multimodel.predict_directory(data.data_directory, self.temp_dir)
labeled_items = resulting_data.get_labels()
# Test that at least one label from each model is predicted
self.assertTrue(any(e in ents_1 for e in labeled_items))
self.assertTrue(any(e in ents_2 for e in labeled_items))
# Test that all files get predicted for
self.assertEqual(len(resulting_data), len(data))
def test_errors(self):
"""Tests that invalid inputs raise the appropriate errors"""
multimodel = MultiModel()
# Test add_model with a nonexisting model path
with self.assertRaises(FileNotFoundError):
multimodel.add_model('notafilepath', ClinicalPipeline)
# Test add_model without passing a subclass of BasePipeline
with self.assertRaises(TypeError):
multimodel.add_model(self.sample_model_1_path, 7)
if __name__ == '__main__':
unittest.main()