--- a
+++ b/create_splits.py
@@ -0,0 +1,90 @@
+import pdb
+import os
+import pandas as pd
+from datasets.dataset_mtl_concat import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset, save_splits
+import argparse
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
+parser.add_argument('--label_frac', type=float, default= -1,
+										help='fraction of labels (default: [1.0])')
+parser.add_argument('--seed', type=int, default=1,
+										help='random seed (default: 1)')
+parser.add_argument('--k', type=int, default=10,
+										help='number of splits (default: 10)')
+parser.add_argument('--hold_out_test', action='store_true', default=False,
+										help='fraction to hold out (default: 0)')
+parser.add_argument('--split_code', type=str, default=None)
+parser.add_argument('--task', type=str, choices=['dummy_mtl_concat'])
+
+args = parser.parse_args()
+
+if args.task == 'dummy_mtl_concat':
+    args.n_classes=18
+    dataset = Generic_WSI_MTL_Dataset(csv_path = 'dataset_csv/dummy_dataset.csv',
+                            shuffle = False, 
+                            seed = args.seed, 
+                            print_info = True,
+                            label_dicts = [{'Lung':0, 'Breast':1, 'Colorectal':2, 'Ovarian':3, 
+                                                                'Pancreatic':4, 'Adrenal':5, 
+                                                                'Skin':6, 'Prostate':7, 'Renal':8, 'Bladder':9, 
+                                                                'Esophagagostric':10,  'Thyroid':11,
+                                                                'Head Neck':12,  'Glioma':13, 
+                                                                'Germ Cell':14, 'Endometrial': 15, 'Cervix': 16, 'Liver': 17},
+                                            {'Primary':0,  'Metastatic':1},
+                                            {'F':0, 'M':1}],
+                            label_cols = ['label', 'site', 'sex'],
+                            patient_strat= False)
+
+         
+else:
+	raise NotImplementedError
+
+num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
+val_num = np.floor(num_slides_cls * 0.1).astype(int)
+test_num = np.floor(num_slides_cls * 0.2).astype(int)
+
+print(val_num)
+print(test_num)
+
+if __name__ == '__main__':
+		if args.label_frac > 0:
+			label_fracs = [args.label_frac]
+		else:
+			label_fracs = [1.0]
+
+		if args.hold_out_test:
+			custom_test_ids = dataset.sample_held_out(test_num=test_num)
+		else:
+			custom_test_ids = None
+		for lf in label_fracs:
+			if args.split_code is not None:
+				split_dir = 'splits/'+ str(args.split_code) + '_{}'.format(int(lf * 100))
+			else:
+				split_dir = 'splits/'+ str(args.task) + '_{}'.format(int(lf * 100))
+			
+			dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf, custom_test_ids=custom_test_ids)
+
+			os.makedirs(split_dir, exist_ok=True)
+			for i in range(args.k):
+				if dataset.split_gen is None:
+					ids = []
+					for split in ['train', 'val', 'test']:
+						ids.append(dataset.get_split_from_df(pd.read_csv(os.path.join(split_dir, 'splits_{}.csv'.format(i))), split_key=split, return_ids_only=True))
+					
+					dataset.train_ids = ids[0]
+					dataset.val_ids = ids[1]
+					dataset.test_ids = ids[2]
+				else:
+					dataset.set_splits()
+
+				descriptor_df = dataset.test_split_gen(return_descriptor=True)
+				descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
+				
+				splits = dataset.return_splits(from_id=True)
+				save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
+				save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
+				
+
+
+