Diff of /scripts/utils.py [000000] .. [c0f169]

Switch to unified view

a b/scripts/utils.py
1
import re
2
import os
3
import pickle
4
5
import spacy
6
from spacy import displacy
7
8
import numpy as np
9
from tensorflow.keras.preprocessing.sequence import pad_sequences
10
11
import nltk
12
nltk.download('punkt')
13
nltk.download('stopwords')
14
from nltk.corpus import stopwords
15
16
STOP_WORDS = stopwords.words('english')
17
18
# Load the tokenizer from file
19
with open('../data/tokenizer.pickle', 'rb') as handle:
20
    tokenizer = pickle.load(handle)
21
22
def load_data(data_dir):
23
    data = np.load(os.path.join(data_dir, 'data.npz'), allow_pickle=True)
24
    
25
    train_sequences_padded = data['train_sequences_padded']
26
    train_labels = data['train_labels']
27
    
28
    val_sequences_padded = data['val_sequences_padded']
29
    val_labels = data['val_labels']
30
    
31
    test_sequences_padded = data['test_sequences_padded']
32
    test_labels = data['test_labels']
33
    
34
    label_to_index = data['label_to_index'].item()  # use .item() to convert the numpy array to a Python dictionary
35
    
36
    index_to_label = data['index_to_label'].item()
37
    
38
    return (train_sequences_padded, train_labels), (val_sequences_padded, val_labels), (
39
    test_sequences_padded, test_labels), label_to_index, index_to_label
40
41
42
def clean_word(word):
43
    """
44
    Cleans a word by removing non-alphanumeric characters and extra whitespaces,
45
    converting it to lowercase, and checking if it is a stopword.
46
47
    Args:
48
    - word (str): the word to clean
49
50
    Returns:
51
    - str: the cleaned word, or an empty string if it is a stopword
52
    """
53
    # remove non-alphanumeric characters and extra whitespaces
54
    word = re.sub(r'[^\w\s]', '', word)
55
    word = re.sub(r'\s+', ' ', word)
56
    
57
    # convert to lowercase
58
    word = word.lower()
59
    
60
    if word not in STOP_WORDS:
61
        return word
62
    
63
    return ''
64
65
def tokenize_text(text):
66
    """
67
    Tokenizes a text into a list of cleaned words.
68
69
    Args:
70
    - text (str): the text to tokenize
71
72
    Returns:
73
    - tokens (list of str): the list of cleaned words
74
    - start_end_ranges (list of tuples): the start and end character positions for each token
75
    """
76
    regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+'  # r'[^\s\u200a\-\—\–]+'
77
    tokens = []
78
    start_end_ranges = []
79
    # Tokenize the sentences in the text
80
    sentences = nltk.sent_tokenize(text)
81
    
82
    start = 0
83
    for sentence in sentences:
84
        
85
        sentence_tokens = re.findall(regex_match, sentence)
86
        curr_sent_tokens = []
87
        curr_sent_ranges = []
88
        
89
        for word in sentence_tokens:
90
            word = clean_word(word)
91
            if word.strip():
92
                start = text.lower().find(word, start)
93
                end = start + len(word)
94
                curr_sent_ranges.append((start, end))
95
                curr_sent_tokens.append(word)
96
                start = end
97
        if len(curr_sent_tokens) > 0:
98
            tokens.append(curr_sent_tokens)
99
            start_end_ranges.append(curr_sent_ranges)
100
            
101
    return tokens, start_end_ranges
