--- a
+++ b/pretrain/data_util.py
@@ -0,0 +1,173 @@
+import os
+import numpy as np
+import pandas as pd
+from transformers import AutoTokenizer
+from load_umls import UMLS
+from torch.utils.data import Dataset, DataLoader
+from random import sample
+from sampler_util import FixedLengthBatchSampler, my_collate_fn
+from torch.utils.data.sampler import RandomSampler
+import ipdb
+from time import time
+import json
+
+
+def pad(list_ids, pad_length, pad_mark=0):
+    output = []
+    for l in list_ids:
+        if len(l) > pad_length:
+            output.append(l[0:pad_length])
+        else:
+            output.append(l + [pad_mark] * (pad_length - len(l)))
+    return output
+
+
+def my_sample(lst, lst_length, start, length):
+    start = start % lst_length
+    if start + length < lst_length:
+        return lst[start:start + length]
+    return lst[start:] + lst[0:start + length - lst_length]
+
+
+class UMLSDataset(Dataset):
+    def __init__(self, umls_folder, model_name_or_path, lang, json_save_path=None, max_lui_per_cui=8, max_length=32):
+        self.umls = UMLS(umls_folder, lang_range=lang)
+        self.len = len(self.umls.rel)
+        self.max_lui_per_cui = max_lui_per_cui
+        self.max_length = max_length
+        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+        self.json_save_path = json_save_path
+        self.calculate_class_count()
+
+    def calculate_class_count(self):
+        print("Calculate class count")
+
+        self.cui2id = {cui: index for index,
+                       cui in enumerate(self.umls.cui2str.keys())}
+
+        self.re_set = set()
+        self.rel_set = set()
+        for r in self.umls.rel:
+            _, _, re, rel = r.split("\t")
+            self.re_set.update([re])
+            self.rel_set.update([rel])
+        self.re_set = list(self.re_set)
+        self.rel_set = list(self.rel_set)
+        self.re_set.sort()
+        self.rel_set.sort()
+
+        self.re2id = {re: index for index, re in enumerate(self.re_set)}
+        self.rel2id = {rel: index for index, rel in enumerate(self.rel_set)}
+
+        sty_list = list(set(self.umls.cui2sty.values()))
+        sty_list.sort()
+        self.sty2id = {sty: index for index, sty in enumerate(sty_list)}
+
+        if self.json_save_path:
+            with open(os.path.join(self.json_save_path, "re2id.json"), "w") as f:
+                json.dump(self.re2id, f)
+            with open(os.path.join(self.json_save_path, "rel2id.json"), "w") as f:
+                json.dump(self.rel2id, f)
+            with open(os.path.join(self.json_save_path, "sty2id.json"), "w") as f:
+                json.dump(self.sty2id, f)
+
+        print("CUI:", len(self.cui2id))
+        print("RE:", len(self.re2id))
+        print("REL:", len(self.rel2id))
+        print("STY:", len(self.sty2id))
+
+    def tokenize_one(self, string):
+        return self.tokenizer.encode_plus(string, max_length=self.max_length, truncation=True)['input_ids']
+
+    # @profile
+    def __getitem__(self, index):
+        cui0, cui1, re, rel = self.umls.rel[index].split("\t")
+
+        str0_list = list(self.umls.cui2str[cui0])
+        str1_list = list(self.umls.cui2str[cui1])
+        if len(str0_list) > self.max_lui_per_cui:
+            str0_list = sample(str0_list, self.max_lui_per_cui)
+        if len(str1_list) > self.max_lui_per_cui:
+            str1_list = sample(str1_list, self.max_lui_per_cui)
+        use_len = min(len(str0_list), len(str1_list))
+        str0_list = str0_list[0:use_len]
+        str1_list = str1_list[0:use_len]
+
+        sty0_index = self.sty2id[self.umls.cui2sty[cui0]]
+        sty1_index = self.sty2id[self.umls.cui2sty[cui1]]
+
+        str2_list = []
+        cui2_index_list = []
+        sty2_index_list = []
+
+        cui2 = my_sample(self.umls.cui, self.umls.cui_count,
+                         index * self.max_lui_per_cui, use_len * 2)
+        sample_index = 0
+        while len(str2_list) < use_len:
+            if sample_index < len(cui2):
+                use_cui2 = cui2[sample_index]
+            else:
+                sample_index = 0
+                cui2 = my_sample(self.umls.cui, self.umls.cui_count,
+                                 index * self.max_lui_per_cui, use_len * 2)
+                use_cui2 = cui2[sample_index]
+            # if not "\t".join([cui0, use_cui2, re, rel]) in self.umls.rel: # TOO SLOW!
+            if True:
+                cui2_index_list.append(self.cui2id[use_cui2])
+                sty2_index_list.append(
+                    self.sty2id[self.umls.cui2sty[use_cui2]])
+                str2_list.append(sample(self.umls.cui2str[use_cui2], 1)[0])
+                sample_index += 1
+
+        # print(str0_list)
+        # print(str1_list)
+        # print(str2_list)
+
+        input_ids = [self.tokenize_one(s)
+                     for s in str0_list + str1_list + str2_list]
+        input_ids = pad(input_ids, self.max_length)
+        input_ids_0 = input_ids[0:use_len]
+        input_ids_1 = input_ids[use_len:2 * use_len]
+        input_ids_2 = input_ids[2 * use_len:]
+
+        cui0_index = self.cui2id[cui0]
+        cui1_index = self.cui2id[cui1]
+
+        re_index = self.re2id[re]
+        rel_index = self.rel2id[rel]
+        return input_ids_0, input_ids_1, input_ids_2, \
+            [cui0_index] * use_len, [cui1_index] * use_len, cui2_index_list, \
+            [sty0_index] * use_len, [sty1_index] * use_len, sty2_index_list, \
+            [re_index] * use_len, \
+            [rel_index] * use_len
+
+    def __len__(self):
+        return self.len
+
+
+def fixed_length_dataloader(umls_dataset, fixed_length=96, num_workers=0):
+    base_sampler = RandomSampler(umls_dataset)
+    batch_sampler = FixedLengthBatchSampler(
+        sampler=base_sampler, fixed_length=fixed_length, drop_last=True)
+    dataloader = DataLoader(umls_dataset, batch_sampler=batch_sampler,
+                            collate_fn=my_collate_fn, num_workers=num_workers, pin_memory=True)
+    return dataloader
+
+
+if __name__ == "__main__":
+    umls_dataset = UMLSDataset(umls_folder="../umls",
+                               model_name_or_path="../biobert_v1.1",
+                               lang=None)
+    ipdb.set_trace()
+    umls_dataloader = fixed_length_dataloader(umls_dataset, num_workers=4)
+    now_time = time()
+    for index, batch in enumerate(umls_dataloader):
+        print(time() - now_time)
+        now_time = time()
+        if index < 10:
+            for item in batch:
+                print(item.shape)
+            #print(batch)
+        else:
+            import sys
+            sys.exit()