a b/scripts/ann_to_bio_conversion.py
1
import os
2
import re
3
import json
4
5
import nltk
6
import string
7
#nltk.download('punkt')
8
nltk.download('stopwords')
9
from nltk.corpus import stopwords
10
11
def tag_token(tags, token_pos, tag):
12
    if token_pos > 0 and f'{tag}' in tags[token_pos - 1]:
13
        
14
        if tags[token_pos] == 'O':
15
            tags[token_pos] = f'I-{tag}'
16
        elif f'I-{tag}' not in tags[token_pos]:
17
            tags[token_pos] += f';I-{tag}'
18
    else:
19
        if tags[token_pos] == 'O':
20
            tags[token_pos] = f'B-{tag}'
21
        elif f'B-{tag}' not in tags[token_pos]:
22
            tags[token_pos] += f';B-{tag}'
23
24
25
def remove_trailing_punctuation(token):
26
    while token and (re.search(r'[^\w\s\']', token[-1]) or (len(token) > 1 and token[-1] == "'")):
27
        token = token[:-1]
28
        
29
    return token
30
31
32
class AnnToBioConverter:
33
    
34
    def __init__(self, input_dir=None, txt_dir=None, ann_dir=None, output_dir=None, filtered_entities=[]):
35
        """
36
        Initializes an instance of the AnnToBioConverter class.
37
        :param input_dir: (str) Directory containing both the text and annotation files.
38
        :param txt_dir: (str) Directory containing the text files.
39
        :param ann_dir: (str) Directory containing the annotation files.
40
        :param output_dir: (str) Directory where the JSON file will be saved.
41
        """
42
        if input_dir is not None:
43
            self.txt_dir = input_dir
44
            self.ann_dir = input_dir
45
        elif txt_dir is not None and ann_dir is not None:
46
            self.txt_dir = txt_dir
47
            self.ann_dir = ann_dir
48
        else:
49
            raise ValueError("Either input_dir or txt_dir and ann_dir must be provided.")
50
        
51
        if output_dir is not None:
52
            self.output_dir = output_dir
53
        else:
54
            raise ValueError("output_dir must be provided.")
55
        self.data = {}
56
        
57
        self.filtered_entities = filtered_entities
58
    
59
    def _splitting_tokens(self, file_id, start, end, hyphen_split):
60
        """
61
        Splits a multi-word token into separate tokens and returns a list of tokens and their respective start and end indices.
62
        :param file_id: (str) The ID of the file containing the token.
63
        :param start: (int) The starting index of the token in the text.
64
        :param end: (int) The ending index of the token in the text.
65
        :return: (tuple) A tuple containing a list of tokens and their respective start and end indices.
66
        """
67
        
68
        text = self.data[file_id]['text']
69
        token = text[start:end]
70
        
71
        extra_sep = ['\u200a']
72
        if hyphen_split:
73
            extra_sep += ['-', '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212', '\uff0d']
74
        
75
        new_range = []
76
        tokens = []
77
        
78
        curr = start
79
        new_start = None
80
        
81
        for c in token + " ":
82
            if c.isspace() or c in extra_sep:
83
                if new_start:
84
                    new_range.append([new_start, curr])
85
                    tokens.append(text[new_start:curr])
86
                    new_start = None
87
            elif not new_start:
88
                new_start = curr
89
            curr += 1
90
        
91
        return tokens, new_range
92
    
93
    def _load_txt(self):
94
        """
95
        Loads the text files into the instance's data dictionary.
96
        """
97
        for file_name in os.listdir(self.txt_dir):
98
            if file_name.endswith(".txt"):
99
                with open(os.path.join(self.txt_dir, file_name), "r") as f:
100
                    text = f.read()
101
                file_id = file_name.split('.')[0]
102
                self.data[file_id] = {
103
                    "text": text,
104
                    "annotations": []
105
                }
106
    
107
    def _load_ann(self, hyphen_split):
108
        for file_name in os.listdir(self.ann_dir):
109
            
110
            if file_name.endswith(".ann"):
111
                with open(os.path.join(self.ann_dir, file_name), "r") as f:
112
                    
113
                    file_id = file_name.split('.')[0]
114
                    annotations = []
115
                    
116
                    for line in f:
117
                        if line.startswith("T"):
118
                            fields = line.strip().split("\t")
119
                            if len(fields[1].split(" ")) > 1:
120
                                label = fields[1].split(" ")[0]
121
                                
122
                                # Extracting start end indices (Few annotations contain more than one disjoint ranges)
123
                                start_end_range = [
124
                                    list(map(int, start_end.split()))
125
                                    for start_end in ' '.join(fields[1].split(" ")[1:]).split(';')
126
                                ]
127
                                
128
                                start_end_range_fixed = []
129
                                for start, end in start_end_range:
130
                                    tokens, start_end_split_list = self._splitting_tokens(file_id, start, end,
131
                                                                                          hyphen_split)
