a b/coderpp/train/dataset.py
1
import os
2
import numpy as np
3
import pandas as pd
4
from transformers import AutoTokenizer
5
from load_umls import UMLS
6
from torch.utils.data import Dataset, DataLoader
7
from random import sample
8
from torch.utils.data.sampler import RandomSampler
9
# import ipdb
10
from time import time
11
import json
12
import pickle
13
import ahocorasick
14
import torch
15
16
class UMLSDataset(Dataset):
17
    def __init__(self, umls_folder='../umls', model_name_or_path='GanjinZero/UMLSBert_ENG', idx2phrase_path='data/idx2string.pkl', phrase2idx_path='data/string2idx.pkl', indices_path='data/indices.npy', max_length=32):
18
        super().__init__()
19
        self.umls = UMLS(umls_folder, phrase2idx_path=phrase2idx_path, idx2phrase_path=idx2phrase_path)
20
        self.indices = np.load(indices_path)
21
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
22
        self.cui2idx = {cui: idx for idx, cui in enumerate(self.umls.cui2stridx.keys())}
23
        self.idx2phrase = self._load_pickle(idx2phrase_path)
24
        self.max_length = max_length
25
26
    def _load_pickle(self, path):
27
        with open(path, 'rb') as f:
28
            return pickle.load(f)
29
    
30
    def tokenize_one(self, string):
31
        tokenized = self.tokenizer.encode_plus(string, 
32
                                               max_length=self.max_length, 
33
                                               truncation=True, 
34
                                               pad_to_max_length=True, 
35
                                               add_special_tokens=True)
36
        return tokenized['input_ids'], tokenized['attention_mask']
37
38
39
    def __getitem__(self, index):
40
        input_str_list = []     # [current_str, top30_str, 30*rand_same_cui_str]
41
        current_str_idx = self.umls.stridx_list[index]
42
        input_str_list.append(self.idx2phrase[current_str_idx])
43
        input_str_list = input_str_list + [self.idx2phrase[idx] for idx in self.indices[current_str_idx]]
44
        current_cui = self.umls.str2cui[self.idx2phrase[current_str_idx]]
45
        stridx_set_for_current_cui = self.umls.cui2stridx[current_cui]
46
        idx_list = sample(stridx_set_for_current_cui - {current_str_idx}, min(30, len(stridx_set_for_current_cui) - 1))
47
        if len(idx_list) < 30:
48
            idx_list += [current_str_idx] * (30 - len(idx_list))
49
        input_str_list += [self.idx2phrase[idx] for idx in idx_list]
50
        input_cui_idx_list = [self.cui2idx[self.umls.str2cui[s]] for s in input_str_list]
51
        input_ids = [self.tokenize_one(s)[0] for s in input_str_list]
52
        attention_mask = [self.tokenize_one(s)[1] for s in input_str_list]
53
        return input_ids, input_cui_idx_list, attention_mask
54
    
55
    def __len__(self):
56
        return len(self.umls.stridx_list)
57
58
def my_collate_fn(batch):
59
    output_ids = torch.LongTensor([sample[0] for sample in batch])
60
    output_label = torch.LongTensor([sample[1] for sample in batch])
61
    output_attention_mask = torch.LongTensor([sample[2] for sample in batch])
62
    output_ids = output_ids.reshape(output_ids.shape[0] * output_ids.shape[1], output_ids.shape[2])
63
    output_label = output_label.reshape(output_label.shape[0] * output_label.shape[1], )
64
    output_attention_mask = output_attention_mask.reshape(output_attention_mask.shape[0] * output_attention_mask.shape[1], output_attention_mask.shape[2])
65
    return output_ids, output_label, output_attention_mask
66
67
    
68
if __name__ == '__main__':
69
    umls_dataset = UMLSDataset()
70
    print(umls_dataset[400])
71
    print(len(umls_dataset[400][0]))
72
    umls_dataloader = DataLoader(umls_dataset,
73
                                 batch_size=5, 
74
                                 shuffle=True,
75
                                 num_workers=1, 
76
                                 pin_memory=True, 
77
                                 drop_last=True,
78
                                 collate_fn=my_collate_fn)
79
    data, label, mask = next(iter(umls_dataloader))
80
    print(data.shape)
81
    print(label.shape)
82
    print(mask.shape)