|
a |
|
b/medacy/data/annotations.py |
|
|
1 |
import logging |
|
|
2 |
import os |
|
|
3 |
import re |
|
|
4 |
from collections import Counter, namedtuple |
|
|
5 |
from math import ceil |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
EntTuple = namedtuple('EntTuple', ['tag', 'start', 'end', 'text']) |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class Annotations: |
|
|
12 |
""" |
|
|
13 |
An Annotations object stores all relevant information needed to manage Annotations over a document. |
|
|
14 |
The Annotation object is utilized by medaCy to structure input to models and output from models. |
|
|
15 |
This object wraps a list of tuples representing the entities in a document. |
|
|
16 |
|
|
|
17 |
:ivar ann_path: the path to the .ann file |
|
|
18 |
:ivar source_text_path: path to the related .txt file |
|
|
19 |
:ivar annotations: a list of annotation tuples |
|
|
20 |
""" |
|
|
21 |
|
|
|
22 |
brat_pattern = re.compile(r'T(\d+)\t(\S+) ((\d+ \d+;)*\d+ \d+)\t(.+)') |
|
|
23 |
|
|
|
24 |
def __init__(self, annotation_data, source_text_path=None): |
|
|
25 |
""" |
|
|
26 |
:param annotation_data: a file path to an annotation file, or a list of annotation tuples. |
|
|
27 |
Construction from a list of tuples is intended for internal use. |
|
|
28 |
:param source_text_path: optional; path to the text file from which the annotations were derived. |
|
|
29 |
""" |
|
|
30 |
if isinstance(annotation_data, list) and all(isinstance(e, tuple) for e in annotation_data): |
|
|
31 |
self.annotations = annotation_data |
|
|
32 |
self.source_text_path = source_text_path |
|
|
33 |
return |
|
|
34 |
elif not os.path.isfile(annotation_data): |
|
|
35 |
raise FileNotFoundError("annotation_data must be a list of tuples or a valid file path, but is %s" % repr(annotation_data)) |
|
|
36 |
|
|
|
37 |
self.ann_path = annotation_data |
|
|
38 |
self.source_text_path = source_text_path |
|
|
39 |
self.annotations = self._init_from_file(annotation_data) |
|
|
40 |
|
|
|
41 |
@staticmethod |
|
|
42 |
def _init_from_file(file_path): |
|
|
43 |
""" |
|
|
44 |
Creates a list of annotation tuples from a file path |
|
|
45 |
:param file_path: the path to an ann file |
|
|
46 |
:return: a list of annotation tuples |
|
|
47 |
""" |
|
|
48 |
annotations = [] |
|
|
49 |
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
50 |
data = f.read() |
|
|
51 |
|
|
|
52 |
for match in re.finditer(Annotations.brat_pattern, data): |
|
|
53 |
tag = match.group(2) |
|
|
54 |
spans = match.group(3) |
|
|
55 |
mention = match.group(5) |
|
|
56 |
|
|
|
57 |
spans = re.findall(r'\d+', spans) |
|
|
58 |
start, end = int(spans[0]), int(spans[-1]) |
|
|
59 |
|
|
|
60 |
new_ent = EntTuple(tag, start, end, mention) |
|
|
61 |
annotations.append(new_ent) |
|
|
62 |
|
|
|
63 |
return annotations |
|
|
64 |
|
|
|
65 |
@property |
|
|
66 |
def annotations(self): |
|
|
67 |
return self._annotations |
|
|
68 |
|
|
|
69 |
@annotations.setter |
|
|
70 |
def annotations(self, value): |
|
|
71 |
"""Ensures that annotations are always sorted""" |
|
|
72 |
self._annotations = sorted([EntTuple(*e) for e in value], key=lambda x: (x.start, x.end)) |
|
|
73 |
|
|
|
74 |
def get_labels(self, as_list=False): |
|
|
75 |
""" |
|
|
76 |
Get the set of labels from this collection of annotations. |
|
|
77 |
:param as_list: bool for if to return the results as a list; defaults to False |
|
|
78 |
:return: The set of labels. |
|
|
79 |
""" |
|
|
80 |
labels = {e[0] for e in self.annotations} |
|
|
81 |
|
|
|
82 |
if as_list: |
|
|
83 |
return list(labels) |
|
|
84 |
return labels |
|
|
85 |
|
|
|
86 |
def add_entity(self, label, start, end, text=""): |
|
|
87 |
""" |
|
|
88 |
Adds an entity to the Annotations |
|
|
89 |
:param label: the label of the annotation you are appending |
|
|
90 |
:param start: the start index in the document of the annotation you are appending. |
|
|
91 |
:param end: the end index in the document of the annotation you are appending |
|
|
92 |
:param text: the raw text of the annotation you are appending |
|
|
93 |
""" |
|
|
94 |
self.annotations.append(EntTuple(label, start, end, text)) |
|
|
95 |
|
|
|
96 |
def to_ann(self, write_location=None): |
|
|
97 |
""" |
|
|
98 |
Formats the Annotations object into a string representing a valid ANN file. Optionally writes the formatted |
|
|
99 |
string to a destination. |
|
|
100 |
:param write_location: path of location to write ann file to |
|
|
101 |
:return: returns string formatted as an ann file, if write_location is valid path also writes to that path. |
|
|
102 |
""" |
|
|
103 |
ann_string = "" |
|
|
104 |
|
|
|
105 |
for num, tup in enumerate(self.annotations, 1): |
|
|
106 |
mention = tup.text.replace('\n', ' ') |
|
|
107 |
ann_string += f"T{num}\t{tup.tag} {tup.start} {tup.end}\t{mention}\n" |
|
|
108 |
|
|
|
109 |
if write_location is not None: |
|
|
110 |
if os.path.isfile(write_location): |
|
|
111 |
logging.warning("Overwriting file at: %s", write_location) |
|
|
112 |
with open(write_location, 'w') as f: |
|
|
113 |
f.write(ann_string) |
|
|
114 |
|
|
|
115 |
return ann_string |
|
|
116 |
|
|
|
117 |
def difference(self, other, leniency=0): |
|
|
118 |
""" |
|
|
119 |
Identifies the difference between two Annotations objects. Useful for checking if an unverified annotation |
|
|
120 |
matches an annotation known to be accurate. This is done returning a list of all annotations in the operated on |
|
|
121 |
Annotation object that do not exist in the passed in annotation object. This is a set difference. |
|
|
122 |
:param other: Another Annotations object. |
|
|
123 |
:param leniency: a floating point value between [0,1] defining the leniency of the character spans to count |
|
|
124 |
as different. A value of zero considers only exact character matches while a positive value considers entities |
|
|
125 |
that differ by up to :code:`ceil(leniency * len(span)/2)` on either side. |
|
|
126 |
:return: A set of tuples of non-matching annotations. |
|
|
127 |
""" |
|
|
128 |
if not isinstance(other, Annotations): |
|
|
129 |
raise ValueError("Annotations.diff() can only accept another Annotations object as an argument.") |
|
|
130 |
if leniency == 0: |
|
|
131 |
return set(self.annotations) - set(other.annotations) |
|
|
132 |
if not 0 <= leniency <= 1: |
|
|
133 |
raise ValueError("Leniency must be a floating point between [0,1]") |
|
|
134 |
|
|
|
135 |
matches = set() |
|
|
136 |
for ann in self.annotations: |
|
|
137 |
label, start, end, text = ann |
|
|
138 |
window = ceil(leniency * (end - start)) |
|
|
139 |
for c_label, c_start, c_end, c_text in other.annotations: |
|
|
140 |
if label == c_label: |
|
|
141 |
if start - window <= c_start and end + window >= c_end: |
|
|
142 |
matches.add(ann) |
|
|
143 |
break |
|
|
144 |
|
|
|
145 |
return set(self.annotations) - matches |
|
|
146 |
|
|
|
147 |
def intersection(self, other, leniency=0): |
|
|
148 |
""" |
|
|
149 |
Computes the intersection of the operated annotation object with the operand annotation object. |
|
|
150 |
:param other: Another Annotations object. |
|
|
151 |
:param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as |
|
|
152 |
different. A value of zero considers only exact character matches while a positive value considers entities that |
|
|
153 |
differ by up to :code:`ceil(leniency * len(span)/2)` on either side. |
|
|
154 |
:return A set of annotations that appear in both Annotation objects |
|
|
155 |
""" |
|
|
156 |
if not isinstance(other, Annotations): |
|
|
157 |
raise ValueError("An Annotations object is requried as an argument.") |
|
|
158 |
if leniency == 0: |
|
|
159 |
return set(self.annotations) & set(other.annotations) |
|
|
160 |
if not 0 <= leniency <= 1: |
|
|
161 |
raise ValueError("Leniency must be a floating point between [0,1]") |
|
|
162 |
|
|
|
163 |
matches = set() |
|
|
164 |
for ann in self.annotations: |
|
|
165 |
label, start, end, text = ann |
|
|
166 |
window = ceil(leniency * (end - start)) |
|
|
167 |
for c_label, c_start, c_end, c_text in other: |
|
|
168 |
if label == c_label and start - window <= c_start and end + window >= c_end: |
|
|
169 |
matches.add(ann) |
|
|
170 |
break |
|
|
171 |
|
|
|
172 |
return matches |
|
|
173 |
|
|
|
174 |
def compute_ambiguity(self, other): |
|
|
175 |
""" |
|
|
176 |
Finds occurrences of spans from 'annotations' that intersect with a span from this annotation but do not have this spans label. |
|
|
177 |
label. If 'annotation' comprises a models predictions, this method provides a strong indicators |
|
|
178 |
of a model's in-ability to dis-ambiguate between entities. For a full analysis, compute a confusion matrix. |
|
|
179 |
:param other: Another Annotations object. |
|
|
180 |
:return: a dictionary containing incorrect label predictions for given spans |
|
|
181 |
""" |
|
|
182 |
if not isinstance(other, Annotations): |
|
|
183 |
raise ValueError("An Annotations object is required as an argument.") |
|
|
184 |
|
|
|
185 |
ambiguity_dict = {} |
|
|
186 |
|
|
|
187 |
for label, start, end, text in self.annotations: |
|
|
188 |
for c_label, c_start, c_end, c_text in other.annotations: |
|
|
189 |
if label == c_label: |
|
|
190 |
continue |
|
|
191 |
overlap = max(0, min(end, c_end) - max(c_start, start)) |
|
|
192 |
if overlap != 0: |
|
|
193 |
ambiguity_dict[(label, start, end, text)] = [(c_label, c_start, c_end, c_text)] |
|
|
194 |
|
|
|
195 |
return ambiguity_dict |
|
|
196 |
|
|
|
197 |
def compute_confusion_matrix(self, other, entities, leniency=0): |
|
|
198 |
""" |
|
|
199 |
Computes a confusion matrix representing span level ambiguity between this annotation and the argument annotation. |
|
|
200 |
An annotation in 'annotations' is ambiguous is it overlaps with a span in this Annotation but does not have the |
|
|
201 |
same entity label. The main diagonal of this matrix corresponds to entities in this Annotation that match spans |
|
|
202 |
in 'annotations' and have equivalent class label. |
|
|
203 |
:param other: Another Annotations object. |
|
|
204 |
:param entities: a list of entities to use in computing matrix ambiguity. |
|
|
205 |
:param leniency: leniency to utilize when computing overlapping entities. This is the same definition of leniency as in intersection. |
|
|
206 |
:return: a square matrix with dimension len(entities) where matrix[i][j] indicates that entities[i] in this annotation was predicted as entities[j] in 'annotation' matrix[i][j] times. |
|
|
207 |
""" |
|
|
208 |
if not isinstance(other, Annotations): |
|
|
209 |
raise ValueError("An Annotations object is required as an argument.") |
|
|
210 |
if not isinstance(entities, list): |
|
|
211 |
raise ValueError("entities must be a list of entities, but is %s" % repr(entities)) |
|
|
212 |
|
|
|
213 |
entity_encoding = {entity: i for i, entity in enumerate(entities)} |
|
|
214 |
# Create 2-d array of len(entities) ** 2 |
|
|
215 |
confusion_matrix = [[0 for _ in range(len(entities))] for _ in range(len(entities))] |
|
|
216 |
|
|
|
217 |
ambiguity_dict = self.compute_ambiguity(other) |
|
|
218 |
intersection = self.intersection(other, leniency=leniency) |
|
|
219 |
|
|
|
220 |
# Compute all off diagonal scores |
|
|
221 |
for gold_annotation in ambiguity_dict: |
|
|
222 |
gold_label, start, end, text = gold_annotation |
|
|
223 |
for ambiguous_annotation in ambiguity_dict[gold_annotation]: |
|
|
224 |
ambiguous_label = ambiguous_annotation[0] |
|
|
225 |
confusion_matrix[entity_encoding[gold_label]][entity_encoding[ambiguous_label]] += 1 |
|
|
226 |
|
|
|
227 |
# Compute diagonal scores (correctly predicted entities with correct spans) |
|
|
228 |
for matching_annotation in intersection: |
|
|
229 |
matching_label, start, end, text = matching_annotation |
|
|
230 |
confusion_matrix[entity_encoding[matching_label]][entity_encoding[matching_label]] += 1 |
|
|
231 |
|
|
|
232 |
return confusion_matrix |
|
|
233 |
|
|
|
234 |
def compute_counts(self): |
|
|
235 |
""" |
|
|
236 |
Computes counts of each entity type in this annotation. |
|
|
237 |
:return: a Counter of the entity counts |
|
|
238 |
""" |
|
|
239 |
return Counter(e[0] for e in self.annotations) |
|
|
240 |
|
|
|
241 |
def __str__(self): |
|
|
242 |
return str(self.annotations) |
|
|
243 |
|
|
|
244 |
def __len__(self): |
|
|
245 |
return len(self.annotations) |
|
|
246 |
|
|
|
247 |
def __iter__(self): |
|
|
248 |
return iter(self.annotations) |
|
|
249 |
|
|
|
250 |
def __or__(self, other): |
|
|
251 |
""" |
|
|
252 |
Creates an Annotations object containing annotations from two instances |
|
|
253 |
:param other: the other Annotations instance |
|
|
254 |
:return: a new Annotations object containing entities from both |
|
|
255 |
""" |
|
|
256 |
new_entities = list(set(self.annotations) | set(other.annotations)) |
|
|
257 |
new_annotations = Annotations(new_entities, source_text_path=self.source_text_path or other.source_text_path) |
|
|
258 |
new_annotations.ann_path = 'None' |
|
|
259 |
return new_annotations |
|
|
260 |
|
|
|
261 |
def __ior__(self, other): |
|
|
262 |
self.annotations = list(set(self.annotations) | set(other.annotations)) |
|
|
263 |
self.ann_path = 'None' |
|
|
264 |
return self |