--- a +++ b/Cluster-ViT/datasets/MyData.py @@ -0,0 +1,54 @@ +''' +Dataset for training +Written by Whalechen +''' + +from os import listdir, mkdir +import random +import numpy as np +from torch.utils.data import Dataset +import random +import torch +class MyDataset(Dataset): + + def __init__(self, root_dir,sequence_len,max_num_cluster,status='train',input_pool=False): + self.patientData_list = [root_dir + "/"+ pth for pth in listdir(root_dir)] + self.sequence_len=sequence_len + self.max_num_cluster=max_num_cluster + self.status = status + self.input_pool = input_pool + def __len__(self): + return len(self.patientData_list) + + def __getitem__(self, idx): + + temp = np.load(self.patientData_list[idx],allow_pickle=True).item() + patientEmbedding=temp['patientEmbedding'] + pos=temp['position'] + cluster=temp['cluster'] + Dead=temp['Dead'] + followUpTime=temp['FollowUpTime'] + if not self.input_pool: + # No pooling + patientEmbedding = torch.tensor(patientEmbedding) + pos = torch.tensor(pos) + patientEmbedding = torch.cat((patientEmbedding,torch.zeros(self.sequence_len-patientEmbedding.shape[0],patientEmbedding.shape[1]))) + pos = torch.cat((pos,torch.zeros(self.sequence_len-pos.shape[0],pos.shape[1]))) + keyPaddingMask = torch.cat((torch.zeros(cluster.shape[0]),torch.ones(self.sequence_len-cluster.shape[0]))) + keyPaddingMask = keyPaddingMask.type(torch.ByteTensor) + cluster = torch.tensor(cluster).to(torch.int64).squeeze() + cluster = torch.cat((cluster,self.max_num_cluster*torch.ones(self.sequence_len-cluster.shape[0]))).to(torch.int64) + Dead = torch.tensor(Dead).to(torch.int64) + followUpTime = torch.tensor(followUpTime).to(torch.float32) + # data processing + else: + selectedIndex = random.choices(range(len(patientEmbedding)),k=self.sequence_len) + patientEmbedding = torch.tensor(patientEmbedding[selectedIndex,:]) + pos = torch.tensor(pos[selectedIndex,:]) + keyPaddingMask = torch.zeros(self.sequence_len).type(torch.ByteTensor) + cluster = cluster[selectedIndex,:] + cluster = torch.tensor(cluster).to(torch.int64).squeeze() + Dead = torch.tensor(Dead).to(torch.int64) + followUpTime = torch.tensor(followUpTime).to(torch.float32) + + return (patientEmbedding, pos, keyPaddingMask, cluster, Dead, followUpTime, idx) \ No newline at end of file