a b/medacy/data/annotations.py
1
import logging
2
import os
3
import re
4
from collections import Counter, namedtuple
5
from math import ceil
6
7
8
EntTuple = namedtuple('EntTuple', ['tag', 'start', 'end', 'text'])
9
10
11
class Annotations:
12
    """
13
    An Annotations object stores all relevant information needed to manage Annotations over a document.
14
    The Annotation object is utilized by medaCy to structure input to models and output from models.
15
    This object wraps a list of tuples representing the entities in a document.
16
17
    :ivar ann_path: the path to the .ann file
18
    :ivar source_text_path: path to the related .txt file
19
    :ivar annotations: a list of annotation tuples
20
    """
21
22
    brat_pattern = re.compile(r'T(\d+)\t(\S+) ((\d+ \d+;)*\d+ \d+)\t(.+)')
23
24
    def __init__(self, annotation_data, source_text_path=None):
25
        """
26
        :param annotation_data: a file path to an annotation file, or a list of annotation tuples.
27
        Construction from a list of tuples is intended for internal use.
28
        :param source_text_path: optional; path to the text file from which the annotations were derived.
29
        """
30
        if isinstance(annotation_data, list) and all(isinstance(e, tuple) for e in annotation_data):
31
            self.annotations = annotation_data
32
            self.source_text_path = source_text_path
33
            return
34
        elif not os.path.isfile(annotation_data):
35
            raise FileNotFoundError("annotation_data must be a list of tuples or a valid file path, but is %s" % repr(annotation_data))
36
37
        self.ann_path = annotation_data
38
        self.source_text_path = source_text_path
39
        self.annotations = self._init_from_file(annotation_data)
40
41
    @staticmethod
42
    def _init_from_file(file_path):
43
        """
44
        Creates a list of annotation tuples from a file path
45
        :param file_path: the path to an ann file
46
        :return: a list of annotation tuples
47
        """
48
        annotations = []
49
        with open(file_path, 'r', encoding='utf-8') as f:
50
            data = f.read()
51
52
        for match in re.finditer(Annotations.brat_pattern, data):
53
            tag = match.group(2)
54
            spans = match.group(3)
55
            mention = match.group(5)
56
57
            spans = re.findall(r'\d+', spans)
58
            start, end = int(spans[0]), int(spans[-1])
59
60
            new_ent = EntTuple(tag, start, end, mention)
61
            annotations.append(new_ent)
62
63
        return annotations
64
65
    @property
66
    def annotations(self):
67
        return self._annotations
68
69
    @annotations.setter
70
    def annotations(self, value):
71
        """Ensures that annotations are always sorted"""
72
        self._annotations = sorted([EntTuple(*e) for e in value], key=lambda x: (x.start, x.end))
73
74
    def get_labels(self, as_list=False):
75
        """
76
        Get the set of labels from this collection of annotations.
77
        :param as_list: bool for if to return the results as a list; defaults to False
78
        :return: The set of labels.
79
        """
80
        labels = {e[0] for e in self.annotations}
81
82
        if as_list:
83
            return list(labels)
84
        return labels
85
86
    def add_entity(self, label, start, end, text=""):
87
        """
88
        Adds an entity to the Annotations
89
        :param label: the label of the annotation you are appending
90
        :param start: the start index in the document of the annotation you are appending.
91
        :param end: the end index in the document of the annotation you are appending
92
        :param text: the raw text of the annotation you are appending
93
        """
94
        self.annotations.append(EntTuple(label, start, end, text))
95
96
    def to_ann(self, write_location=None):
97
        """
98
        Formats the Annotations object into a string representing a valid ANN file. Optionally writes the formatted
99
        string to a destination.
100
        :param write_location: path of location to write ann file to
101
        :return: returns string formatted as an ann file, if write_location is valid path also writes to that path.
102
        """
103
        ann_string = ""
104
105
        for num, tup in enumerate(self.annotations, 1):
106
            mention = tup.text.replace('\n', ' ')
107
            ann_string += f"T{num}\t{tup.tag} {tup.start} {tup.end}\t{mention}\n"
108
109
        if write_location is not None:
110
            if os.path.isfile(write_location):
111
                logging.warning("Overwriting file at: %s", write_location)
112
            with open(write_location, 'w') as f:
113
                f.write(ann_string)
114
115
        return ann_string
116
117
    def difference(self, other, leniency=0):
118
        """
119
        Identifies the difference between two Annotations objects. Useful for checking if an unverified annotation
120
        matches an annotation known to be accurate. This is done returning a list of all annotations in the operated on
121
        Annotation object that do not exist in the passed in annotation object. This is a set difference.
122
        :param other: Another Annotations object.
123
        :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count
124
        as different. A value of zero considers only exact character matches while a positive value considers entities
125
        that differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
126
        :return: A set of tuples of non-matching annotations.
127
        """
128
        if not isinstance(other, Annotations):
129
            raise ValueError("Annotations.diff() can only accept another Annotations object as an argument.")
130
        if leniency == 0:
131
            return set(self.annotations) - set(other.annotations)
132
        if not 0 <= leniency <= 1:
133
            raise ValueError("Leniency must be a floating point between [0,1]")
134
135
        matches = set()
136
        for ann in self.annotations:
137
            label, start, end, text = ann
138
            window = ceil(leniency * (end - start))
139
            for c_label, c_start, c_end, c_text in other.annotations:
