Diff of /src/utils/dataloader.py [000000] .. [0eda78]

Switch to unified view

a b/src/utils/dataloader.py
1
import numpy as np
2
import pandas as pd
3
import torch
4
import re
5
from torch import nn
6
from torch.utils.data import Dataset
7
from torch.utils.data import DataLoader
8
from transformers import BertTokenizer,BertForTokenClassification
9
import random
10
import nltk
11
from nltk.tokenize import sent_tokenize
12
nltk.download('punkt')
13
14
15
def shuffle_sentences_and_entities(text, entities):
16
    sentences = sent_tokenize(text)
17
    entity_tokens = entities.split() # align with words in text
18
19
    # identify start and end indices of sentences in terms of word counts
20
    word_counts = [len(sentence.split()) for sentence in sentences]
21
    start_indices = [sum(word_counts[:i]) for i in range(len(word_counts))]
22
    end_indices = [sum(word_counts[:i+1]) for i in range(len(word_counts))]
23
24
    # split entities into groups (corresponding to sentences)
25
    sentence_entities = [entity_tokens[start:end] for start, end in zip(start_indices, end_indices)]
26
27
    # shuffle sentence-entities pairs
28
    combined = list(zip(sentences, sentence_entities))
29
    random.seed(42)
30
    random.shuffle(combined)
31
    shuffled_sentences, shuffled_sentence_entities = zip(*combined)
32
33
    # reconstruction
34
    augmented_text = ' '.join(shuffled_sentences)
35
    augmented_entities = ' '.join([' '.join(entity_group) for entity_group in shuffled_sentence_entities])
36
37
    return augmented_text, augmented_entities
38
39
class Dataloader():
40
    """
41
    Dataloader used for loading the dataset used in this project. Also provides a framework for automatic
42
    tokenization of the data.
43
    """
44
45
    def __init__(self, label_to_ids, ids_to_label, transfer_learning, max_tokens, type):
46
        self.label_to_ids = label_to_ids
47
        self.ids_to_label = ids_to_label
48
        self.max_tokens = max_tokens
49
        self.transfer_learning = transfer_learning
50
        self.type = type
51
52
    def load_dataset(self, full = False, augment = False):
53
        """
54
        Loads the dataset and automatically initialized a tokenizer for the Custom_Dataset initialization.
55
56
        Parameters:
57
        full (bool): Whether the function should return the whole dataset or not - will return a train-val-test split
58
                     according to the Pareto principle (80:20).
59
        augment (bool): Whether the existing dataset should be extended via augmented data. Augmentation in this sense
60
                        means that the dataset will be extended via instances where the sentences are randomly switched around.
61
62
        Returns:
63
        if full:
64
            dataset (Custom_Dataset): the full dataset in one.
65
        else:
66
            tuple:
67
                - train_dataset (Custom_Dataset): Dataset used for training.
68
                - val_dataset (Custom_Dataset): Dataset used for validation.
69
                - test_dataset (Custom_Dataset): Dataset sued for testing.
70
        """
71
72
        if self.transfer_learning:
73
            data = pd.read_csv("../datasets/labelled_data/MEDCOND/all.csv", names=['text', 'entity'], header=None, sep="|")
74
            tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner')
75
            data['entity'] = data['entity'].apply(lambda x: x.replace('B-MEDCOND', 'B-DISEASE'))
76
            data['entity'] = data['entity'].apply(lambda x: x.replace('I-MEDCOND', 'I-DISEASE'))
77
        else:
78
            data = pd.read_csv(f"../datasets/labelled_data/{self.type}/all.csv", names=['text', 'entity'], header=None, sep="|")
79
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
80
            tokenizer.add_tokens(['B-' + self.type, 'I-' + self.type, 'O'])
81
82
        if not full:
83
            #train_data = data.sample((int) (len(data)*0.8), random_state=7).reset_index(drop=True)
84
            #test_data = data.drop(train_data.index).reset_index(drop=True)
85
86
            train_data = data.sample(frac=0.7, random_state=7).reset_index(drop=True)
87
88
            remaining_data = data.drop(train_data.index).reset_index(drop=True)
89
            val_data = remaining_data.sample(frac=0.2857, random_state=7).reset_index(drop=True)
90
91
            test_data = remaining_data.drop(val_data.index).reset_index(drop=True)
92
93
            if augment:
94
                augmented_rows = [shuffle_sentences_and_entities(text, entities) for text, entities in zip(train_data['text'], train_data['entity'])]
95
                augmented_texts, augmented_entities = zip(*augmented_rows)
96
97
                augmented_data = pd.DataFrame({'text': augmented_texts, 'entity': augmented_entities})
98
                train_data = pd.concat([train_data, augmented_data]).reset_index(drop=True)
