Switch to side-by-side view

--- a
+++ b/scripts/ann_to_bio_conversion.py
@@ -0,0 +1,264 @@
+import os
+import re
+import json
+
+import nltk
+import string
+#nltk.download('punkt')
+nltk.download('stopwords')
+from nltk.corpus import stopwords
+
+def tag_token(tags, token_pos, tag):
+	if token_pos > 0 and f'{tag}' in tags[token_pos - 1]:
+		
+		if tags[token_pos] == 'O':
+			tags[token_pos] = f'I-{tag}'
+		elif f'I-{tag}' not in tags[token_pos]:
+			tags[token_pos] += f';I-{tag}'
+	else:
+		if tags[token_pos] == 'O':
+			tags[token_pos] = f'B-{tag}'
+		elif f'B-{tag}' not in tags[token_pos]:
+			tags[token_pos] += f';B-{tag}'
+
+
+def remove_trailing_punctuation(token):
+	while token and (re.search(r'[^\w\s\']', token[-1]) or (len(token) > 1 and token[-1] == "'")):
+		token = token[:-1]
+		
+	return token
+
+
+class AnnToBioConverter:
+	
+	def __init__(self, input_dir=None, txt_dir=None, ann_dir=None, output_dir=None, filtered_entities=[]):
+		"""
+		Initializes an instance of the AnnToBioConverter class.
+		:param input_dir: (str) Directory containing both the text and annotation files.
+		:param txt_dir: (str) Directory containing the text files.
+		:param ann_dir: (str) Directory containing the annotation files.
+		:param output_dir: (str) Directory where the JSON file will be saved.
+		"""
+		if input_dir is not None:
+			self.txt_dir = input_dir
+			self.ann_dir = input_dir
+		elif txt_dir is not None and ann_dir is not None:
+			self.txt_dir = txt_dir
+			self.ann_dir = ann_dir
+		else:
+			raise ValueError("Either input_dir or txt_dir and ann_dir must be provided.")
+		
+		if output_dir is not None:
+			self.output_dir = output_dir
+		else:
+			raise ValueError("output_dir must be provided.")
+		self.data = {}
+		
+		self.filtered_entities = filtered_entities
+	
+	def _splitting_tokens(self, file_id, start, end, hyphen_split):
+		"""
+		Splits a multi-word token into separate tokens and returns a list of tokens and their respective start and end indices.
+		:param file_id: (str) The ID of the file containing the token.
+		:param start: (int) The starting index of the token in the text.
+		:param end: (int) The ending index of the token in the text.
+		:return: (tuple) A tuple containing a list of tokens and their respective start and end indices.
+		"""
+		
+		text = self.data[file_id]['text']
+		token = text[start:end]
+		
+		extra_sep = ['\u200a']
+		if hyphen_split:
+			extra_sep += ['-', '\u2010', '\u2011', '\u2012', '\u2013', '\u2014', '\u2015', '\u2212', '\uff0d']
+		
+		new_range = []
+		tokens = []
+		
+		curr = start
+		new_start = None
+		
+		for c in token + " ":
+			if c.isspace() or c in extra_sep:
+				if new_start:
+					new_range.append([new_start, curr])
+					tokens.append(text[new_start:curr])
+					new_start = None
+			elif not new_start:
+				new_start = curr
+			curr += 1
+		
+		return tokens, new_range
+	
+	def _load_txt(self):
+		"""
+		Loads the text files into the instance's data dictionary.
+		"""
+		for file_name in os.listdir(self.txt_dir):
+			if file_name.endswith(".txt"):
+				with open(os.path.join(self.txt_dir, file_name), "r") as f:
+					text = f.read()
+				file_id = file_name.split('.')[0]
+				self.data[file_id] = {
+					"text": text,
+					"annotations": []
+				}
+	
+	def _load_ann(self, hyphen_split):
+		for file_name in os.listdir(self.ann_dir):
+			
+			if file_name.endswith(".ann"):
+				with open(os.path.join(self.ann_dir, file_name), "r") as f:
+					
+					file_id = file_name.split('.')[0]
+					annotations = []
+					
+					for line in f:
+						if line.startswith("T"):
+							fields = line.strip().split("\t")
+							if len(fields[1].split(" ")) > 1:
+								label = fields[1].split(" ")[0]
+								
+								# Extracting start end indices (Few annotations contain more than one disjoint ranges)
+								start_end_range = [
+									list(map(int, start_end.split()))
+									for start_end in ' '.join(fields[1].split(" ")[1:]).split(';')
+								]
+								
+								start_end_range_fixed = []
+								for start, end in start_end_range:
+									tokens, start_end_split_list = self._splitting_tokens(file_id, start, end,
+																						  hyphen_split)
+									start_end_range_fixed.extend(start_end_split_list)
+								
+								# Adding labels, start, end to annotations
+								for start, end in start_end_range_fixed:
+									annotations.append({
+										"label": label,
+										"start": start,
+										"end": end
+									})
+					# sort annotations based on 'start' key before adding it to our dataset
+					annotations = sorted(annotations, key=lambda x: (x['start'], x['label']))
+					self.data[file_id]["annotations"] = annotations
+		self._manual_fix()
+	
+	def split_text(self, file_id):
+		text = self.data[file_id]['text']
+		regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+'  # r'[^\s\u200a\-\—\–]+'
+		
+		tokens = []
+		start_end_ranges = []
+		
+		sentence_breaks = []
+		
+		start_idx = 0
+		
+		for sentence in text.split('\n'):
+			words = [match.group(0) for match in re.finditer(regex_match, sentence)]
+			processed_words = list(map(remove_trailing_punctuation, words))
+			sentence_indices = [(match.start(), match.start() + len(token)) for match, token in
+								zip(re.finditer(regex_match, sentence), processed_words)]
+			
+			# Update the indices to account for the current sentence's position in the entire text
+			sentence_indices = [(start_idx + start, start_idx + end) for start, end in sentence_indices]
+			
+			start_end_ranges.extend(sentence_indices)
+			tokens.extend(processed_words)
+			
+			sentence_breaks.append(len(tokens))
+			
+			start_idx += len(sentence) + 1
+		return tokens, start_end_ranges, sentence_breaks
+	
+	def write_bio_files(self):
+		
+		for file_id in self.data:
+			text = self.data[file_id]['text']
+			annotations = self.data[file_id]['annotations']
+			
+			# Tokenizing
+			tokens, token2text, sentence_breaks = self.split_text(file_id)
+			
+			# Initialize the tags
+			tags = ['O'] * len(tokens)
+			
+			ann_pos = 0
+			token_pos = 0
+			
+			while ann_pos < len(annotations) and token_pos < len(tokens):
+				
+				label = annotations[ann_pos]['label']
+				start = annotations[ann_pos]['start']
+				end = annotations[ann_pos]['end']
+				
+				if self.filtered_entities:
+					if label not in self.filtered_entities:
+						# increment to access next annotation
+						ann_pos += 1
+						continue
+				
+				ann_word = text[start:end]
+				
+				# find the next word that fall between the annotation start and end
+				while token_pos < len(tokens) and token2text[token_pos][0] < start:
+					token_pos += 1
+				
+				if tokens[token_pos] == ann_word or \
+					ann_word in tokens[token_pos] or \
+					re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos]):
+					tag_token(tags, token_pos, label)
+				elif ann_word in tokens[token_pos - 1] or \
+					ann_word in tokens[token_pos - 1] or \
+					re.sub(r'\W+', '', ann_word) in re.sub(r'\W+', '', tokens[token_pos - 1]):
+					tag_token(tags, token_pos - 1, label)
+				else:
+					print(tokens[token_pos], tokens[token_pos - 1], ann_word, label)
+				
+				# increment to access next annotation
+				ann_pos += 1
+			
+			# Write the tags to a .bio file
+			with open(os.path.join(self.output_dir, f"{file_id}.bio"), 'w') as f:
+				for i in range(len(tokens)):
+					token = tokens[i].strip()
+					if token:
+						if i in sentence_breaks:
+							f.write("\n")
+						f.write(f"{tokens[i]}\t{tags[i]}\n")
+	
+	def _manual_fix(self):
+		fix = {
+			'19214295': {
+				425: 424
+			}
+		}
+		for file_id in self.data:
+			if file_id in fix:
+				for ann in self.data[file_id]['annotations']:
+					if ann['start'] in fix[file_id]:
+						ann['start'] = fix[file_id][ann['start']]
+	
+	def load_data(self, reload=False, hyphen_split=False):
+		if not self.data or reload:
+			self.data = {}
+			self._load_txt()
+			self._load_ann(hyphen_split)
+	
+	def create_json(self, data_dir, reload=False, hyphen_split=False):
+		if not self.data or reload:
+			self.load_data(hyphen_split, reload)
+		
+		# Write the dictionary to a JSON file
+		with open(os.path.join(data_dir, "data.json"), "w") as f:
+			json.dump(self.data, f)
+	
+	def load(self):
+		self.load_txt()
+		self.load_ann()
+	
+	def __getattr__(self, name):
+		if name == 'data':
+			return self.data
+		else:
+			raise AttributeError(f"'AnnToBioConverter' object has no attribute '{name}'")