Switch to unified view

a b/medacy/tests/data/test_annotation.py
1
import os
2
import shutil
3
import tempfile
4
import unittest
5
6
import pkg_resources
7
8
from medacy.data.annotations import Annotations
9
from medacy.data.dataset import Dataset
10
from medacy.tests.sample_data import test_dir
11
12
13
class TestAnnotation(unittest.TestCase):
14
    """Tests for medacy.data.annotations.Annotations"""
15
16
    @classmethod
17
    def setUpClass(cls):
18
        """Loads sample dataset and sets up a temporary directory for IO tests"""
19
        cls.test_dir = tempfile.mkdtemp()  # set up temp directory
20
        cls.sample_data_dir = os.path.join(test_dir, 'sample_dataset_1')
21
        cls.dataset = Dataset(cls.sample_data_dir)
22
        cls.entities = cls.dataset.get_labels(as_list=True)
23
24
        with open(os.path.join(cls.test_dir, "broken_ann_file.ann"), 'w') as f:
25
            f.write("This is clearly not a valid ann file")
26
27
        cls.ann_path_1 = cls.dataset.data_files[0].ann_path
28
        cls.ann_path_2 = cls.dataset.data_files[1].ann_path
29
30
    @classmethod
31
    def tearDownClass(cls):
32
        """Removes test temp directory and deletes all files"""
33
        pkg_resources.cleanup_resources()
34
        shutil.rmtree(cls.test_dir)
35
36
    def _test_is_sorted(self, ann):
37
        expected = sorted([e[1] for e in ann])
38
        actual = [e[1] for e in ann]
39
        self.assertListEqual(actual, expected)
40
41
    def test_init_from_ann_file(self):
42
        """Tests initialization from valid ann file"""
43
        ann = Annotations(self.ann_path_1)
44
        self._test_is_sorted(ann)
45
46
    def test_init_from_invalid_ann(self):
47
        """Tests initialization from invalid annotation file"""
48
        with self.assertRaises(FileNotFoundError):
49
            Annotations("not_a_file_path")
50
51
    def test_init_tuples(self):
52
        """Tests the creation of individual annotation tuples, including ones with non-contiguous spans"""
53
        temp_path = os.path.join(self.test_dir, 'tuples.ann')
54
55
        samples = [
56
            ("T1\tObject 66 77\tthis is some text\n", ('Object', 66, 77, 'this is some text')),
57
            ("T2\tEntity 44 55;66 77\tI love NER\n", ('Entity', 44, 77, 'I love NER')),
58
            ("T3\tThingy 66 77;88 99;100 188\tthis is some sample text\n", ('Thingy', 66, 188, 'this is some sample text'))
59
        ]
60
61
        for string, expected in samples:
62
            with open(temp_path, 'w') as f:
63
                f.write(string)
64
65
            resulting_ann = Annotations(temp_path)
66
            actual = resulting_ann.annotations[0]
67
            self.assertTupleEqual(actual, expected)
68
69
    def test_ann_conversions(self):
70
        """Tests converting and un-converting a valid Annotations object to an ANN file."""
71
        self.maxDiff = None
72
        annotations = Annotations(self.ann_path_1)
73
        temp_path = os.path.join(self.test_dir, "intermediary.ann")
74
        annotations.to_ann(write_location=temp_path)
75
        annotations2 = Annotations(temp_path)
76
        self.assertListEqual(annotations.annotations, annotations2.annotations)
77
78
    def test_difference(self):
79
        """Tests that when a given Annotations object uses the diff() method with another Annotations object created
80
        from the same source file, that it returns an empty list."""
81
        ann = Annotations(self.ann_path_1)
82
        result = ann.difference(ann)
83
        self.assertFalse(result)
84
85
    def test_different_file_diff(self):
86
        """Tests that when two different files are used in the difference method, the output is a list with more than
87
        one value."""
88
        ann_1 = Annotations(self.ann_path_1)
89
        ann_2 = Annotations(self.ann_path_2)
90
        result = ann_1.difference(ann_2)
91
        self.assertGreater(len(result), 0)
92
93
    def test_compute_ambiguity(self):
94
        ann_1 = Annotations(self.ann_path_1)
95
        ann_1_copy = Annotations(self.ann_path_1)
96
        ambiguity = ann_1.compute_ambiguity(ann_1_copy)
97
        # The number of overlapping spans for the selected ann file is known to be 25
98
        self.assertEqual(25, len(ambiguity))
99
        # Manually introduce ambiguity by changing the name of an entity in the copy
100
        first_tuple = ann_1_copy.annotations[0]
101
        ann_1_copy.annotations[0] = ('different_name', first_tuple[1], first_tuple[2], first_tuple[3])
102
        ambiguity = ann_1.compute_ambiguity(ann_1_copy)
103
        # See if this increased the ambiguity score by one
104
        self.assertEqual(26, len(ambiguity))
105
106
    def test_confusion_matrix(self):
107
        ann_1 = Annotations(self.ann_path_1)
108
        ann_2 = Annotations(self.ann_path_2)
109
        ann_1.add_entity(*ann_2.annotations[0])
110
        self.assertEqual(len(ann_1.compute_confusion_matrix(ann_2, self.entities)[0]), len(self.entities))
111
        self.assertEqual(len(ann_1.compute_confusion_matrix(ann_2, self.entities)), len(self.entities))
112
113
    def test_intersection(self):
114
        ann_1 = Annotations(self.ann_path_1)
115
        ann_2 = Annotations(self.ann_path_2)
116
        ann_1.add_entity(*ann_2.annotations[0])
117
        ann_1.add_entity(*ann_2.annotations[1])
118
        expected = {ann_2.annotations[0], ann_2.annotations[1]}
119
        actual = ann_1.intersection(ann_2)
120
        self.assertSetEqual(actual, expected)
121
122
    def test_compute_counts(self):
123
        ann_1 = Annotations(self.ann_path_1)
124
        self.assertIsInstance(ann_1.compute_counts(), dict)
125
126
    def test_or(self):
127
        """
128
        Tests that the pipe operator correctly merges two Annotations and retains the source text path of
129
        the left operand
130
        """
131
        tup_1 = ('Object', 66, 77, 'this is some text')
132
        tup_2 = ('Entity', 44, 77, 'I love NER')
133
        tup_3 = ('Thingy', 66, 188, 'this is some sample text')
134
        file_name = 'some_file'
135
136
        ann_1 = Annotations([tup_1, tup_2], source_text_path=file_name)
137
        ann_2 = Annotations([tup_3])
138
139
        for a in [ann_1, ann_2]:
140
            self._test_is_sorted(a)
141
142
        # Test __or__
143
        result = ann_1 | ann_2
144
        expected = {tup_1, tup_2, tup_3}
145
        actual = set(result)
146
        self.assertSetEqual(actual, expected)
147
        self.assertEqual(file_name, result.source_text_path)
148
        self._test_is_sorted(result)
149
150
        # Test __ior__
151
        ann_1 |= ann_2
152
        actual = set(ann_1)
153
        self.assertSetEqual(actual, expected)
154
        self._test_is_sorted(ann_1)
155
156
157
if __name__ == '__main__':
158
    unittest.main()