--- a +++ b/src/Parser/ops.py @@ -0,0 +1,426 @@ +import re +import copy +import json +import time + +import numpy as np +import xml.etree.ElementTree as ElTree + +from datetime import datetime, timezone +from operator import itemgetter + + +tokenize_regex = re.compile(r'([0-9a-zA-Z]+|[^0-9a-zA-Z])') + +def json_to_sent(data): + '''data: list of json file [{pmid,abstract,title}, ...] ''' + out = dict() + for paper in data: + sentences = list() + + if len(CoNLL_tokenizer(paper['title'])) < 50: + title = [paper['title']] + else: + title = sentence_split(paper['title']) + if len(title) != 1 or len(title[0].strip()) > 0: + sentences.extend(title) + + if len(paper['abstract']) > 0: + abst = sentence_split(paper['abstract']) + if len(abst) != 1 or len(abst[0].strip()) > 0: + sentences.extend(abst) + out[paper['pmid']] = dict() + out[paper['pmid']]['sentence'] = sentences + return out + +def input_form(sent_data): + '''sent_data: dict of sentence, key=pmid {pmid:[sent,sent, ...], pmid: ...}''' + for pmid in sent_data: + sent_data[pmid]['words'] = list() + sent_data[pmid]['wordPos'] = list() + doc_piv = 0 + for sent in sent_data[pmid]['sentence']: + wids = list() + wpos = list() + sent_piv = 0 + tok = CoNLL_tokenizer(sent) + + for w in tok: + if len(w) > 20: + wids.append(w[:10]) + else: + wids.append(w) + + start = doc_piv + sent_piv + sent[sent_piv:].find(w) + end = start + len(w) - 1 + sent_piv = end - doc_piv + 1 + wpos.append((start, end)) + doc_piv += len(sent) + sent_data[pmid]['words'].append(wids) + sent_data[pmid]['wordPos'].append(wpos) + + return sent_data + +def softmax(logits): + out = list() + for logit in logits: + temp = np.subtract(logit, np.max(logit)) + p = np.exp(temp) / np.sum(np.exp(temp)) + out.append(np.max(p)) + return out + +def CoNLL_tokenizer(text): + rawTok = [t for t in tokenize_regex.split(text) if t] + assert ''.join(rawTok) == text + tok = [t for t in rawTok if t != ' '] + return tok + +def sentence_split(text): + sentences = list() + sent = '' + piv = 0 + for idx, char in enumerate(text): + if char in "?!": + if idx > len(text) - 3: + sent = text[piv:] + piv = -1 + else: + sent = text[piv:idx + 1] + piv = idx + 1 + + elif char == '.': + if idx > len(text) - 3: + sent = text[piv:] + piv = -1 + elif (text[idx + 1] == ' ') and ( + text[idx + 2] in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-"' + "'"): + sent = text[piv:idx + 1] + piv = idx + 1 + + if sent != '': + toks = CoNLL_tokenizer(sent) + if len(toks) > 100: + while True: + rawTok = [t for t in tokenize_regex.split(sent) if t] + cut = ''.join(rawTok[:200]) + sent = ''.join(rawTok[200:]) + sentences.append(cut) + + if len(CoNLL_tokenizer(sent)) < 100: + if sent.strip() == '': + sent = '' + break + else: + sentences.append(sent) + sent = '' + break + else: + sentences.append(sent) + sent = '' + + if piv == -1: + break + + if piv != -1: + sent = text[piv:] + toks = CoNLL_tokenizer(sent) + if len(toks) > 100: + while True: + rawTok = [t for t in tokenize_regex.split(sent) if t] + cut = ''.join(rawTok[:200]) + sent = ''.join(rawTok[200:]) + sentences.append(cut) + + if len(CoNLL_tokenizer(sent)) < 100: + if sent.strip() == '': + sent = '' + break + else: + sentences.append(sent) + sent = '' + break + else: + sentences.append(sent) + sent = '' + + return sentences + +def get_prob(data, sent_data, predicDict, logitsDict, entity_types=None): + for idx, paper in enumerate(data): + pmid = paper['pmid'] + + if len(paper['abstract']) > 0: + content = paper['title'] + ' ' + paper['abstract'] + else: + content = paper['title'] + + for ent_type in entity_types: + paper['entities'][ent_type] = [] + paper['prob'] = dict() + + for dtype in entity_types: + for sentidx, tags in enumerate(predicDict[dtype][pmid]): + B_flag = False + # get position of entity corresponding to types + for widx, tag in enumerate(tags): + if tag == 'O': + if B_flag: + tmpSE["end"] = \ + sent_data[pmid]['wordPos'][sentidx][widx - 1][1] + paper['entities'][dtype].append(tmpSE) + B_flag = False + continue + elif tag == 'B': + if B_flag: + tmpSE["end"] = \ + sent_data[pmid]['wordPos'][sentidx][widx - 1][1] + paper['entities'][dtype].append(tmpSE) + tmpSE = { + "start": sent_data[pmid]['wordPos'][sentidx][widx][ + 0]} + B_flag = True + elif tag == "I": + continue + if B_flag: + tmpSE["end"] = sent_data[pmid]['wordPos'][sentidx][-1][1] + paper['entities'][dtype].append(tmpSE) + + # get prob. of entity logits corresponding to types + logs = list() + for t_sent in logitsDict[dtype][pmid]: + logs.extend(t_sent) + paper['prob'][dtype] = list() + for pos in paper['entities'][dtype]: + if pos['start'] == pos['end']: + soft = softmax(logs[len( + CoNLL_tokenizer(content[:pos['start']])):len( + CoNLL_tokenizer(content[:pos['end']])) + 1]) + paper['prob'][dtype].append( + (pos, float(np.average(soft)))) + else: + soft = softmax(logs[len( + CoNLL_tokenizer(content[:pos['start']])):len( + CoNLL_tokenizer(content[:pos['end']]))]) + paper['prob'][dtype].append( + (pos, float(np.average(soft)))) + + return data + +def detokenize(tokens, predicts, logits): + pred = dict({ + 'toks': tokens[:], + 'labels': predicts[:], + 'logit': logits[:] + }) # dictionary for predicted tokens and labels. + + bert_toks = list() + bert_labels = list() + bert_logits = list() + tmp_p = list() + tmp_l = list() + tmp_s = list() + for t, l, s in zip(pred['toks'], pred['labels'], pred['logit']): + if t == '[CLS]' or t == '<s>': # non-text tokens will not be evaluated. + continue + elif t == '[SEP]' or t == '</s>': # newline + bert_toks.append(tmp_p) + bert_labels.append(tmp_l) + bert_logits.append(tmp_s) + tmp_p = list() + tmp_l = list() + tmp_s = list() + continue + elif t[:2] == '##': # if it is a piece of a word (broken by Word Piece tokenizer) + tmp_p[-1] = tmp_p[-1] + t[2:] # append pieces + elif t.startswith('Ġ'): # roberta tokenizer + t = t.replace('Ġ', ' ') + tmp_p[-1] = tmp_p[-1] + t + else: + tmp_p.append(t) + tmp_l.append(l) + tmp_s.append(s) + return bert_toks, bert_labels, bert_logits + +# https://stackoverflow.com/a/3620972 +PROF_DATA = {} + +class Profile(object): + def __init__(self, prefix): + self.prefix = prefix + + def __call__(self, fn): + def with_profiling(*args, **kwargs): + global PROF_DATA + start_time = time.time() + ret = fn(*args, **kwargs) + + elapsed_time = time.time() - start_time + key = '[' + self.prefix + '].' + fn.__name__ + + if key not in PROF_DATA: + PROF_DATA[key] = [0, list()] + PROF_DATA[key][0] += 1 + PROF_DATA[key][1].append(elapsed_time) + + return ret + + return with_profiling + +def show_prof_data(): + for fname, data in sorted(PROF_DATA.items()): + max_time = max(data[1]) + avg_time = sum(data[1]) / len(data[1]) + total_time = sum(data[1]) + print("\n{} -> called {} times".format(fname, data[0])) + print("Time total: {:.3f}, max: {:.3f}, avg: {:.3f}".format( + total_time, max_time, avg_time)) + +def clear_prof_data(): + global PROF_DATA + PROF_DATA = {} + +# Ref. dict of SR4GN +species_human_excl_homo_sapiens = \ + 'person|infant|Child|people|participants|woman|' \ + 'Girls|Man|Peoples|Men|Participant|Patients|' \ + 'humans|Persons|mans|participant|Infants|Boys|' \ + 'Human|Humans|Women|children|Mans|child|Participants|Girl|' \ + 'Infant|girl|patient|patients|boys|men|infants|' \ + 'man|girls|Children|Boy|women|persons|human|Woman|' \ + 'peoples|Patient|People|boy|Person'.split('|') + +def filter_entities(ner_results): + num_filtered_species_per_doc = list() + + for idx, paper in enumerate(ner_results): + + if len(paper['abstract']) > 0: + content = paper['title'] + ' ' + paper['abstract'] + else: + content = paper['title'] + + valid_species = list() + species = paper['entities']['species'] + for spcs in species: + entity_mention = content[spcs['start']:spcs['end']+1] + if entity_mention in species_human_excl_homo_sapiens: + spcs['end'] += 1 + continue + valid_species.append(spcs) + + num_filtered_species = len(species) - len(valid_species) + if num_filtered_species > 0: + paper['entities']['species'] = valid_species + + num_filtered_species_per_doc.append((paper['pmid'], + num_filtered_species)) + + return num_filtered_species_per_doc + +# from convert.py +def pubtator2dict_list(pubtator_file_path): + dict_list = list() + + title_pmid = '' + # abstract_pmid = '' + title = '' + abstract_text = '' + doc_line_num = 0 + + with open(pubtator_file_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.rstrip() + if len(line) == 0: + + doc_dict = { + 'pmid': title_pmid, + 'entities': {}, + } + doc_dict['title'] = title + doc_dict['abstract'] = abstract_text + + dict_list.append(doc_dict) + + doc_line_num = 0 + continue + + if doc_line_num == 0: + title_cols = line.split('|t|') + + if len(title_cols) != 2: + return '{"error": "wrong #title_cols {}"}'\ + .format(len(title_cols)) + + title_pmid = title_cols[0] + + if '- No text -' == title_cols[1]: + # make tmvar2 results empty + title = '' + else: + title = title_cols[1] + elif doc_line_num == 1: + abstract_cols = line.split('|a|') + + if len(abstract_cols) != 2: + if len(abstract_cols) > 2: + abstract_text = "|a|".join(abstract_cols[1:]) + else: + return '{"error": "wrong #abstract_cols {}"}'.format(len(abstract_cols)) + else: + if '- No text -' == abstract_cols[1]: + # make tmvar2 results empty + abstract_text = '' + else: + abstract_text = abstract_cols[1] + + doc_line_num += 1 + return dict_list + +def preprocess(text): + text = text.replace('\r ', ' ') + + text = text.replace('\u2028', ' ') + text = text.replace('\u2029', ' ') + + # HAIR SPACE + # https://www.fileformat.info/info/unicode/char/200a/index.htm + text = text.replace('\u200A', ' ') + + # THIN SPACE + # https://www.fileformat.info/info/unicode/char/2009/index.htm + text = text.replace('\u2009', ' ') + text = text.replace('\u2008', ' ') + + # FOUR-PER-EM SPACE + # https://www.fileformat.info/info/unicode/char/2005/index.htm + text = text.replace('\u2005', ' ') + text = text.replace('\u2004', ' ') + text = text.replace('\u2003', ' ') + + # EN SPACE + # https://www.fileformat.info/info/unicode/char/2002/index.htm + text = text.replace('\u2002', ' ') + + # NO-BREAK SPACE + # https://www.fileformat.info/info/unicode/char/00a0/index.htm + text = text.replace('\u00A0', ' ') + + # https://www.fileformat.info/info/unicode/char/f8ff/index.htm + text = text.replace('\uF8FF', ' ') + + # https://www.fileformat.info/info/unicode/char/202f/index.htm + text = text.replace('\u202F', ' ') + + text = text.replace('\uFEFF', ' ') + text = text.replace('\uF044', ' ') + text = text.replace('\uF02D', ' ') + text = text.replace('\uF0BB', ' ') + + text = text.replace('\uF048', 'Η') + text = text.replace('\uF0B0', '°') + + # MIDLINE HORIZONTAL ELLIPSIS: ⋯ + # https://www.fileformat.info/info/unicode/char/22ef/index.htm + # text = text.replace('\u22EF', '...') + + return text