--- a
+++ b/src/dataset.py
@@ -0,0 +1,241 @@
+import numpy as np
+import os
+import cv2
+import pandas as pd
+from torch.utils.data import Dataset
+# import jpeg4py as jpeg
+from utils import get_windowing, window_image
+import pydicom
+
+IGNORE_IDS = [
+    'ID_6431af929',
+]
+
+windows_range = {
+    'brain': [40, 80],
+    'bone': [600, 2800],
+    'subdual': [75, 215]
+}
+
+LABEL_COLS = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural", "any"]
+LABEL_COLS_WITHOUT_ANY = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]
+
+
+def load_dicom_image(path):
+    data = pydicom.read_file(path)
+    image = data.pixel_array
+    window_center, window_width, intercept, slope = get_windowing(data)
+    images = []
+    image_windowed = window_image(image, window_center, window_width, intercept, slope)
+    images.append(image_windowed)
+
+    for k, v in windows_range.items():
+        image_windowed = window_image(image, v[0], v[1], intercept, slope)
+        images.append(image_windowed)
+
+    images = np.asarray(images).transpose((1, 2, 0))
+    images = images / 255
+    return images
+
+
+def load_image(path):
+    image = cv2.imread(path)
+    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+    return image
+
+
+def load_random_windows(path, id):
+    random_window = np.random.choice(['brain', 'bone', 'subdual'], 1)[0]
+    return load_image(os.path.join(path, random_window, id + ".jpg"))
+
+
+def load_multi_images(root, image_name):
+    images = []
+    for i, (k, v) in enumerate(windows_range.items()):
+        image = cv2.imread(os.path.join(root, k, image_name), 0)
+        images.append(image)
+
+    images = np.asarray(images).transpose((1, 2, 0))
+
+    return images
+
+
+# def load_jpeg_image(path):
+#     image = jpeg.JPEG(path).decode()
+#     return image
+
+
+import random
+def get_balance_set(df):
+    patients = set(df["patient_id"].unique())
+    patients_pos = set(df[df["any"] == 1]["patient_id"].unique())
+    patients_neg = patients - patients_pos
+    patients_neg_balance = random.sample(patients_neg, len(patients_pos))
+    patients_balance = patients_pos.union(patients_neg_balance)
+
+    print(len(patients), len(patients_pos), len(patients), len(patients_balance))
+
+    return df[df["patient_id"].isin(patients_balance)]
+
+
+from sklearn.preprocessing import MinMaxScaler
+meta_data_cols = [
+    'image_position_patient_0', 'image_position_patient_1', 'image_position_patient_2',
+    'image_orientation_patient_0', 'image_orientation_patient_2', 'image_orientation_patient_3',
+    'image_orientation_patient_4', 'image_orientation_patient_5'
+]
+
+
+class RSNADataset(Dataset):
+    """
+    Read JPG images
+    """
+    def __init__(self, csv_file, root, with_any, transform, mode='train', image_type='jpg'):
+        if isinstance(csv_file, pd.DataFrame):
+            df = csv_file
+        else:
+            print(csv_file)
+            df = pd.read_csv(csv_file)
+        if mode == 'train':
+            # df = df
+            df = get_balance_set(df)
+        if mode in ['train', 'valid']:
+            meta_data = pd.read_csv(f"/data/df_dicom_metadata_train.csv", usecols=meta_data_cols + ['sop_instance_uid'])
+        else:
+            meta_data = pd.read_csv(f"/data/df_dicom_metadata_test.csv", usecols=meta_data_cols + ['sop_instance_uid'])
+            df["sop_instance_uid"] = "ID_" + df["sop_instance_uid"]
+        meta_data = meta_data[meta_data['sop_instance_uid'].isin(df['sop_instance_uid'])]
+        df = df.merge(meta_data, on='sop_instance_uid', how='left')
+        ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid"
+        df = df[~df[ID_col].isin(IGNORE_IDS)]
+        self.ids = df[ID_col].values
+        self.metadata = df[meta_data_cols].values
+        self.with_any = with_any
+        if with_any:
+            self.labels = df[LABEL_COLS].values
+        else:
+            self.labels = df[LABEL_COLS_WITHOUT_ANY].values
+        self.root = root
+        self.transform = transform
+        self.image_type = image_type
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __getitem__(self, idx):
+        id = self.ids[idx]
+        label = self.labels[idx].astype(np.float32)
+
+        meta = self.metadata[idx].astype(np.float32)
+
+        if not "ID" in id:
+            id = "ID_" + id
+
+        image = os.path.join(self.root, id + "." + self.image_type)
+        image = load_image(image)
+
+        if self.transform:
+            augmented = self.transform(image=image)
+            image = augmented['image']
+
+        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
+
+        return {
+            'images': image,
+            'targets': label,
+            'meta': meta
+        }
+
+
+class RSNARandomWindowDataset(RSNADataset):
+    """
+    Random select bone, brain and subdual during the training
+    """
+
+    def __getitem__(self, idx):
+        id = self.ids[idx]
+        label = self.labels[idx].astype(np.float32)
+
+        image = load_random_windows(self.root, id)
+
+        if self.transform:
+            augmented = self.transform(image=image)
+            image = augmented['image']
+
+        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
+
+        return {
+            'images': image,
+            'targets': label
+        }
+
+
+class RSNADicomDataset(RSNADataset):
+    """
+    load dicom image directly. windows are applied on the fly.
+    """
+    def __init__(self, csv_file, root, with_any, transform, mode='train'):
+        super(RSNADicomDataset, self).__init__(csv_file, root, with_any, transform, mode)
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __getitem__(self, idx):
+        id = self.ids[idx]
+        label = self.labels[idx].astype(np.float32)
+
+        image = os.path.join(self.root, id + ".dcm")
+        image = load_dicom_image(image)
+
+        if self.transform:
+            augmented = self.transform(image=image)
+            image = augmented['image']
+
+        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
+
+        return {
+            'images': image,
+            'targets': label
+        }
+
+
+class RSNAMultiWindowsDataset(Dataset):
+    """
+    Read all window images then concatinate.
+    """
+    def __init__(self, csv_file, root, with_any, transform):
+        if isinstance(csv_file, pd.DataFrame):
+            df = csv_file
+        else:
+            df = pd.read_csv(csv_file)
+        ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid"
+        df = df[~df[ID_col].isin(IGNORE_IDS)]
+        self.ids = df[ID_col].values
+        self.with_any = with_any
+        if with_any:
+            self.labels = df[LABEL_COLS].values
+        else:
+            self.labels = df[LABEL_COLS_WITHOUT_ANY].values
+        self.root = root
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __getitem__(self, idx):
+        id = self.ids[idx]
+        label = self.labels[idx].astype(np.float32)
+
+        # image = os.path.join(self.root, id + ".jpg")
+        image = load_multi_images(self.root, id + ".jpg")
+
+        if self.transform:
+            augmented = self.transform(image=image)
+            image = augmented['image']
+
+        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
+
+        return {
+            'images': image,
+            'targets': label
+        }