[6c353a]: / medacy / tests / data / test_annotation.py

Download this file

159 lines (129 with data), 6.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import shutil
import tempfile
import unittest
import pkg_resources
from medacy.data.annotations import Annotations
from medacy.data.dataset import Dataset
from medacy.tests.sample_data import test_dir
class TestAnnotation(unittest.TestCase):
"""Tests for medacy.data.annotations.Annotations"""
@classmethod
def setUpClass(cls):
"""Loads sample dataset and sets up a temporary directory for IO tests"""
cls.test_dir = tempfile.mkdtemp() # set up temp directory
cls.sample_data_dir = os.path.join(test_dir, 'sample_dataset_1')
cls.dataset = Dataset(cls.sample_data_dir)
cls.entities = cls.dataset.get_labels(as_list=True)
with open(os.path.join(cls.test_dir, "broken_ann_file.ann"), 'w') as f:
f.write("This is clearly not a valid ann file")
cls.ann_path_1 = cls.dataset.data_files[0].ann_path
cls.ann_path_2 = cls.dataset.data_files[1].ann_path
@classmethod
def tearDownClass(cls):
"""Removes test temp directory and deletes all files"""
pkg_resources.cleanup_resources()
shutil.rmtree(cls.test_dir)
def _test_is_sorted(self, ann):
expected = sorted([e[1] for e in ann])
actual = [e[1] for e in ann]
self.assertListEqual(actual, expected)
def test_init_from_ann_file(self):
"""Tests initialization from valid ann file"""
ann = Annotations(self.ann_path_1)
self._test_is_sorted(ann)
def test_init_from_invalid_ann(self):
"""Tests initialization from invalid annotation file"""
with self.assertRaises(FileNotFoundError):
Annotations("not_a_file_path")
def test_init_tuples(self):
"""Tests the creation of individual annotation tuples, including ones with non-contiguous spans"""
temp_path = os.path.join(self.test_dir, 'tuples.ann')
samples = [
("T1\tObject 66 77\tthis is some text\n", ('Object', 66, 77, 'this is some text')),
("T2\tEntity 44 55;66 77\tI love NER\n", ('Entity', 44, 77, 'I love NER')),
("T3\tThingy 66 77;88 99;100 188\tthis is some sample text\n", ('Thingy', 66, 188, 'this is some sample text'))
]
for string, expected in samples:
with open(temp_path, 'w') as f:
f.write(string)
resulting_ann = Annotations(temp_path)
actual = resulting_ann.annotations[0]
self.assertTupleEqual(actual, expected)
def test_ann_conversions(self):
"""Tests converting and un-converting a valid Annotations object to an ANN file."""
self.maxDiff = None
annotations = Annotations(self.ann_path_1)
temp_path = os.path.join(self.test_dir, "intermediary.ann")
annotations.to_ann(write_location=temp_path)
annotations2 = Annotations(temp_path)
self.assertListEqual(annotations.annotations, annotations2.annotations)
def test_difference(self):
"""Tests that when a given Annotations object uses the diff() method with another Annotations object created
from the same source file, that it returns an empty list."""
ann = Annotations(self.ann_path_1)
result = ann.difference(ann)
self.assertFalse(result)
def test_different_file_diff(self):
"""Tests that when two different files are used in the difference method, the output is a list with more than
one value."""
ann_1 = Annotations(self.ann_path_1)
ann_2 = Annotations(self.ann_path_2)
result = ann_1.difference(ann_2)
self.assertGreater(len(result), 0)
def test_compute_ambiguity(self):
ann_1 = Annotations(self.ann_path_1)
ann_1_copy = Annotations(self.ann_path_1)
ambiguity = ann_1.compute_ambiguity(ann_1_copy)
# The number of overlapping spans for the selected ann file is known to be 25
self.assertEqual(25, len(ambiguity))
# Manually introduce ambiguity by changing the name of an entity in the copy
first_tuple = ann_1_copy.annotations[0]
ann_1_copy.annotations[0] = ('different_name', first_tuple[1], first_tuple[2], first_tuple[3])
ambiguity = ann_1.compute_ambiguity(ann_1_copy)
# See if this increased the ambiguity score by one
self.assertEqual(26, len(ambiguity))
def test_confusion_matrix(self):
ann_1 = Annotations(self.ann_path_1)
ann_2 = Annotations(self.ann_path_2)
ann_1.add_entity(*ann_2.annotations[0])
self.assertEqual(len(ann_1.compute_confusion_matrix(ann_2, self.entities)[0]), len(self.entities))
self.assertEqual(len(ann_1.compute_confusion_matrix(ann_2, self.entities)), len(self.entities))
def test_intersection(self):
ann_1 = Annotations(self.ann_path_1)
ann_2 = Annotations(self.ann_path_2)
ann_1.add_entity(*ann_2.annotations[0])
ann_1.add_entity(*ann_2.annotations[1])
expected = {ann_2.annotations[0], ann_2.annotations[1]}
actual = ann_1.intersection(ann_2)
self.assertSetEqual(actual, expected)
def test_compute_counts(self):
ann_1 = Annotations(self.ann_path_1)
self.assertIsInstance(ann_1.compute_counts(), dict)
def test_or(self):
"""
Tests that the pipe operator correctly merges two Annotations and retains the source text path of
the left operand
"""
tup_1 = ('Object', 66, 77, 'this is some text')
tup_2 = ('Entity', 44, 77, 'I love NER')
tup_3 = ('Thingy', 66, 188, 'this is some sample text')
file_name = 'some_file'
ann_1 = Annotations([tup_1, tup_2], source_text_path=file_name)
ann_2 = Annotations([tup_3])
for a in [ann_1, ann_2]:
self._test_is_sorted(a)
# Test __or__
result = ann_1 | ann_2
expected = {tup_1, tup_2, tup_3}
actual = set(result)
self.assertSetEqual(actual, expected)
self.assertEqual(file_name, result.source_text_path)
self._test_is_sorted(result)
# Test __ior__
ann_1 |= ann_2
actual = set(ann_1)
self.assertSetEqual(actual, expected)
self._test_is_sorted(ann_1)
if __name__ == '__main__':
unittest.main()