|
a |
|
b/coderpp/train/dataset.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import pandas as pd |
|
|
4 |
from transformers import AutoTokenizer |
|
|
5 |
from load_umls import UMLS |
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
from random import sample |
|
|
8 |
from torch.utils.data.sampler import RandomSampler |
|
|
9 |
# import ipdb |
|
|
10 |
from time import time |
|
|
11 |
import json |
|
|
12 |
import pickle |
|
|
13 |
import ahocorasick |
|
|
14 |
import torch |
|
|
15 |
|
|
|
16 |
class UMLSDataset(Dataset): |
|
|
17 |
def __init__(self, umls_folder='../umls', model_name_or_path='GanjinZero/UMLSBert_ENG', idx2phrase_path='data/idx2string.pkl', phrase2idx_path='data/string2idx.pkl', indices_path='data/indices.npy', max_length=32): |
|
|
18 |
super().__init__() |
|
|
19 |
self.umls = UMLS(umls_folder, phrase2idx_path=phrase2idx_path, idx2phrase_path=idx2phrase_path) |
|
|
20 |
self.indices = np.load(indices_path) |
|
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
22 |
self.cui2idx = {cui: idx for idx, cui in enumerate(self.umls.cui2stridx.keys())} |
|
|
23 |
self.idx2phrase = self._load_pickle(idx2phrase_path) |
|
|
24 |
self.max_length = max_length |
|
|
25 |
|
|
|
26 |
def _load_pickle(self, path): |
|
|
27 |
with open(path, 'rb') as f: |
|
|
28 |
return pickle.load(f) |
|
|
29 |
|
|
|
30 |
def tokenize_one(self, string): |
|
|
31 |
tokenized = self.tokenizer.encode_plus(string, |
|
|
32 |
max_length=self.max_length, |
|
|
33 |
truncation=True, |
|
|
34 |
pad_to_max_length=True, |
|
|
35 |
add_special_tokens=True) |
|
|
36 |
return tokenized['input_ids'], tokenized['attention_mask'] |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def __getitem__(self, index): |
|
|
40 |
input_str_list = [] # [current_str, top30_str, 30*rand_same_cui_str] |
|
|
41 |
current_str_idx = self.umls.stridx_list[index] |
|
|
42 |
input_str_list.append(self.idx2phrase[current_str_idx]) |
|
|
43 |
input_str_list = input_str_list + [self.idx2phrase[idx] for idx in self.indices[current_str_idx]] |
|
|
44 |
current_cui = self.umls.str2cui[self.idx2phrase[current_str_idx]] |
|
|
45 |
stridx_set_for_current_cui = self.umls.cui2stridx[current_cui] |
|
|
46 |
idx_list = sample(stridx_set_for_current_cui - {current_str_idx}, min(30, len(stridx_set_for_current_cui) - 1)) |
|
|
47 |
if len(idx_list) < 30: |
|
|
48 |
idx_list += [current_str_idx] * (30 - len(idx_list)) |
|
|
49 |
input_str_list += [self.idx2phrase[idx] for idx in idx_list] |
|
|
50 |
input_cui_idx_list = [self.cui2idx[self.umls.str2cui[s]] for s in input_str_list] |
|
|
51 |
input_ids = [self.tokenize_one(s)[0] for s in input_str_list] |
|
|
52 |
attention_mask = [self.tokenize_one(s)[1] for s in input_str_list] |
|
|
53 |
return input_ids, input_cui_idx_list, attention_mask |
|
|
54 |
|
|
|
55 |
def __len__(self): |
|
|
56 |
return len(self.umls.stridx_list) |
|
|
57 |
|
|
|
58 |
def my_collate_fn(batch): |
|
|
59 |
output_ids = torch.LongTensor([sample[0] for sample in batch]) |
|
|
60 |
output_label = torch.LongTensor([sample[1] for sample in batch]) |
|
|
61 |
output_attention_mask = torch.LongTensor([sample[2] for sample in batch]) |
|
|
62 |
output_ids = output_ids.reshape(output_ids.shape[0] * output_ids.shape[1], output_ids.shape[2]) |
|
|
63 |
output_label = output_label.reshape(output_label.shape[0] * output_label.shape[1], ) |
|
|
64 |
output_attention_mask = output_attention_mask.reshape(output_attention_mask.shape[0] * output_attention_mask.shape[1], output_attention_mask.shape[2]) |
|
|
65 |
return output_ids, output_label, output_attention_mask |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
if __name__ == '__main__': |
|
|
69 |
umls_dataset = UMLSDataset() |
|
|
70 |
print(umls_dataset[400]) |
|
|
71 |
print(len(umls_dataset[400][0])) |
|
|
72 |
umls_dataloader = DataLoader(umls_dataset, |
|
|
73 |
batch_size=5, |
|
|
74 |
shuffle=True, |
|
|
75 |
num_workers=1, |
|
|
76 |
pin_memory=True, |
|
|
77 |
drop_last=True, |
|
|
78 |
collate_fn=my_collate_fn) |
|
|
79 |
data, label, mask = next(iter(umls_dataloader)) |
|
|
80 |
print(data.shape) |
|
|
81 |
print(label.shape) |
|
|
82 |
print(mask.shape) |