Diff of /src/Parser/ops.py [000000] .. [f87529]

Switch to unified view

a b/src/Parser/ops.py
1
import re
2
import copy
3
import json
4
import time
5
6
import numpy as np
7
import xml.etree.ElementTree as ElTree
8
9
from datetime import datetime, timezone
10
from operator import itemgetter
11
12
13
tokenize_regex = re.compile(r'([0-9a-zA-Z]+|[^0-9a-zA-Z])')
14
15
def json_to_sent(data):
16
    '''data: list of json file [{pmid,abstract,title}, ...] '''
17
    out = dict()
18
    for paper in data:
19
        sentences = list()
20
        
21
        if len(CoNLL_tokenizer(paper['title'])) < 50:
22
            title = [paper['title']]
23
        else:
24
            title = sentence_split(paper['title'])
25
        if len(title) != 1 or len(title[0].strip()) > 0:
26
            sentences.extend(title)
27
28
        if len(paper['abstract']) > 0:
29
            abst = sentence_split(paper['abstract'])
30
            if len(abst) != 1 or len(abst[0].strip()) > 0:
31
                sentences.extend(abst)
32
        out[paper['pmid']] = dict()
33
        out[paper['pmid']]['sentence'] = sentences
34
    return out
35
36
def input_form(sent_data):
37
    '''sent_data: dict of sentence, key=pmid {pmid:[sent,sent, ...], pmid: ...}'''
38
    for pmid in sent_data:
39
        sent_data[pmid]['words'] = list()
40
        sent_data[pmid]['wordPos'] = list()
41
        doc_piv = 0
42
        for sent in sent_data[pmid]['sentence']:
43
            wids = list()
44
            wpos = list()
45
            sent_piv = 0
46
            tok = CoNLL_tokenizer(sent)
47
48
            for w in tok:
49
                if len(w) > 20:
50
                    wids.append(w[:10])
51
                else:
52
                    wids.append(w)
53
54
                start = doc_piv + sent_piv + sent[sent_piv:].find(w)
55
                end = start + len(w) - 1
56
                sent_piv = end - doc_piv + 1
57
                wpos.append((start, end))
58
            doc_piv += len(sent)
59
            sent_data[pmid]['words'].append(wids)
60
            sent_data[pmid]['wordPos'].append(wpos)
61
62
    return sent_data
63
64
def softmax(logits):
65
    out = list()
66
    for logit in logits:
67
        temp = np.subtract(logit, np.max(logit))
68
        p = np.exp(temp) / np.sum(np.exp(temp))
69
        out.append(np.max(p))
70
    return out
71
72
def CoNLL_tokenizer(text):
73
    rawTok = [t for t in tokenize_regex.split(text) if t]
74
    assert ''.join(rawTok) == text
75
    tok = [t for t in rawTok if t != ' ']
76
    return tok
77
78
def sentence_split(text):
79
    sentences = list()
80
    sent = ''
81
    piv = 0
82
    for idx, char in enumerate(text):
83
        if char in "?!":
84
            if idx > len(text) - 3:
85
                sent = text[piv:]
86
                piv = -1
87
            else:
88
                sent = text[piv:idx + 1]
89
                piv = idx + 1
90
91
        elif char == '.':
92
            if idx > len(text) - 3:
93
                sent = text[piv:]
94
                piv = -1