140
                if label == c_label:
141
                    if start - window <= c_start and end + window >= c_end:
142
                        matches.add(ann)
143
                        break
144
145
        return set(self.annotations) - matches
146
147
    def intersection(self, other, leniency=0):
148
        """
149
        Computes the intersection of the operated annotation object with the operand annotation object.
150
        :param other: Another Annotations object.
151
        :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as
152
        different. A value of zero considers only exact character matches while a positive value considers entities that
153
         differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
154
        :return A set of annotations that appear in both Annotation objects
155
        """
156
        if not isinstance(other, Annotations):
157
            raise ValueError("An Annotations object is requried as an argument.")
158
        if leniency == 0:
159
            return set(self.annotations) & set(other.annotations)
160
        if not 0 <= leniency <= 1:
161
            raise ValueError("Leniency must be a floating point between [0,1]")
162
163
        matches = set()
164
        for ann in self.annotations:
165
            label, start, end, text = ann
166
            window = ceil(leniency * (end - start))
167
            for c_label, c_start, c_end, c_text in other:
168
                if label == c_label and start - window <= c_start and end + window >= c_end:
169
                    matches.add(ann)
170
                    break
171
172
        return matches
173
174
    def compute_ambiguity(self, other):
175
        """
176
        Finds occurrences of spans from 'annotations' that intersect with a span from this annotation but do not have this spans label.
177
        label. If 'annotation' comprises a models predictions, this method provides a strong indicators
178
        of a model's in-ability to dis-ambiguate between entities. For a full analysis, compute a confusion matrix.
179
        :param other: Another Annotations object.
180
        :return: a dictionary containing incorrect label predictions for given spans
181
        """
182
        if not isinstance(other, Annotations):
183
            raise ValueError("An Annotations object is required as an argument.")
184
185
        ambiguity_dict = {}
186
187
        for label, start, end, text in self.annotations:
188
            for c_label, c_start, c_end, c_text in other.annotations:
189
                if label == c_label:
190
                    continue
191
                overlap = max(0, min(end, c_end) - max(c_start, start))
192
                if overlap != 0:
193
                    ambiguity_dict[(label, start, end, text)] = [(c_label, c_start, c_end, c_text)]
194
195
        return ambiguity_dict
196
197
    def compute_confusion_matrix(self, other, entities, leniency=0):
198
        """
199
        Computes a confusion matrix representing span level ambiguity between this annotation and the argument annotation.
200
        An annotation in 'annotations' is ambiguous is it overlaps with a span in this Annotation but does not have the
201
        same entity label. The main diagonal of this matrix corresponds to entities in this Annotation that match spans
202
        in 'annotations' and have equivalent class label.
203
        :param other: Another Annotations object.
204
        :param entities: a list of entities to use in computing matrix ambiguity.
205
        :param leniency: leniency to utilize when computing overlapping entities. This is the same definition of leniency as in intersection.
206
        :return: a square matrix with dimension len(entities) where matrix[i][j] indicates that entities[i] in this annotation was predicted as entities[j] in 'annotation' matrix[i][j] times.
207
        """
208
        if not isinstance(other, Annotations):
209
            raise ValueError("An Annotations object is required as an argument.")
210
        if not isinstance(entities, list):
211
            raise ValueError("entities must be a list of entities, but is %s" % repr(entities))
212
213
        entity_encoding = {entity: i for i, entity in enumerate(entities)}
214
        # Create 2-d array of len(entities) ** 2
215
        confusion_matrix = [[0 for _ in range(len(entities))] for _ in range(len(entities))]
216
217
        ambiguity_dict = self.compute_ambiguity(other)
218
        intersection = self.intersection(other, leniency=leniency)
219
220
        # Compute all off diagonal scores
221
        for gold_annotation in ambiguity_dict:
222
            gold_label, start, end, text = gold_annotation
223
            for ambiguous_annotation in ambiguity_dict[gold_annotation]:
224
                ambiguous_label = ambiguous_annotation[0]
225
                confusion_matrix[entity_encoding[gold_label]][entity_encoding[ambiguous_label]] += 1
226
227
        # Compute diagonal scores (correctly predicted entities with correct spans)
228
        for matching_annotation in intersection:
229
            matching_label, start, end, text = matching_annotation
230
            confusion_matrix[entity_encoding[matching_label]][entity_encoding[matching_label]] += 1
231
232
        return confusion_matrix
233
234
    def compute_counts(self):
235
        """
236
        Computes counts of each entity type in this annotation.
237
        :return: a Counter of the entity counts
238
        """
239
        return Counter(e[0] for e in self.annotations)
240
241
    def __str__(self):
242
        return str(self.annotations)
243
244
    def __len__(self):
245
        return len(self.annotations)
246
247
    def __iter__(self):
248
        return iter(self.annotations)
249
250
    def __or__(self, other):
251
        """
252
        Creates an Annotations object containing annotations from two instances
253
        :param other: the other Annotations instance
254
        :return: a new Annotations object containing entities from both
255
        """
256
        new_entities = list(set(self.annotations) | set(other.annotations))
257
        new_annotations = Annotations(new_entities, source_text_path=self.source_text_path or other.source_text_path)
258
        new_annotations.ann_path = 'None'
259
        return new_annotations
260
261
    def __ior__(self, other):
262
        self.annotations = list(set(self.annotations) | set(other.annotations))
263
        self.ann_path = 'None'
264
        return self