Diff of /pretrain/data_util.py [000000] .. [c3444c]

Switch to unified view

a b/pretrain/data_util.py
1
import os
2
import numpy as np
3
import pandas as pd
4
from transformers import AutoTokenizer
5
from load_umls import UMLS
6
from torch.utils.data import Dataset, DataLoader
7
from random import sample
8
from sampler_util import FixedLengthBatchSampler, my_collate_fn
9
from torch.utils.data.sampler import RandomSampler
10
import ipdb
11
from time import time
12
import json
13
14
15
def pad(list_ids, pad_length, pad_mark=0):
16
    output = []
17
    for l in list_ids:
18
        if len(l) > pad_length:
19
            output.append(l[0:pad_length])
20
        else:
21
            output.append(l + [pad_mark] * (pad_length - len(l)))
22
    return output
23
24
25
def my_sample(lst, lst_length, start, length):
26
    start = start % lst_length
27
    if start + length < lst_length:
28
        return lst[start:start + length]
29
    return lst[start:] + lst[0:start + length - lst_length]
30
31
32
class UMLSDataset(Dataset):
33
    def __init__(self, umls_folder, model_name_or_path, lang, json_save_path=None, max_lui_per_cui=8, max_length=32):
34
        self.umls = UMLS(umls_folder, lang_range=lang)
35
        self.len = len(self.umls.rel)
36
        self.max_lui_per_cui = max_lui_per_cui
37
        self.max_length = max_length
38
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
39
        self.json_save_path = json_save_path
40
        self.calculate_class_count()
41
42
    def calculate_class_count(self):
43
        print("Calculate class count")
44
45
        self.cui2id = {cui: index for index,
46
                       cui in enumerate(self.umls.cui2str.keys())}
47
48
        self.re_set = set()
49
        self.rel_set = set()
50
        for r in self.umls.rel:
51
            _, _, re, rel = r.split("\t")
52
            self.re_set.update([re])
53
            self.rel_set.update([rel])
54
        self.re_set = list(self.re_set)
55
        self.rel_set = list(self.rel_set)
56
        self.re_set.sort()
57
        self.rel_set.sort()
58
59
        self.re2id = {re: index for index, re in enumerate(self.re_set)}
60
        self.rel2id = {rel: index for index, rel in enumerate(self.rel_set)}
61
62
        sty_list = list(set(self.umls.cui2sty.values()))
63
        sty_list.sort()
64
        self.sty2id = {sty: index for index, sty in enumerate(sty_list)}
65
66
        if self.json_save_path:
67
            with open(os.path.join(self.json_save_path, "re2id.json"), "w") as f:
68
                json.dump(self.re2id, f)
69
            with open(os.path.join(self.json_save_path, "rel2id.json"), "w") as f:
70
                json.dump(self.rel2id, f)
71
            with open(os.path.join(self.json_save_path, "sty2id.json"), "w") as f:
72
                json.dump(self.sty2id, f)
73
74
        print("CUI:", len(self.cui2id))
75
        print("RE:", len(self.re2id))
76
        print("REL:", len(self.rel2id))
77
        print("STY:", len(self.sty2id))
78
79
    def tokenize_one(self, string):
80
        return self.tokenizer.encode_plus(string, max_length=self.max_length, truncation=True)['input_ids']
81
82
    # @profile
83
    def __getitem__(self, index):
84
        cui0, cui1, re, rel = self.umls.rel[index].split("\t")
85
86
        str0_list = list(self.umls.cui2str[cui0])
87
        str1_list = list(self.umls.cui2str[cui1])
88
        if len(str0_list) > self.max_lui_per_cui:
89
            str0_list = sample(str0_list, self.max_lui_per_cui)
90
        if len(str1_list) > self.max_lui_per_cui:
91
            str1_list = sample(str1_list, self.max_lui_per_cui)
92
        use_len = min(len(str0_list), len(str1_list))
93
        str0_list = str0_list[0:use_len]
94
        str1_list = str1_list[0:use_len]
95
96
        sty0_index = self.sty2id[self.umls.cui2sty[cui0]]
97
        sty1_index = self.sty2id[self.umls.cui2sty[cui1]]
98
99
        str2_list = []
100
        cui2_index_list = []
101
        sty2_index_list = []
102
103
        cui2 = my_sample(self.umls.cui, self.umls.cui_count,
104
                         index * self.max_lui_per_cui, use_len * 2)
105
        sample_index = 0
106
        while len(str2_list) < use_len:
107
            if sample_index < len(cui2):
108
                use_cui2 = cui2[sample_index]
109
            else:
110
                sample_index = 0
111
                cui2 = my_sample(self.umls.cui, self.umls.cui_count,
112
                                 index * self.max_lui_per_cui, use_len * 2)
113
                use_cui2 = cui2[sample_index]
114
            # if not "\t".join([cui0, use_cui2, re, rel]) in self.umls.rel: # TOO SLOW!
115
            if True:
116
                cui2_index_list.append(self.cui2id[use_cui2])
117
                sty2_index_list.append(
118
                    self.sty2id[self.umls.cui2sty[use_cui2]])
119
                str2_list.append(sample(self.umls.cui2str[use_cui2], 1)[0])
120
                sample_index += 1
121
122
        # print(str0_list)
123
        # print(str1_list)
124
        # print(str2_list)
125
126
        input_ids = [self.tokenize_one(s)
127
                     for s in str0_list + str1_list + str2_list]
128
        input_ids = pad(input_ids, self.max_length)
129
        input_ids_0 = input_ids[0:use_len]
130
        input_ids_1 = input_ids[use_len:2 * use_len]
131
        input_ids_2 = input_ids[2 * use_len:]
132
133
        cui0_index = self.cui2id[cui0]
134
        cui1_index = self.cui2id[cui1]
135
136
        re_index = self.re2id[re]
137
        rel_index = self.rel2id[rel]
138
        return input_ids_0, input_ids_1, input_ids_2, \
139
            [cui0_index] * use_len, [cui1_index] * use_len, cui2_index_list, \
140
            [sty0_index] * use_len, [sty1_index] * use_len, sty2_index_list, \
141
            [re_index] * use_len, \
142
            [rel_index] * use_len
143
144
    def __len__(self):
145
        return self.len
146
147
148
def fixed_length_dataloader(umls_dataset, fixed_length=96, num_workers=0):
149
    base_sampler = RandomSampler(umls_dataset)
150
    batch_sampler = FixedLengthBatchSampler(
151
        sampler=base_sampler, fixed_length=fixed_length, drop_last=True)
152
    dataloader = DataLoader(umls_dataset, batch_sampler=batch_sampler,
153
                            collate_fn=my_collate_fn, num_workers=num_workers, pin_memory=True)
154
    return dataloader
155
156
157
if __name__ == "__main__":
158
    umls_dataset = UMLSDataset(umls_folder="../umls",
159
                               model_name_or_path="../biobert_v1.1",
160
                               lang=None)
161
    ipdb.set_trace()
162
    umls_dataloader = fixed_length_dataloader(umls_dataset, num_workers=4)
163
    now_time = time()
164
    for index, batch in enumerate(umls_dataloader):
165
        print(time() - now_time)
166
        now_time = time()
167
        if index < 10:
168
            for item in batch:
169
                print(item.shape)
170
            #print(batch)
171
        else:
172
            import sys
173
            sys.exit()