132
                                    start_end_range_fixed.extend(start_end_split_list)
133
                                
134
                                # Adding labels, start, end to annotations
135
                                for start, end in start_end_range_fixed:
136
                                    annotations.append({
137
                                        "label": label,
138
                                        "start": start,
139
                                        "end": end
140
                                    })
141
                    # sort annotations based on 'start' key before adding it to our dataset
142
                    annotations = sorted(annotations, key=lambda x: (x['start'], x['label']))
143
                    self.data[file_id]["annotations"] = annotations
144
        self._manual_fix()
145
    
146
    def split_text(self, file_id):
147
        text = self.data[file_id]['text']
148
        regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+'  # r'[^\s\u200a\-\—\–]+'
149
        
150
        tokens = []
151
        start_end_ranges = []
152
        
153
        sentence_breaks = []
154
        
155
        start_idx = 0
156
        
157
        for sentence in text.split('\n'):
158
            words = [match.group(0) for match in re.finditer(regex_match, sentence)]
159
            processed_words = list(map(remove_trailing_punctuation, words))
160
            sentence_indices = [(match.start(), match.start() + len(token)) for match, token in
161
                                zip(re.finditer(regex_match, sentence), processed_words)]
162
            
163
            # Update the indices to account for the current sentence's position in the entire text
164
            sentence_indices = [(start_idx + start, start_idx + end) for start, end in sentence_indices]
165
            
166
            start_end_ranges.extend(sentence_indices)
167
            tokens.extend(processed_words)
168
            
169
            sentence_breaks.append(len(tokens))
170
            
171
            start_idx += len(sentence) + 1
172
        return tokens, start_end_ranges, sentence_breaks
173
    
174
    def write_bio_files(self):
175
        
176
        for file_id in self.data:
177
            text = self.data[file_id]['text']
178
            annotations = self.data[file_id]['annotations']
179
            
180
            # Tokenizing
181
            tokens, token2text, sentence_breaks = self.split_text(file_id)
182
            
183
            # Initialize the tags
184
            tags = ['O'] * len(tokens)
185
            
186
            ann_pos = 0
187
            token_pos = 0
188
            
189
            while ann_pos < len(annotations) and token_pos < len(tokens):
190
                
191
                label = annotations[ann_pos]['label']
192
                start = annotations[ann_pos]['start']
193
                end = annotations[ann_pos]['end']
194
                
195
                if self.filtered_entities:
196
                    if label not in self.filtered_entities:
197
                        # increment to access next annotation
198
                        ann_pos += 1
199
                        continue
200
                
201
                ann_word = text[start:end]
202
                
203
                # find the next word that fall between the annotation start and end
204
                while token_pos < len(tokens) and token2text[token_pos][0] < start:
205
                    token_pos += 1
206
                
207
                if tokens[token_pos] == ann_word or \
208
                    ann_word in tokens[token_pos] or \
209
                    re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos]):
210
                    tag_token(tags, token_pos, label)
211
                elif ann_word in tokens[token_pos - 1] or \
212
                    ann_word in tokens[token_pos - 1] or \
213
                    re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos - 1]):
214
                    tag_token(tags, token_pos - 1, label)
215
                else:
216
                    print(tokens[token_pos], tokens[token_pos - 1], ann_word, label)
217
                
218
                # increment to access next annotation
219
                ann_pos += 1
220
            
221
            # Write the tags to a .bio file
222
            with open(os.path.join(self.output_dir, f"{file_id}.bio"), 'w') as f:
223
                for i in range(len(tokens)):
224
                    token = tokens[i].strip()
225
                    if token:
226
                        if i in sentence_breaks:
227
                            f.write("\n")
228
                        f.write(f"{tokens[i]}\t{tags[i]}\n")
229
    
230
    def _manual_fix(self):
231
        fix = {
232
            '19214295': {
233
                425: 424
234
            }
235
        }
236
        for file_id in self.data:
237
            if file_id in fix:
238
                for ann in self.data[file_id]['annotations']:
239
                    if ann['start'] in fix[file_id]:
240
                        ann['start'] = fix[file_id][ann['start']]
241
    
242
    def load_data(self, reload=False, hyphen_split=False):
243
        if not self.data or reload:
244
            self.data = {}
245
            self._load_txt()
246
            self._load_ann(hyphen_split)
247
    
248
    def create_json(self, data_dir, reload=False, hyphen_split=False):
249
        if not self.data or reload:
250
            self.load_data(hyphen_split, reload)
251
        
252
        # Write the dictionary to a JSON file
253
        with open(os.path.join(data_dir, "data.json"), "w") as f:
254
            json.dump(self.data, f)
255
    
256
    def load(self):
257
        self.load_txt()
258
        self.load_ann()
259
    
260
    def __getattr__(self, name):
261
        if name == 'data':
262
            return self.data
263
        else:
264
            raise AttributeError(f"'AnnToBioConverter' object has no attribute '{name}'")