Switch to side-by-side view

--- a
+++ b/AICare-baselines/datasets/loader/datamodule.py
@@ -0,0 +1,53 @@
+import os
+
+import lightning as L
+import pandas as pd
+import torch
+import torch.utils.data as data
+
+
+class EhrDataset(data.Dataset):
+    def __init__(self, data_path, mode='train'):
+        super().__init__()
+        self.data = pd.read_pickle(os.path.join(data_path,f'{mode}_x.pkl'))
+        self.label = pd.read_pickle(os.path.join(data_path,f'{mode}_y.pkl'))
+        self.pid = pd.read_pickle(os.path.join(data_path,f'{mode}_pid.pkl'))
+
+    def __len__(self):
+        return len(self.label) # number of patients
+
+    def __getitem__(self, index):
+        return self.data[index], self.label[index], self.pid[index]
+
+
+class EhrDataModule(L.LightningDataModule):
+    def __init__(self, data_path, batch_size=32):
+        super().__init__()
+        self.data_path = data_path
+        self.batch_size = batch_size
+
+    def setup(self, stage: str):
+        if stage=="fit":
+            self.train_dataset = EhrDataset(self.data_path, mode="train")
+            self.val_dataset = EhrDataset(self.data_path, mode='val')
+        if stage=="test":
+            self.test_dataset = EhrDataset(self.data_path, mode='test')
+
+    def train_dataloader(self):
+        return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True , collate_fn=self.pad_collate, num_workers=8)
+
+    def val_dataloader(self):
+        return data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False , collate_fn=self.pad_collate, num_workers=8)
+
+    def test_dataloader(self):
+        return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False , collate_fn=self.pad_collate, num_workers=8)
+
+    def pad_collate(self, batch):
+        xx, yy, pid = zip(*batch)
+        lens = torch.as_tensor([len(x) for x in xx])
+        # convert to tensor
+        xx = [torch.tensor(x) for x in xx]
+        yy = [torch.tensor(y) for y in yy]
+        xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)
+        yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)
+        return xx_pad.float(), yy_pad.float(), lens, pid
\ No newline at end of file