99
100
            train_dataset = Custom_Dataset(train_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
101
            val_dataset = Custom_Dataset(val_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
102
            test_dataset = Custom_Dataset(test_data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
103
104
            return train_dataset, val_dataset, test_dataset
105
        else:
106
            dataset = Custom_Dataset(data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
107
            return dataset
108
109
    def load_custom(self, data):
110
        """
111
        Loads the dataset, but with entities swapped from MEDCOND to DISEASE (if transfer learning
112
        is enabled).
113
114
        Parameters:
115
        data (dataframe): Data extracted from csv file.
116
117
        Returns:
118
        dataset (Custom_Dataset): Dataset changed accordingly.
119
        """
120
        if self.transfer_learning:
121
            tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner')
122
            data['entity'] = data['entity'].apply(lambda x: x.replace('B-MEDCOND', 'B-DISEASE'))
123
            data['entity'] = data['entity'].apply(lambda x: x.replace('I-MEDCOND', 'I-DISEASE'))
124
        else:
125
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
126
            tokenizer.add_tokens(['B-' + self.type, 'I-' + self.type, 'O'])
127
        dataset = Custom_Dataset(data, tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
128
        return dataset
129
130
    def convert_id_to_label(self, ids):
131
        return [self.ids_to_label.get(x) for x in ids.numpy()[0]]
132
133
def tokenize_and_preserve_labels(sentence, text_labels, tokenizer, label_to_ids, ids_to_label, max_tokens):
134
    """
135
    Tokenizes each word separately. This may take longer, but increases accuracy. Preserves the labels
136
    of each word, adhereing to B and I prefixes.
137
138
    Parameters:
139
    sentence (string): Sentence to be tokenized.
140
    text_labels (numpy.array): Contains the labels of the sentence.
141
    tokenizer (BertTokenizer): Tokenizer used for tokenizing sentences.
142
    label_to_ids (dict): Dictionary containing label-id mappings.
143
    ids_to_label (dict): Dictionary containing id-label mappings.
144
    max_tokens (int): The maximum tokens allowed (input size of BERT model).
145
146
    Returns:
147
        tuple:
148
            - tokenized_sentence (numpy.array): Array containing all tokens of the give sentence.
149
            - labels (numpy.array): Array containing the corresponding labels of the tokens.
150
    """
151
    tokenized_sentence = []
152
    labels = []
153
154
    for word, label in zip(sentence, text_labels):
155
        tokenized_word = tokenizer.tokenize(word)
156
        n_subwords = len(tokenized_word)
157
158
        if(len(tokenized_sentence)>=max_tokens): #truncate
159
            return tokenized_sentence, labels
160
161
        tokenized_sentence.extend(tokenized_word)
162
163
        if label.startswith("B-"):
164
            labels.extend([label])
165
            labels.extend([ids_to_label.get(label_to_ids.get(label)+1)]*(n_subwords-1))
166
        else:
167
            labels.extend([label] * n_subwords)
168
169
    return tokenized_sentence, labels
170
171
class Custom_Dataset(Dataset):
172
    """
173
    Dataset used for loading and tokenizing sentences on-the-fly.
174
    """
175
176
    def __init__(self, data, tokenizer, label_to_ids, ids_to_label, max_tokens):
177
        self.data = data
178
        self.tokenizer = tokenizer
179
        self.label_to_ids = label_to_ids
180
        self.ids_to_label = ids_to_label
181
        self.max_tokens = max_tokens
182
183
    def __len__(self):
184
        return len(self.data)
185
186
    def __getitem__(self, idx):
187
        """
188
        Takes the current sentence with its labels and tokenizes it on-the-fly.
189
190
        Returns:
191
        item (torch.tensor): Tensor which can be fed into model.
192
        """
193
        sentence = re.findall(r"\w+|\w+(?='s)|'s|['\".,!?;]", self.data['text'][idx].strip(), re.UNICODE)
194
        word_labels = self.data['entity'][idx].split(" ")
195
        t_sen, t_labl = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer, self.label_to_ids, self.ids_to_label, self.max_tokens)
196
197
        sen_code = self.tokenizer.encode_plus(t_sen,
198
            add_special_tokens=True, # adds [CLS] and [SEP]
199
            max_length = self.max_tokens, # maximum tokens of a sentence
200
            padding='max_length',
201
            return_attention_mask=True, # generates the attention mask
202
            truncation = True
203
            )
204
205
        #shift labels (due to [CLS] and [SEP])
206
        labels = [-100]*self.max_tokens #-100 is ignore token
207
        for i, tok in enumerate(t_labl):
208
            if tok != None and i < self.max_tokens-1:
209
                labels[i+1]=self.label_to_ids.get(tok)
210
211
        item = {key: torch.as_tensor(val) for key, val in sen_code.items()}
212
        item['entity'] = torch.as_tensor(labels)
213
214
        return item