102
103
# def tokenize_text(text):
104
#   """
105
#   Tokenizes a text into a list of cleaned words.
106
#
107
#   Args:
108
#   - text (str): the text to tokenize
109
#
110
#   Returns:
111
#   - list of str: the list of cleaned words
112
#   """
113
#   regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+'  # r'[^\s\u200a\-\—\–]+'
114
#   tokens = []
115
#   for sentence in text.split('\n'):
116
#       sentence_tokens = re.findall(regex_match, sentence)
117
#       for word in sentence_tokens:
118
#           word = clean_word(word)
119
#           if word.strip():
120
#               tokens.append(word)
121
#   return tokens
122
123
124
def predict(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
125
    """
126
    Predicts named entities in a text using a trained NER model.
127
128
    Args:
129
    - text (str): the text to predict named entities in
130
    - model: the trained NER model
131
    - tokenizer: the trained tokenizer used for the model
132
    - index_to_label (list of str): a list mapping each index in the predicted sequence to a named entity label
133
    - acronyms_to_entities (dict): a dictionary mapping acronyms to their corresponding named entity labels
134
    - MAX_LENGTH (int): the maximum sequence length for the model
135
136
    Returns:
137
    - None
138
    """
139
    
140
    tokens, start_end_ranges = tokenize_text(text)
141
    all_tokens = []
142
    all_ranges = []
143
    for sent_tokens, sent_ranges in zip(tokens, start_end_ranges):
144
        for token, start_end in zip(sent_tokens, sent_ranges):
145
            start, end = start_end[0], start_end[1]
146
            all_tokens.append(token)
147
            all_ranges.append((start, end))
148
            
149
    sequence = tokenizer.texts_to_sequences([' '.join(token for token in all_tokens)])
150
    padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post')
151
    
152
    # Make the prediction
153
    prediction = model.predict(np.array(padded_sequence))
154
    
155
    # Decode the prediction
156
    predicted_labels = np.argmax(prediction, axis=-1)
157
    predicted_labels = [index_to_label[i] for i in predicted_labels[0]]
158
    
159
    entities = []
160
    start_char = 0
161
    for i, (token, label, start_end_range) in enumerate(zip(all_tokens, predicted_labels, all_ranges)):
162
    
163
        start = start_end_range[0]
164
        end = start_end_range[1]
165
        
166
        if label != 'O':
167
            entity_type = acronyms_to_entities[label[2:]]
168
            entity = (start, end, entity_type)
169
            entities.append(entity)
170
    
171
    # Print the predicted named entities
172
    print("Predicted Named Entities:")
173
    for i in range(len(all_tokens)):
174
        if predicted_labels[i] == 'O':
175
            print(f"{all_tokens[i]}: {predicted_labels[i]}")
176
        else:
177
            print(f"{all_tokens[i]}: {acronyms_to_entities[predicted_labels[i][2:]]}")
178
    
179
    display_pred(text, entities)
180
181
def display_pred(text, entities):
182
    nlp = spacy.load("en_core_web_sm", disable=['ner'])
183
    # Generate the entities in Spacy format
184
    doc = nlp(text)
185
    # Add the predicted named entities to the Doc object
186
    for start, end, label in entities:
187
        span = doc.char_span(start, end, label=label)
188
        if span is not None:
189
            doc.ents += tuple([span])
190
    
191
    colors = {"Activity": "#f9d5e5",
192
              "Administration": "#f7a399",
193
              "Age": "#f6c3d0",
194
              "Area": "#fde2e4",
195
              "Biological_attribute": "#d5f5e3",
196
              "Biological_structure": "#9ddfd3",
197
              "Clinical_event": "#77c5d5",
198
              "Color": "#a0ced9",
199
              "Coreference": "#e3b5a4",
200
              "Date": "#f1f0d2",
201
              "Detailed_description": "#ffb347",
202
              "Diagnostic_procedure": "#c5b4e3",
203
              "Disease_disorder": "#c4b7ea",
204
              "Distance": "#bde0fe",
205
              "Dosage": "#b9e8d8",
206
              "Duration": "#ffdfba",
207
              "Family_history": "#e6ccb2",
208
              "Frequency": "#e9d8a6",
209
              "Height": "#f2eecb",
210
              "History": "#e2f0cb",
211
              "Lab_value": "#f4b3c2",
212
              "Mass": "#f4c4c3",
213
              "Medication": "#f9d5e5",
214
              "Nonbiological_location": "#f7a399",
215
              "Occupation": "#f6c3d0",
216
              "Other_entity": "#d5f5e3",
217
              "Other_event": "#9ddfd3",
218
              "Outcome": "#77c5d5",
219
              "Personal_background": "#a0ced9",
220
              "Qualitative_concept": "#e3b5a4",
221
              "Quantitative_concept": "#f1f0d2",
222
              "Severity": "#ffb347",
223
              "Sex": "#c5b4e3",
224
              "Shape": "#c4b7ea",
225
              "Sign_symptom": "#bde0fe",
226
              "Subject": "#b9e8d8",
227
              "Texture": "#ffdfba",
228
              "Therapeutic_procedure": "#e6ccb2",
229
              "Time": "#e9d8a6",
230
              "Volume": "#f2eecb",
231
              "Weight": "#e2f0cb"}
232
    options = {"compact": True, "bg": "#F8F8F8",
233
               "ents": list(colors.keys()),
234
               "colors": colors}
235
    
236
    # Generate the HTML visualization
237
    html = displacy.render(doc, style="ent", options=options)
238
239
# def predict(text, model, tokenizer, index_to_label, acronyms_to_entities, MAX_LENGTH):
240
#   """
241
#   Predicts named entities in a text using a trained NER model.
242
#
243
#   Args:
244
#   - text (str): the text to predict named entities in
245
#   - model: the trained NER model
246
#   - tokenizer: the trained tokenizer used for the model
247
#   - index_to_label (list of str): a list mapping each index in the predicted sequence to a named entity label
248
#   - acronyms_to_entities (dict): a dictionary mapping acronyms to their corresponding named entity labels
249
#   - MAX_LENGTH (int): the maximum sequence length for the model
250
#
251
#   Returns:
252
#   - None
253
#   """
254
#
255
#   tokens = tokenize_text(text)
256
#   sequence = tokenizer.texts_to_sequences([' '.join(token for token in tokens)])
257
#   padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post')
258
#
259
#   # Make the prediction
260
#   prediction = model.predict(np.array(padded_sequence))
261
#
262
#   # Decode the prediction
263
#   predicted_labels = np.argmax(prediction, axis=-1)
264
#   predicted_labels = [index_to_label[i] for i in predicted_labels[0]]
265
#
266
#   # Print the predicted named entities
267
#   print("Predicted Named Entities:")
268
#   for i in range(len(tokens)):
269
#       if predicted_labels[i] == 'O':
270
#           print(f"{tokens[i]}: {predicted_labels[i]}")
271
#       else:
272
#           print(f"{tokens[i]}: {acronyms_to_entities[predicted_labels[i][2:]]}")
273
#
274
275
def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
276
    
277
    # sentences = re.split(r' *[\.\?!][\'"\)\]]* *', text)
278
    # sent_tokens = []
279
    # sent_start_end = []
280
    sequences = []
281
    
282
    sent_tokens, sent_start_end = tokenize_text(text)
283
    
284
    for i in range(len(sent_tokens)):
285
        sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])])
