--- a
+++ b/Cluster-ViT/datasets/MyData.py
@@ -0,0 +1,54 @@
+'''
+Dataset for training
+Written by Whalechen
+'''
+
+from os import listdir, mkdir
+import random
+import numpy as np
+from torch.utils.data import Dataset
+import random
+import torch
+class MyDataset(Dataset):
+
+    def __init__(self, root_dir,sequence_len,max_num_cluster,status='train',input_pool=False):
+        self.patientData_list = [root_dir + "/"+ pth for pth in listdir(root_dir)]
+        self.sequence_len=sequence_len
+        self.max_num_cluster=max_num_cluster
+        self.status = status
+        self.input_pool = input_pool
+    def __len__(self):
+        return len(self.patientData_list)
+
+    def __getitem__(self, idx):
+
+        temp = np.load(self.patientData_list[idx],allow_pickle=True).item()
+        patientEmbedding=temp['patientEmbedding']
+        pos=temp['position']
+        cluster=temp['cluster']
+        Dead=temp['Dead']
+        followUpTime=temp['FollowUpTime']
+        if not self.input_pool:
+            # No pooling 
+            patientEmbedding = torch.tensor(patientEmbedding)
+            pos = torch.tensor(pos)
+            patientEmbedding = torch.cat((patientEmbedding,torch.zeros(self.sequence_len-patientEmbedding.shape[0],patientEmbedding.shape[1])))
+            pos = torch.cat((pos,torch.zeros(self.sequence_len-pos.shape[0],pos.shape[1])))
+            keyPaddingMask = torch.cat((torch.zeros(cluster.shape[0]),torch.ones(self.sequence_len-cluster.shape[0])))
+            keyPaddingMask = keyPaddingMask.type(torch.ByteTensor)
+            cluster = torch.tensor(cluster).to(torch.int64).squeeze()
+            cluster = torch.cat((cluster,self.max_num_cluster*torch.ones(self.sequence_len-cluster.shape[0]))).to(torch.int64)
+            Dead = torch.tensor(Dead).to(torch.int64)
+            followUpTime = torch.tensor(followUpTime).to(torch.float32)
+            # data processing
+        else:           
+            selectedIndex = random.choices(range(len(patientEmbedding)),k=self.sequence_len)
+            patientEmbedding = torch.tensor(patientEmbedding[selectedIndex,:])
+            pos = torch.tensor(pos[selectedIndex,:])
+            keyPaddingMask = torch.zeros(self.sequence_len).type(torch.ByteTensor)
+            cluster = cluster[selectedIndex,:]
+            cluster = torch.tensor(cluster).to(torch.int64).squeeze()
+            Dead = torch.tensor(Dead).to(torch.int64)
+            followUpTime = torch.tensor(followUpTime).to(torch.float32) 
+
+        return (patientEmbedding, pos, keyPaddingMask, cluster, Dead, followUpTime, idx)
\ No newline at end of file