--- a
+++ b/datasets.py
@@ -0,0 +1,118 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+import glob
+import os
+from math import ceil, floor
+from medpy.io import load, header
+from models import Model
+import utils
+import pandas as pd
+import matplotlib.pyplot as plt
+
+
+class RadDataset(Dataset):
+    def __init__(self, df, root_data,train_flag=True, dim=[48, 48, 3], ring=15):
+        self.df = df
+
+        self.train_flag=train_flag
+        self.transforms = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.RandomAffine(3, scale=(0.95, 1.05)),
+            transforms.RandomHorizontalFlip(0.5),
+            transforms.RandomVerticalFlip(0.5)
+            ])
+        self.test_transforms = transforms.Compose([
+            transforms.ToTensor(),
+            ])
+        self.y = np.array(df["DFS_3years"]).astype(np.float32)
+        self.time = np.array(df["DFS"]).astype(np.float32)
+        self.event = np.array(df["DFS_censor"]).astype(np.float32)
+        self.ID = np.array(df["radiology_folder_name"])
+
+
+        self.dim = dim
+        self.ring = ring
+        self.root_data = root_data
+
+    def __len__(self):
+        return len(self.y)
+
+    def get_radiology(self, ct_image, index,train_flag):
+        concat_vols = []
+
+        torch.cuda.manual_seed_all(42)
+        torch.manual_seed(42)
+        np.random.seed(42)
+
+        for location in ['tumor', 'lymph']:
+            
+            X_min, X_max, Y_min, Y_max, Z_min, Z_max = np.array(
+                self.df["X_min_" + location][index]), np.array(
+                self.df["X_max_" + location][index]), np.array(
+                self.df["Y_min_" + location][index]), np.array(
+                self.df["Y_max_" + location][index]), np.array(
+                self.df["Z_min_" + location][index]), np.array(
+                self.df["Z_max_" + location][index])
+            X_min -= self.ring
+            Y_min -= self.ring
+            Z_min = max(3, Z_min - self.ring)
+            X_max += self.ring
+            Y_max += self.ring
+            Z_max = min(ct_image.shape[-1]-1, Z_max+ self.ring)
+
+            center_Y = int(ceil(int(Y_min+Y_max)/2))
+            center_X = int(ceil(int(X_min+X_max)/2))
+
+            Z_1, Z_2, Z_3 = Z_min+int((Z_max - Z_min)/4), Z_min + \
+                int((Z_max - Z_min)/2), Z_min + \
+                int(3*(Z_max - Z_min)/4)
+            
+            center_Z1 = int((Z_min+Z_1)/2)
+            center_Z2 = int((Z_1+Z_2)/2)
+            center_Z3 = Z_1
+            center_Z4 = Z_3
+            
+            center1 = [center_Y, center_X, center_Z1]
+            center2 = [center_Y, center_X, center_Z2]
+            center3 = [center_Y, center_X, center_Z3]
+            center4 = [center_Y, center_X, center_Z4]
+            
+            if train_flag:
+                sub_vol1 = self.transforms(
+                    utils.random_crop(ct_image, self.dim, center1))
+                sub_vol2 = self.transforms(
+                    utils.random_crop(ct_image, self.dim, center2))
+                sub_vol3 = self.transforms(
+                    utils.random_crop(ct_image, self.dim, center3))
+                sub_vol4 = self.transforms(
+                    utils.random_crop(ct_image, self.dim, center4))
+                vol = torch.stack(
+                    (sub_vol1, sub_vol2, sub_vol3, sub_vol4))
+                concat_vols.append(vol)
+            else:
+                sub_vol1 = self.test_transforms(
+                    utils.random_crop(ct_image, self.dim, center1))
+                sub_vol2 = self.test_transforms(
+                    utils.random_crop(ct_image, self.dim, center2))
+                sub_vol3 = self.test_transforms(
+                    utils.random_crop(ct_image, self.dim, center3))
+                sub_vol4 = self.test_transforms(
+                    utils.random_crop(ct_image, self.dim, center4))
+                vol = torch.stack(
+                    (sub_vol1, sub_vol2, sub_vol3, sub_vol4))
+                concat_vols.append(vol)
+        return concat_vols
+
+    def __getitem__(self, index):
+
+        ct_image, _ = load(os.path.join(self.root_data, self.df["radiology_folder_name"].iloc[index], "CT_img.nii.gz"))
+
+        
+        ct_image = utils.soft_tissue_window(ct_image)
+
+        ct_vol = self.get_radiology(ct_image, index,self.train_flag)
+        ct_tumor, ct_lymphnodes = ct_vol[0], ct_vol[1]
+
+        return ct_tumor, ct_lymphnodes, self.y[index], self.time[index], self.event[index], self.ID[index]
\ No newline at end of file