286
        sequences.extend(sequence)
287
    
288
    # for sentence in sentences:
289
    #   tokens, start_end_ranges = tokenize_text(sentence)
290
    #   sequence = tokenizer.texts_to_sequences([' '.join(token for token in tokens)])
291
    #   sequences.append(sequence[0])
292
    #   sent_tokens.append(tokens)
293
    #   sent_start_end.append(start_end_ranges)
294
        
295
    padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post')
296
    
297
    # Make the prediction
298
    prediction = model.predict(np.array(padded_sequence))
299
    
300
    # Decode the prediction
301
    predicted_labels = np.argmax(prediction, axis=-1)
302
    
303
    predicted_labels = [
304
        [index_to_label[i] for i in sent_predicted_labels]
305
        for sent_predicted_labels in predicted_labels
306
    ]
307
    
308
    entities = []
309
    start_char = 0
310
    
311
    for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end):
312
        
313
        for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)):
314
            start = start_end_range[0]
315
            end = start_end_range[1]
316
            
317
            if label != 'O':
318
                entity_type = acronyms_to_entities[label[2:]]
319
                entity = (start, end, entity_type)
320
                entities.append(entity)
321
        
322
    # Print the predicted named entities
323
    print("Predicted Named Entities:")
324
    for i in range(len(sent_tokens)):
325
        for j in range(len(sent_tokens[i])):
326
            if predicted_labels[i][j] == 'O':
327
                print(f"{sent_tokens[i][j]}: {predicted_labels[i][j]}")
328
            else:
329
                print(f"{sent_tokens[i][j]}: {acronyms_to_entities[predicted_labels[i][j][2:]]}")
330
        print("\n\n\n")
331
    
332
    display_pred(text, entities)
333
    # return entities