[c0f169]: / scripts / ann_to_bio_conversion.py

Download this file

265 lines (210 with data), 8.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import os
import re
import json
import nltk
import string
#nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords
def tag_token(tags, token_pos, tag):
if token_pos > 0 and f'{tag}' in tags[token_pos - 1]:
if tags[token_pos] == 'O':
tags[token_pos] = f'I-{tag}'
elif f'I-{tag}' not in tags[token_pos]:
tags[token_pos] += f';I-{tag}'
else:
if tags[token_pos] == 'O':
tags[token_pos] = f'B-{tag}'
elif f'B-{tag}' not in tags[token_pos]:
tags[token_pos] += f';B-{tag}'
def remove_trailing_punctuation(token):
while token and (re.search(r'[^\w\s\']', token[-1]) or (len(token) > 1 and token[-1] == "'")):
token = token[:-1]
return token
class AnnToBioConverter:
def __init__(self, input_dir=None, txt_dir=None, ann_dir=None, output_dir=None, filtered_entities=[]):
"""
Initializes an instance of the AnnToBioConverter class.
:param input_dir: (str) Directory containing both the text and annotation files.
:param txt_dir: (str) Directory containing the text files.
:param ann_dir: (str) Directory containing the annotation files.
:param output_dir: (str) Directory where the JSON file will be saved.
"""
if input_dir is not None:
self.txt_dir = input_dir
self.ann_dir = input_dir
elif txt_dir is not None and ann_dir is not None:
self.txt_dir = txt_dir
self.ann_dir = ann_dir
else:
raise ValueError("Either input_dir or txt_dir and ann_dir must be provided.")
if output_dir is not None:
self.output_dir = output_dir
else:
raise ValueError("output_dir must be provided.")
self.data = {}
self.filtered_entities = filtered_entities
def _splitting_tokens(self, file_id, start, end, hyphen_split):
"""
Splits a multi-word token into separate tokens and returns a list of tokens and their respective start and end indices.
:param file_id: (str) The ID of the file containing the token.
:param start: (int) The starting index of the token in the text.
:param end: (int) The ending index of the token in the text.
:return: (tuple) A tuple containing a list of tokens and their respective start and end indices.
"""
text = self.data[file_id]['text']
token = text[start:end]
extra_sep = ['\u200a']
if hyphen_split:
extra_sep += ['-', '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212', '\uff0d']
new_range = []
tokens = []
curr = start
new_start = None
for c in token + " ":
if c.isspace() or c in extra_sep:
if new_start:
new_range.append([new_start, curr])
tokens.append(text[new_start:curr])
new_start = None
elif not new_start:
new_start = curr
curr += 1
return tokens, new_range
def _load_txt(self):
"""
Loads the text files into the instance's data dictionary.
"""
for file_name in os.listdir(self.txt_dir):
if file_name.endswith(".txt"):
with open(os.path.join(self.txt_dir, file_name), "r") as f:
text = f.read()
file_id = file_name.split('.')[0]
self.data[file_id] = {
"text": text,
"annotations": []
}
def _load_ann(self, hyphen_split):
for file_name in os.listdir(self.ann_dir):
if file_name.endswith(".ann"):
with open(os.path.join(self.ann_dir, file_name), "r") as f:
file_id = file_name.split('.')[0]
annotations = []
for line in f:
if line.startswith("T"):
fields = line.strip().split("\t")
if len(fields[1].split(" ")) > 1:
label = fields[1].split(" ")[0]
# Extracting start end indices (Few annotations contain more than one disjoint ranges)
start_end_range = [
list(map(int, start_end.split()))
for start_end in ' '.join(fields[1].split(" ")[1:]).split(';')
]
start_end_range_fixed = []
for start, end in start_end_range:
tokens, start_end_split_list = self._splitting_tokens(file_id, start, end,
hyphen_split)
start_end_range_fixed.extend(start_end_split_list)
# Adding labels, start, end to annotations
for start, end in start_end_range_fixed:
annotations.append({
"label": label,
"start": start,
"end": end
})
# sort annotations based on 'start' key before adding it to our dataset
annotations = sorted(annotations, key=lambda x: (x['start'], x['label']))
self.data[file_id]["annotations"] = annotations
self._manual_fix()
def split_text(self, file_id):
text = self.data[file_id]['text']
regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # r'[^\s\u200a\-\—\–]+'
tokens = []
start_end_ranges = []
sentence_breaks = []
start_idx = 0
for sentence in text.split('\n'):
words = [match.group(0) for match in re.finditer(regex_match, sentence)]
processed_words = list(map(remove_trailing_punctuation, words))
sentence_indices = [(match.start(), match.start() + len(token)) for match, token in
zip(re.finditer(regex_match, sentence), processed_words)]
# Update the indices to account for the current sentence's position in the entire text
sentence_indices = [(start_idx + start, start_idx + end) for start, end in sentence_indices]
start_end_ranges.extend(sentence_indices)
tokens.extend(processed_words)
sentence_breaks.append(len(tokens))
start_idx += len(sentence) + 1
return tokens, start_end_ranges, sentence_breaks
def write_bio_files(self):
for file_id in self.data:
text = self.data[file_id]['text']
annotations = self.data[file_id]['annotations']
# Tokenizing
tokens, token2text, sentence_breaks = self.split_text(file_id)
# Initialize the tags
tags = ['O'] * len(tokens)
ann_pos = 0
token_pos = 0
while ann_pos < len(annotations) and token_pos < len(tokens):
label = annotations[ann_pos]['label']
start = annotations[ann_pos]['start']
end = annotations[ann_pos]['end']
if self.filtered_entities:
if label not in self.filtered_entities:
# increment to access next annotation
ann_pos += 1
continue
ann_word = text[start:end]
# find the next word that fall between the annotation start and end
while token_pos < len(tokens) and token2text[token_pos][0] < start:
token_pos += 1
if tokens[token_pos] == ann_word or \
ann_word in tokens[token_pos] or \
re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos]):
tag_token(tags, token_pos, label)
elif ann_word in tokens[token_pos - 1] or \
ann_word in tokens[token_pos - 1] or \
re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos - 1]):
tag_token(tags, token_pos - 1, label)
else:
print(tokens[token_pos], tokens[token_pos - 1], ann_word, label)
# increment to access next annotation
ann_pos += 1
# Write the tags to a .bio file
with open(os.path.join(self.output_dir, f"{file_id}.bio"), 'w') as f:
for i in range(len(tokens)):
token = tokens[i].strip()
if token:
if i in sentence_breaks:
f.write("\n")
f.write(f"{tokens[i]}\t{tags[i]}\n")
def _manual_fix(self):
fix = {
'19214295': {
425: 424
}
}
for file_id in self.data:
if file_id in fix:
for ann in self.data[file_id]['annotations']:
if ann['start'] in fix[file_id]:
ann['start'] = fix[file_id][ann['start']]
def load_data(self, reload=False, hyphen_split=False):
if not self.data or reload:
self.data = {}
self._load_txt()
self._load_ann(hyphen_split)
def create_json(self, data_dir, reload=False, hyphen_split=False):
if not self.data or reload:
self.load_data(hyphen_split, reload)
# Write the dictionary to a JSON file
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump(self.data, f)
def load(self):
self.load_txt()
self.load_ann()
def __getattr__(self, name):
if name == 'data':
return self.data
else:
raise AttributeError(f"'AnnToBioConverter' object has no attribute '{name}'")