95
            elif (text[idx + 1] == ' ') and (
96
                    text[idx + 2] in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-"' + "'"):
97
                sent = text[piv:idx + 1]
98
                piv = idx + 1
99
100
        if sent != '':
101
            toks = CoNLL_tokenizer(sent)
102
            if len(toks) > 100:
103
                while True:
104
                    rawTok = [t for t in tokenize_regex.split(sent) if t]
105
                    cut = ''.join(rawTok[:200])
106
                    sent = ''.join(rawTok[200:])
107
                    sentences.append(cut)
108
109
                    if len(CoNLL_tokenizer(sent)) < 100:
110
                        if sent.strip() == '':
111
                            sent = ''
112
                            break
113
                        else:
114
                            sentences.append(sent)
115
                            sent = ''
116
                            break
117
            else:
118
                sentences.append(sent)
119
                sent = ''
120
121
            if piv == -1:
122
                break
123
124
    if piv != -1:
125
        sent = text[piv:]
126
        toks = CoNLL_tokenizer(sent)
127
        if len(toks) > 100:
128
            while True:
129
                rawTok = [t for t in tokenize_regex.split(sent) if t]
130
                cut = ''.join(rawTok[:200])
131
                sent = ''.join(rawTok[200:])
132
                sentences.append(cut)
133
134
                if len(CoNLL_tokenizer(sent)) < 100:
135
                    if sent.strip() == '':
136
                        sent = ''
137
                        break
138
                    else:
139
                        sentences.append(sent)
140
                        sent = ''
141
                        break
142
        else:
143
            sentences.append(sent)
144
            sent = ''
145
146
    return sentences
147
148
def get_prob(data, sent_data, predicDict, logitsDict, entity_types=None):
149
    for idx, paper in enumerate(data):
150
        pmid = paper['pmid']
151
        
152
        if len(paper['abstract']) > 0:
153
            content = paper['title'] + ' ' + paper['abstract']
154
        else:
155
            content = paper['title']
156
157
        for ent_type in entity_types:
158
            paper['entities'][ent_type] = []
159
        paper['prob'] = dict()
160
161
        for dtype in entity_types:
162
            for sentidx, tags in enumerate(predicDict[dtype][pmid]):
163
                B_flag = False
164
                # get position of entity corresponding to types
165
                for widx, tag in enumerate(tags):
166
                    if tag == 'O':
167
                        if B_flag:
168
                            tmpSE["end"] = \
169
                            sent_data[pmid]['wordPos'][sentidx][widx - 1][1]
170
                            paper['entities'][dtype].append(tmpSE)
171
                        B_flag = False
172
                        continue
173
                    elif tag == 'B':
174
                        if B_flag:
175
                            tmpSE["end"] = \
176
                            sent_data[pmid]['wordPos'][sentidx][widx - 1][1]
177
                            paper['entities'][dtype].append(tmpSE)
178
                        tmpSE = {
179
                            "start": sent_data[pmid]['wordPos'][sentidx][widx][
180
                                0]}
181
                        B_flag = True
182
                    elif tag == "I":
183
                        continue
184
                if B_flag:
185
                    tmpSE["end"] = sent_data[pmid]['wordPos'][sentidx][-1][1]
186
                    paper['entities'][dtype].append(tmpSE)
187
188
            # get prob. of entity logits corresponding to types
189
            logs = list()
190
            for t_sent in logitsDict[dtype][pmid]:
191
                logs.extend(t_sent)
192
            paper['prob'][dtype] = list()
193
            for pos in paper['entities'][dtype]:
194
                if pos['start'] == pos['end']:
195
                    soft = softmax(logs[len(
196
                        CoNLL_tokenizer(content[:pos['start']])):len(
197
                        CoNLL_tokenizer(content[:pos['end']])) + 1])
198
                    paper['prob'][dtype].append(
199
                        (pos, float(np.average(soft))))
200
                else:
201
                    soft = softmax(logs[len(
202
                        CoNLL_tokenizer(content[:pos['start']])):len(
203
                        CoNLL_tokenizer(content[:pos['end']]))])
204
                    paper['prob'][dtype].append(
205
                        (pos, float(np.average(soft))))
206
207
    return data
208
209
def detokenize(tokens, predicts, logits):
210
    pred = dict({
211
        'toks': tokens[:],
212
        'labels': predicts[:],
213
        'logit': logits[:]
214
    })  # dictionary for predicted tokens and labels.
215
216
    bert_toks = list()
217
    bert_labels = list()
218
    bert_logits = list()
219
    tmp_p = list()
220
    tmp_l = list()
221
    tmp_s = list()
222
    for t, l, s in zip(pred['toks'], pred['labels'], pred['logit']):
223
        if t == '[CLS]' or t == '<s>':  # non-text tokens will not be evaluated.
224
            continue
225
        elif t == '[SEP]' or t == '</s>':  # newline
226
            bert_toks.append(tmp_p)
227
            bert_labels.append(tmp_l)
228
            bert_logits.append(tmp_s)
229
            tmp_p = list()
230
            tmp_l = list()
231
            tmp_s = list()
232
            continue
233
        elif t[:2] == '##':  # if it is a piece of a word (broken by Word Piece tokenizer)
234
            tmp_p[-1] = tmp_p[-1] + t[2:]  # append pieces
235
        elif t.startswith('Ġ'): # roberta tokenizer
236
            t = t.replace('Ġ', ' ')
237
            tmp_p[-1] = tmp_p[-1] + t
238
        else:
239
            tmp_p.append(t)
240
            tmp_l.append(l)
241
            tmp_s.append(s)
242
    return bert_toks, bert_labels, bert_logits
243
244
# https://stackoverflow.com/a/3620972
245
PROF_DATA = {}
246
247
class Profile(object):
248
    def __init__(self, prefix):
249
        self.prefix = prefix
250
251
    def __call__(self, fn):
252
        def with_profiling(*args, **kwargs):
253
            global PROF_DATA
254
            start_time = time.time()
255
            ret = fn(*args, **kwargs)
256
257
            elapsed_time = time.time() - start_time
258
            key = '[' + self.prefix + '].' + fn.__name__
259
260
            if key not in PROF_DATA:
261
                PROF_DATA[key] = [0, list()]
262
            PROF_DATA[key][0] += 1
263
            PROF_DATA[key][1].append(elapsed_time)
264
265
            return ret
266
267
        return with_profiling
268
269
def show_prof_data():
270
    for fname, data in sorted(PROF_DATA.items()):
271
        max_time = max(data[1])
272
        avg_time = sum(data[1]) / len(data[1])
273
        total_time = sum(data[1])
274
        print("\n{} -> called {} times".format(fname, data[0]))
275
        print("Time total: {:.3f}, max: {:.3f}, avg: {:.3f}".format(
276
            total_time, max_time, avg_time))
277
278
def clear_prof_data():
279
    global PROF_DATA
280
    PROF_DATA = {}
281
282
# Ref. dict of SR4GN
283
species_human_excl_homo_sapiens = \
284
    'person|infant|Child|people|participants|woman|' \
285
    'Girls|Man|Peoples|Men|Participant|Patients|' \
286
    'humans|Persons|mans|participant|Infants|Boys|' \
287
    'Human|Humans|Women|children|Mans|child|Participants|Girl|' \
288
    'Infant|girl|patient|patients|boys|men|infants|' \
289
    'man|girls|Children|Boy|women|persons|human|Woman|' \
290
    'peoples|Patient|People|boy|Person'.split('|')
291
    
292
def filter_entities(ner_results):
293
    num_filtered_species_per_doc = list()
294
295
    for idx, paper in enumerate(ner_results):
296
297
        if len(paper['abstract']) > 0:
298
            content = paper['title'] + ' ' + paper['abstract']
299
        else:
300
            content = paper['title']
301
302
        valid_species = list()
303
        species = paper['entities']['species']
304
        for spcs in species:
305
            entity_mention = content[spcs['start']:spcs['end']+1]
306
            if entity_mention in species_human_excl_homo_sapiens:
307
                spcs['end'] += 1
308
                continue
309
            valid_species.append(spcs)
310
311
        num_filtered_species = len(species) - len(valid_species)
312
        if num_filtered_species > 0:
313
            paper['entities']['species'] = valid_species
314
315
        num_filtered_species_per_doc.append((paper['pmid'],
316
                                             num_filtered_species))
317
318
    return num_filtered_species_per_doc
319
320
# from convert.py
321
def pubtator2dict_list(pubtator_file_path):
322
    dict_list = list()
323
324
    title_pmid = ''
325
    # abstract_pmid = ''
326
    title = ''
327
    abstract_text = ''
328
    doc_line_num = 0
329
330
    with open(pubtator_file_path, 'r', encoding='utf-8') as f:
331
        for line in f:
332
            line = line.rstrip()
333
            if len(line) == 0:
334
                               
335
                doc_dict = {
336
                    'pmid': title_pmid,
337
                    'entities': {},
338
                }
339
                doc_dict['title'] = title
340
                doc_dict['abstract'] = abstract_text
341
342
                dict_list.append(doc_dict)
343
344
                doc_line_num = 0
345
                continue
346
347
            if doc_line_num == 0:
348
                title_cols = line.split('|t|')
349
350
                if len(title_cols) != 2:
351
                    return '{"error": "wrong #title_cols {}"}'\
352
                        .format(len(title_cols))
353
354
                title_pmid = title_cols[0]
355
356
                if '- No text -' == title_cols[1]:
357
                    # make tmvar2 results empty
358
                    title = ''
359
                else:
360
                    title = title_cols[1]
361
            elif doc_line_num == 1:
362
                abstract_cols = line.split('|a|')
363
364
                if len(abstract_cols) != 2:
365
                    if len(abstract_cols) > 2:
366
                        abstract_text = "|a|".join(abstract_cols[1:])
367
                    else:
368
                        return '{"error": "wrong #abstract_cols {}"}'.format(len(abstract_cols))
369
                else:
370
                    if '- No text -' == abstract_cols[1]:
371
                        # make tmvar2 results empty
372
                        abstract_text = ''
373
                    else:
374
                        abstract_text = abstract_cols[1]
375
376
            doc_line_num += 1
377
    return dict_list
378
379
def preprocess(text):
380
    text = text.replace('\r ', ' ')
381
382
    text = text.replace('\u2028', ' ')
383
    text = text.replace('\u2029', ' ')
384
385
    # HAIR SPACE
386
    # https://www.fileformat.info/info/unicode/char/200a/index.htm
387
    text = text.replace('\u200A', ' ')
388
389
    # THIN SPACE
390
    # https://www.fileformat.info/info/unicode/char/2009/index.htm
391
    text = text.replace('\u2009', ' ')
392
    text = text.replace('\u2008', ' ')
393
394
    # FOUR-PER-EM SPACE
395
    # https://www.fileformat.info/info/unicode/char/2005/index.htm
396
    text = text.replace('\u2005', ' ')
397
    text = text.replace('\u2004', ' ')
398
    text = text.replace('\u2003', ' ')
399
400
    # EN SPACE
401
    # https://www.fileformat.info/info/unicode/char/2002/index.htm
402
    text = text.replace('\u2002', ' ')
403
404
    # NO-BREAK SPACE
405
    # https://www.fileformat.info/info/unicode/char/00a0/index.htm
406
    text = text.replace('\u00A0', ' ')
407
408
    # https://www.fileformat.info/info/unicode/char/f8ff/index.htm
409
    text = text.replace('\uF8FF', ' ')
410
411
    # https://www.fileformat.info/info/unicode/char/202f/index.htm
412
    text = text.replace('\u202F', ' ')
413
414
    text = text.replace('\uFEFF', ' ')
415
    text = text.replace('\uF044', ' ')
416
    text = text.replace('\uF02D', ' ')
417
    text = text.replace('\uF0BB', ' ')
418
419
    text = text.replace('\uF048', 'Η')
420
    text = text.replace('\uF0B0', '°')
421
422
    # MIDLINE HORIZONTAL ELLIPSIS: ⋯
423
    # https://www.fileformat.info/info/unicode/char/22ef/index.htm
424
    # text = text.replace('\u22EF', '...')
425
426
    return text