|
a |
|
b/src/utils/dataloader.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import torch |
|
|
4 |
import re |
|
|
5 |
from torch import nn |
|
|
6 |
from torch.utils.data import Dataset |
|
|
7 |
from torch.utils.data import DataLoader |
|
|
8 |
from transformers import BertTokenizer,BertForTokenClassification |
|
|
9 |
import random |
|
|
10 |
import nltk |
|
|
11 |
from nltk.tokenize import sent_tokenize |
|
|
12 |
nltk.download('punkt') |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def shuffle_sentences_and_entities(text, entities): |
|
|
16 |
sentences = sent_tokenize(text) |
|
|
17 |
entity_tokens = entities.split() # align with words in text |
|
|
18 |
|
|
|
19 |
# identify start and end indices of sentences in terms of word counts |
|
|
20 |
word_counts = [len(sentence.split()) for sentence in sentences] |
|
|
21 |
start_indices = [sum(word_counts[:i]) for i in range(len(word_counts))] |
|
|
22 |
end_indices = [sum(word_counts[:i+1]) for i in range(len(word_counts))] |
|
|
23 |
|
|
|
24 |
# split entities into groups (corresponding to sentences) |
|
|
25 |
sentence_entities = [entity_tokens[start:end] for start, end in zip(start_indices, end_indices)] |
|
|
26 |
|
|
|
27 |
# shuffle sentence-entities pairs |
|
|
28 |
combined = list(zip(sentences, sentence_entities)) |
|
|
29 |
random.seed(42) |
|
|
30 |
random.shuffle(combined) |
|
|
31 |
shuffled_sentences, shuffled_sentence_entities = zip(*combined) |
|
|
32 |
|
|
|
33 |
# reconstruction |
|
|
34 |
augmented_text = ' '.join(shuffled_sentences) |
|
|
35 |
augmented_entities = ' '.join([' '.join(entity_group) for entity_group in shuffled_sentence_entities]) |
|
|
36 |
|
|
|
37 |
return augmented_text, augmented_entities |
|
|
38 |
|
|
|
39 |
class Dataloader(): |
|
|
40 |
""" |
|
|
41 |
Dataloader used for loading the dataset used in this project. Also provides a framework for automatic |
|
|
42 |
tokenization of the data. |
|
|
43 |
""" |
|
|
44 |
|
|
|
45 |
def __init__(self, label_to_ids, ids_to_label, transfer_learning, max_tokens, type): |
|
|
46 |
self.label_to_ids = label_to_ids |
|
|
47 |
self.ids_to_label = ids_to_label |
|
|
48 |
self.max_tokens = max_tokens |
|
|
49 |
self.transfer_learning = transfer_learning |
|
|
50 |
self.type = type |
|
|
51 |
|
|
|
52 |
def load_dataset(self, full = False, augment = False): |
|
|
53 |
""" |
|
|
54 |
Loads the dataset and automatically initialized a tokenizer for the Custom_Dataset initialization. |
|
|
55 |
|
|
|
56 |
Parameters: |
|
|
57 |
full (bool): Whether the function should return the whole dataset or not - will return a train-val-test split |
|
|
58 |
according to the Pareto principle (80:20). |
|
|
59 |
augment (bool): Whether the existing dataset should be extended via augmented data. Augmentation in this sense |
|
|
60 |
means that the dataset will be extended via instances where the sentences are randomly switched around. |
|
|
61 |
|
|
|
62 |
Returns: |
|
|
63 |
if full: |
|
|
64 |
dataset (Custom_Dataset): the full dataset in one. |
|
|
65 |
else: |
|
|
66 |
tuple: |
|
|
67 |
- train_dataset (Custom_Dataset): Dataset used for training. |
|
|
68 |
- val_dataset (Custom_Dataset): Dataset used for validation. |
|
|
69 |
- test_dataset (Custom_Dataset): Dataset sued for testing. |
|
|
70 |
""" |
|
|
71 |
|
|
|
72 |
if self.transfer_learning: |
|
|
73 |
data = pd.read_csv("../datasets/labelled_data/MEDCOND/all.csv", names=['text', 'entity'], header=None, sep="|") |
|
|
74 |
tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner') |
|
|
75 |
data['entity'] = data['entity'].apply(lambda x: x.replace('B-MEDCOND', 'B-DISEASE')) |
|
|
76 |
data['entity'] = data['entity'].apply(lambda x: x.replace('I-MEDCOND', 'I-DISEASE')) |
|
|
77 |
else: |
|
|
78 |
data = pd.read_csv(f"../datasets/labelled_data/{self.type}/all.csv", names=['text', 'entity'], header=None, sep="|") |
|
|
79 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
80 |
tokenizer.add_tokens(['B-' + self.type, 'I-' + self.type, 'O']) |
|
|
81 |
|
|
|
82 |
if not full: |
|
|
83 |
#train_data = data.sample((int) (len(data)*0.8), random_state=7).reset_index(drop=True) |
|
|
84 |
#test_data = data.drop(train_data.index).reset_index(drop=True) |
|
|
85 |
|
|
|
86 |
train_data = data.sample(frac=0.7, random_state=7).reset_index(drop=True) |
|
|
87 |
|
|
|
88 |
remaining_data = data.drop(train_data.index).reset_index(drop=True) |
|
|
89 |
val_data = remaining_data.sample(frac=0.2857, random_state=7).reset_index(drop=True) |
|
|
90 |
|
|
|
91 |
test_data = remaining_data.drop(val_data.index).reset_index(drop=True) |
|
|
92 |
|
|
|
93 |
if augment: |
|
|
94 |
augmented_rows = [shuffle_sentences_and_entities(text, entities) for text, entities in zip(train_data['text'], train_data['entity'])] |
|
|
95 |
augmented_texts, augmented_entities = zip(*augmented_rows) |
|
|
96 |
|
|
|
97 |
augmented_data = pd.DataFrame({'text': augmented_texts, 'entity': augmented_entities}) |
|
|
98 |
train_data = pd.concat([train_data, augmented_data]).reset_index(drop=True) |
|
|
99 |
|
|
|
100 |
train_dataset = Custom_Dataset(train_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
101 |
val_dataset = Custom_Dataset(val_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
102 |
test_dataset = Custom_Dataset(test_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
103 |
|
|
|
104 |
return train_dataset, val_dataset, test_dataset |
|
|
105 |
else: |
|
|
106 |
dataset = Custom_Dataset(data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
107 |
return dataset |
|
|
108 |
|
|
|
109 |
def load_custom(self, data): |
|
|
110 |
""" |
|
|
111 |
Loads the dataset, but with entities swapped from MEDCOND to DISEASE (if transfer learning |
|
|
112 |
is enabled). |
|
|
113 |
|
|
|
114 |
Parameters: |
|
|
115 |
data (dataframe): Data extracted from csv file. |
|
|
116 |
|
|
|
117 |
Returns: |
|
|
118 |
dataset (Custom_Dataset): Dataset changed accordingly. |
|
|
119 |
""" |
|
|
120 |
if self.transfer_learning: |
|
|
121 |
tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner') |
|
|
122 |
data['entity'] = data['entity'].apply(lambda x: x.replace('B-MEDCOND', 'B-DISEASE')) |
|
|
123 |
data['entity'] = data['entity'].apply(lambda x: x.replace('I-MEDCOND', 'I-DISEASE')) |
|
|
124 |
else: |
|
|
125 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
126 |
tokenizer.add_tokens(['B-' + self.type, 'I-' + self.type, 'O']) |
|
|
127 |
dataset = Custom_Dataset(data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
128 |
return dataset |
|
|
129 |
|
|
|
130 |
def convert_id_to_label(self, ids): |
|
|
131 |
return [self.ids_to_label.get(x) for x in ids.numpy()[0]] |
|
|
132 |
|
|
|
133 |
def tokenize_and_preserve_labels(sentence, text_labels, tokenizer, label_to_ids, ids_to_label, max_tokens): |
|
|
134 |
""" |
|
|
135 |
Tokenizes each word separately. This may take longer, but increases accuracy. Preserves the labels |
|
|
136 |
of each word, adhereing to B and I prefixes. |
|
|
137 |
|
|
|
138 |
Parameters: |
|
|
139 |
sentence (string): Sentence to be tokenized. |
|
|
140 |
text_labels (numpy.array): Contains the labels of the sentence. |
|
|
141 |
tokenizer (BertTokenizer): Tokenizer used for tokenizing sentences. |
|
|
142 |
label_to_ids (dict): Dictionary containing label-id mappings. |
|
|
143 |
ids_to_label (dict): Dictionary containing id-label mappings. |
|
|
144 |
max_tokens (int): The maximum tokens allowed (input size of BERT model). |
|
|
145 |
|
|
|
146 |
Returns: |
|
|
147 |
tuple: |
|
|
148 |
- tokenized_sentence (numpy.array): Array containing all tokens of the give sentence. |
|
|
149 |
- labels (numpy.array): Array containing the corresponding labels of the tokens. |
|
|
150 |
""" |
|
|
151 |
tokenized_sentence = [] |
|
|
152 |
labels = [] |
|
|
153 |
|
|
|
154 |
for word, label in zip(sentence, text_labels): |
|
|
155 |
tokenized_word = tokenizer.tokenize(word) |
|
|
156 |
n_subwords = len(tokenized_word) |
|
|
157 |
|
|
|
158 |
if(len(tokenized_sentence)>=max_tokens): #truncate |
|
|
159 |
return tokenized_sentence, labels |
|
|
160 |
|
|
|
161 |
tokenized_sentence.extend(tokenized_word) |
|
|
162 |
|
|
|
163 |
if label.startswith("B-"): |
|
|
164 |
labels.extend([label]) |
|
|
165 |
labels.extend([ids_to_label.get(label_to_ids.get(label)+1)]*(n_subwords-1)) |
|
|
166 |
else: |
|
|
167 |
labels.extend([label] * n_subwords) |
|
|
168 |
|
|
|
169 |
return tokenized_sentence, labels |
|
|
170 |
|
|
|
171 |
class Custom_Dataset(Dataset): |
|
|
172 |
""" |
|
|
173 |
Dataset used for loading and tokenizing sentences on-the-fly. |
|
|
174 |
""" |
|
|
175 |
|
|
|
176 |
def __init__(self, data, tokenizer, label_to_ids, ids_to_label, max_tokens): |
|
|
177 |
self.data = data |
|
|
178 |
self.tokenizer = tokenizer |
|
|
179 |
self.label_to_ids = label_to_ids |
|
|
180 |
self.ids_to_label = ids_to_label |
|
|
181 |
self.max_tokens = max_tokens |
|
|
182 |
|
|
|
183 |
def __len__(self): |
|
|
184 |
return len(self.data) |
|
|
185 |
|
|
|
186 |
def __getitem__(self, idx): |
|
|
187 |
""" |
|
|
188 |
Takes the current sentence with its labels and tokenizes it on-the-fly. |
|
|
189 |
|
|
|
190 |
Returns: |
|
|
191 |
item (torch.tensor): Tensor which can be fed into model. |
|
|
192 |
""" |
|
|
193 |
sentence = re.findall(r"\w+|\w+(?='s)|'s|['\".,!?;]", self.data['text'][idx].strip(), re.UNICODE) |
|
|
194 |
word_labels = self.data['entity'][idx].split(" ") |
|
|
195 |
t_sen, t_labl = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens) |
|
|
196 |
|
|
|
197 |
sen_code = self.tokenizer.encode_plus(t_sen, |
|
|
198 |
add_special_tokens=True, # adds [CLS] and [SEP] |
|
|
199 |
max_length = self.max_tokens, # maximum tokens of a sentence |
|
|
200 |
padding='max_length', |
|
|
201 |
return_attention_mask=True, # generates the attention mask |
|
|
202 |
truncation = True |
|
|
203 |
) |
|
|
204 |
|
|
|
205 |
#shift labels (due to [CLS] and [SEP]) |
|
|
206 |
labels = [-100]*self.max_tokens #-100 is ignore token |
|
|
207 |
for i, tok in enumerate(t_labl): |
|
|
208 |
if tok != None and i < self.max_tokens-1: |
|
|
209 |
labels[i+1]=self.label_to_ids.get(tok) |
|
|
210 |
|
|
|
211 |
item = {key: torch.as_tensor(val) for key, val in sen_code.items()} |
|
|
212 |
item['entity'] = torch.as_tensor(labels) |
|
|
213 |
|
|
|
214 |
return item |