--- a
+++ b/create_splits.py
@@ -0,0 +1,125 @@
+import pdb
+import os
+import pandas as pd
+import numpy as np
+from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits
+import argparse
+
+parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
+parser.add_argument('--label_frac', type=float, default= 1,
+                    help='fraction of labels (default: [0.25, 0.5, 0.75, 1.0])')
+parser.add_argument('--seed', type=int, default=1,
+                    help='random seed (default: 1)')
+parser.add_argument('--k', type=int, default=10,
+                    help='number of splits (default: 10)')
+parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl'],
+                    help='select one of the supported tasks for which to perform the split')
+parser.add_argument('--hold_out_test', action='store_true', default=False,
+                    help='hold-out the test set for each split')
+parser.add_argument('--split_code', type=str, default=None)
+
+args = parser.parse_args()
+
+
+# The splits are made to account for all possible diagnosis across all splits:
+# grade - cell rejection only low grade
+#       - low grade cellular, high-grade antibody case ...
+#       - ....
+# mtl  - low grade cellular only
+#      - high grade cellular only
+#      - low grade cellular + quilty
+#      - .....
+#---------------------------------------------------
+if args.task == 'cardiac-grade':
+    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_GradeSplit.csv',
+                            shuffle = False,
+                            seed = args.seed,
+                            print_info = True,
+                            label_dict = {'cell_only_low'		: 0,
+			    		  'cell_only_high'		: 1,
+					  'cell_low_quilty'		: 2,
+					  'cell_high_quilty'		: 3,
+					  'amr_only_low' 		: 4,
+					  'amr_only_high'		: 5,
+					  'amr_low_quilty'		: 6,
+					  'amr_high_quilty'		: 7,
+					  'cell_amr_low'		: 8,
+					  'cell_amr_high'		: 9,
+					  'cell_amr_quilty_low'		: 10,
+					  'cell_amr_quilty_high'	: 11},
+                            patient_strat= True,
+                            ignore=[])
+
+    p_val  = 0.1   # use 10% of data in validation
+    p_test = 0.2   # to use hold-out test set set p_test = 0
+
+
+elif args.task == 'cardiac-mtl':
+    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTLSplit.csv',
+                            shuffle = False,
+                            seed = args.seed,
+                            print_info = True,
+                            label_dict = {'healthy'                 :0, 
+                                          'quilty'                  :1,
+                                          'cell_only_low'           :2, 
+                                          'cell_only_high'          :3,
+                                          'cell_low_quilty'         :4, 
+                                          'cell_high_quilty'        :5,
+                                          'amr_only_low'            :6, 
+                                          'amr_only_high'           :7, 
+                                          'amr_low_quilty'          :8, 
+                                          'amr_high_quilty'         :9,
+                                          'cell_amr_low'            :10,
+                                          'cell_amr_high'           :11, 
+                                          'cell_amr_quilty_low'     :12, 
+                                          'cell_amr_quilty_high'    :13},
+                            patient_strat= True,
+                            ignore=[])
+			    
+			    
+
+    p_val  = 0.1   # use 10% of data in validation
+    p_test = 0.2   # use 20% data for test set
+
+else:
+    raise NotImplementedError
+
+
+# splits
+num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
+val_num = np.floor(num_slides_cls * p_val).astype(int)      # use 10% data in validation
+test_num = np.floor(num_slides_cls * p_test).astype(int)     # use 20% for test set
+print("---------------------------------")
+print(f"validation set size = {val_num} ")
+print(f"test set size       = {test_num}")
+print("---------------------------------")
+
+
+if __name__ == '__main__':
+    if args.label_frac > 0:
+        label_fracs = [args.label_frac]
+    else:
+        label_fracs = [0.25, 0.5, 0.75, 1.0]
+
+    if args.hold_out_test:
+        custom_test_ids = dataset.sample_held_out(test_num=test_num)
+    else:
+        custom_test_ids = None
+
+    for lf in label_fracs:
+        if args.split_code is not None:
+            split_dir = 'splits/'+ str(args.split_code) + '_{}'.format(int(lf * 100))
+        else:
+            split_dir = 'splits/'+ str(args.task) + '_{}'.format(int(lf * 100))
+
+        os.makedirs(split_dir, exist_ok=True)
+        #pdb.set_trace()
+        dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf, custom_test_ids=custom_test_ids)
+        for i in range(args.k):
+            dataset.set_splits()
+            descriptor_df = dataset.test_split_gen(return_descriptor=True)
+            splits = dataset.return_splits(from_id=True)
+            save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
+            save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
+            descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
+