--- a
+++ b/data/dataset_split.py
@@ -0,0 +1,93 @@
+import sys
+import pandas as pd
+from sklearn.model_selection import train_test_split
+import os
+import pdb
+import numpy as np
+import math
+
+def pairing(odds, idx = ''):
+    odds['dummy'] = 1
+    odds['pair'] = odds.sort_values(['case','slice_id', 'day'],ascending=False).groupby(['case', 'day'])['slice_id'].shift()
+    odds = odds[odds['slice_id'] % 2 == 1]
+    odds['pair_idx'] = odds['dummy'].cumsum()
+    odds['pair_idx'] = odds['pair_idx'].apply(lambda x: idx + str(x))
+
+    #odds['pair'] = odds['pair'].astype('int')
+    odds = pd.concat([odds[['case', 'slice_id', 'day', 'pair_idx', 'dummy']], odds[['case', 'pair', 'day', 'pair_idx', 'dummy']].rename({'pair': 'slice_id'}, axis = 'columns')])
+
+    #Drop those that only have 1
+    odds['pair_total'] = odds.groupby(['pair_idx'])['dummy'].sum()
+    odds = odds[odds['pair_total'] != 2]
+    return odds.drop(['pair_total', 'dummy'], axis = 'columns')
+
+def dataset_split(dataset_path, output_folder, train_prop=0.7, val_prop=0.2, test_prop=0.1):
+    # Default split is 70/20/10
+    # Check if the splits add up to 1
+    total = train_prop + val_prop + test_prop
+    if abs(total - 1) > 0.0000001:
+        print("Train, validation, and test proportions must add up to 1. Instead, they are", round(total, 3))
+
+    # Create output folder if its not already created
+    if not os.path.exists(output_folder):
+        os.mkdir(output_folder)
+
+    data = pd.read_csv(dataset_path, index_col=0)
+
+    #Patient Version
+    #cases = pd.DataFrame(data['case'].unique(), columns = ['case'])
+
+    #Patient-Day Version
+    cases = data[['case', 'day']].drop_duplicates()
+
+    # Split into train and val+test datasets
+    train_cases, others_cases = train_test_split(cases, test_size=val_prop+test_prop, random_state=0)
+
+    # Split the val+test datasets into validation and test
+    val_cases, test_cases = train_test_split(others_cases, test_size=test_prop/(val_prop+test_prop), random_state=0)
+
+    #Patient Version
+    #train = pd.merge(train_cases, data, how = "left", on = ['case'])
+    #val = pd.merge(val_cases, data, how = "left", on = ['case'])
+    #test = pd.merge(test_cases, data, how = "left", on = ['case'])
+
+    #Patient-Day Version
+    train = pd.merge(train_cases, data, how = "left", on = ['case', 'day'])
+    val = pd.merge(val_cases, data, how = "left", on = ['case', 'day'])
+    test = pd.merge(test_cases, data, how = "left", on = ['case', 'day'])
+
+    #Batching for train
+    pair_ids = pairing(train)
+    train = pd.merge(pair_ids, train, on = ['case', 'slice_id', 'day'])
+    train['dummy'] = train.groupby(['pair_idx'])['dummy'].cumsum()
+    train1 = train[train['dummy'] == 1]
+
+    #Random sort, make sure no more than one repeat in a batch
+    min_check = 0
+    while min_check <= 1:
+        train1 = train1.sample(frac=1).reset_index(drop=True)
+        train1['batch'] = train1.index // 16
+        check = train1.groupby('batch')['case'].count()
+        min_check = min(check)
+    
+    train2 = pd.merge(train[train['dummy'] == 2], train1[['pair_idx', 'batch']], on = 'pair_idx')
+    train = pd.concat([train1, train2]).sort_values(['batch', 'pair_idx']).reset_index(drop=True)
+
+    # Output train, val, test datasets
+    train.to_csv(os.path.join(output_folder, "train_dataset.csv"), index=False)
+    val.to_csv(os.path.join(output_folder, "val_dataset.csv"), index=False)
+    test.to_csv(os.path.join(output_folder, "test_dataset.csv"), index=False)
+
+if __name__ == '__main__':
+    # usage: python data/dataset_split.py [final.csv] [output folder] [train percent] [val percent] [test percent]
+    dataset_path = sys.argv[1]
+    output_folder = sys.argv[2]
+    try:
+        train_prop = int(sys.argv[3]) / 100
+        val_prop = int(sys.argv[4]) / 100
+        test_prop = int(sys.argv[5]) / 100
+        #print(f"Using train-val-test split of {sys.argv[3]}%-{sys.argv[4]}%-{sys.argv[5]}%")
+        dataset_split(dataset_path, output_folder, train_prop, val_prop, test_prop)
+    except:
+        #print("Using default train-val-test split of 70%-20%-10%")
+        dataset_split(dataset_path, output_folder)
\ No newline at end of file