|
a |
|
b/medacy/tools/entity.py |
|
|
1 |
import os |
|
|
2 |
from typing import Match |
|
|
3 |
|
|
|
4 |
from medacy.data.data_file import DataFile |
|
|
5 |
from medacy.data.annotations import Annotations |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class Entity: |
|
|
9 |
""" |
|
|
10 |
Representation of an individual entity in an annotation document. This abstraction is not used in the Annotations |
|
|
11 |
class, but can be used to keep track of what entities are present in a document during dataset manipulation. |
|
|
12 |
|
|
|
13 |
:ivar tag: the tag of this Entity |
|
|
14 |
:ivar start: the start index |
|
|
15 |
:ivar end: the end index |
|
|
16 |
:ivar text: the text of the Entity |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
t = 1 |
|
|
20 |
|
|
|
21 |
def __init__(self, tag: str, start: int, end: int, text: str, num: int = 0): |
|
|
22 |
self.num = num |
|
|
23 |
self.tag = tag |
|
|
24 |
self.start = start |
|
|
25 |
self.end = end |
|
|
26 |
self.text = text |
|
|
27 |
|
|
|
28 |
def __eq__(self, other): |
|
|
29 |
return self.start == other.start and self.end == other.end and self.tag == other.tag |
|
|
30 |
|
|
|
31 |
def __hash__(self): |
|
|
32 |
return hash((self.start, self.end, self.text)) |
|
|
33 |
|
|
|
34 |
def __str__(self): |
|
|
35 |
"""Returns the BRAT representation of this Entity, without a new-line character""" |
|
|
36 |
return f"T{self.num}\t{self.tag} {self.start} {self.end}\t{self.text}" |
|
|
37 |
|
|
|
38 |
def __repr__(self): |
|
|
39 |
"""Return the constructor in string form""" |
|
|
40 |
return f"{type(self).__name__}({self.tag}, {self.start}, {self.end}, {self.text}, {self.num})" |
|
|
41 |
|
|
|
42 |
@classmethod |
|
|
43 |
def reset_t(cls): |
|
|
44 |
""" |
|
|
45 |
Resest the T counter for this class to 1 |
|
|
46 |
:return: The previous value of t |
|
|
47 |
""" |
|
|
48 |
previous = cls.t |
|
|
49 |
cls.t = 1 |
|
|
50 |
return previous |
|
|
51 |
|
|
|
52 |
@classmethod |
|
|
53 |
def init_from_re_match(cls, match: Match, ent_class, num=None, increment_t=False): |
|
|
54 |
""" |
|
|
55 |
Creates a new Entity from a regex Match. |
|
|
56 |
:param match: A Match object |
|
|
57 |
:param ent_class: The type of entity this is |
|
|
58 |
:param num: The number for this entity; defaults to the current entity count held by the class. |
|
|
59 |
:param increment_t: Whether or not to increment the T number |
|
|
60 |
:return: A new Entity |
|
|
61 |
""" |
|
|
62 |
if not isinstance(match, Match): |
|
|
63 |
raise TypeError("Argument is not a Match object.") |
|
|
64 |
|
|
|
65 |
new_entity = cls( |
|
|
66 |
num=cls.t if num is None else num, |
|
|
67 |
tag=ent_class, |
|
|
68 |
start=match.start(), |
|
|
69 |
end=match.end(), |
|
|
70 |
text=match.string[match.start():match.end()], |
|
|
71 |
) |
|
|
72 |
|
|
|
73 |
if num is None and increment_t: |
|
|
74 |
# Increment the counter |
|
|
75 |
cls.t += 1 |
|
|
76 |
|
|
|
77 |
return new_entity |
|
|
78 |
|
|
|
79 |
@classmethod |
|
|
80 |
def init_from_doc(cls, doc): |
|
|
81 |
""" |
|
|
82 |
Creates a list of Entities for all entity annotations in a document. |
|
|
83 |
:param doc: can be a DataFile or str of a file path |
|
|
84 |
:return: a list of Entities |
|
|
85 |
""" |
|
|
86 |
if isinstance(doc, DataFile): |
|
|
87 |
ann = Annotations(doc.ann_path, doc.txt_path) |
|
|
88 |
elif isinstance(doc, (str, os.PathLike)): |
|
|
89 |
ann = Annotations(doc) |
|
|
90 |
else: |
|
|
91 |
raise ValueError(f"'doc'' must be DataFile, str, or os.PathLike, but is '{type(doc)}'") |
|
|
92 |
|
|
|
93 |
entities = [] |
|
|
94 |
|
|
|
95 |
for ent in ann: |
|
|
96 |
# Entities are a tuple of (label, start, end, text) |
|
|
97 |
new_ent = cls( |
|
|
98 |
tag=ent[0], |
|
|
99 |
start=ent[1], |
|
|
100 |
end=ent[2], |
|
|
101 |
text=ent[3] |
|
|
102 |
) |
|
|
103 |
entities.append(new_ent) |
|
|
104 |
|
|
|
105 |
return entities |
|
|
106 |
|
|
|
107 |
def set_t(self): |
|
|
108 |
"""Sets the T value based on the class's counter and increments the counter""" |
|
|
109 |
self.num = self.__class__.t |
|
|
110 |
self.__class__.t += 1 |
|
|
111 |
|
|
|
112 |
def equals(self, other, mode='strict'): |
|
|
113 |
""" |
|
|
114 |
Determines if two Entities match, based on if the spans match and the tag is the same. |
|
|
115 |
If mode is set to 'lenient', two Entities match if the other span is fully within or fully |
|
|
116 |
without the first Entity and the tag is the same. |
|
|
117 |
:param other: another instance of Entity |
|
|
118 |
:param mode: 'strict' or 'lenient'; defaults to 'strict' |
|
|
119 |
:return: True or False |
|
|
120 |
""" |
|
|
121 |
if not isinstance(other, Entity): |
|
|
122 |
raise ValueError(f"'other' must be another instance of Entity, but is '{type(other)}'") |
|
|
123 |
|
|
|
124 |
if mode == 'strict': |
|
|
125 |
return self == other |
|
|
126 |
if mode != 'lenient': |
|
|
127 |
raise ValueError(f"'mode' must be 'strict' or 'lenient', but is '{mode}'") |
|
|
128 |
|
|
|
129 |
# Lenient |
|
|
130 |
return ((self.end > other.start and self.start < other.end) or (self.start < other.end and other.start < self.end)) and self.tag == other.tag |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def sort_entities(entities): |
|
|
134 |
""" |
|
|
135 |
Sorts a list of Entity instances, adjusting the num value of each one |
|
|
136 |
:param entities: a list of Entities |
|
|
137 |
:return: a sorted list; all instances have ascending num values starting at 1 |
|
|
138 |
""" |
|
|
139 |
if not all(isinstance(e, Entity) for e in entities): |
|
|
140 |
raise ValueError("At least one item in entities is not an Entity") |
|
|
141 |
|
|
|
142 |
entities = entities.copy() |
|
|
143 |
entities.sort(key=lambda x: (x.start, x.end)) |
|
|
144 |
|
|
|
145 |
for i, e in enumerate(entities, 1): |
|
|
146 |
e.num = i |
|
|
147 |
|
|
|
148 |
return entities |