Diff of /create_splits.py [000000] .. [fdd588]

Switch to unified view

a b/create_splits.py
1
import pdb
2
import os
3
import pandas as pd
4
from datasets.dataset_mtl_concat import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset, save_splits
5
import argparse
6
import numpy as np
7
8
parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
9
parser.add_argument('--label_frac', type=float, default= -1,
10
                                        help='fraction of labels (default: [1.0])')
11
parser.add_argument('--seed', type=int, default=1,
12
                                        help='random seed (default: 1)')
13
parser.add_argument('--k', type=int, default=10,
14
                                        help='number of splits (default: 10)')
15
parser.add_argument('--hold_out_test', action='store_true', default=False,
16
                                        help='fraction to hold out (default: 0)')
17
parser.add_argument('--split_code', type=str, default=None)
18
parser.add_argument('--task', type=str, choices=['dummy_mtl_concat'])
19
20
args = parser.parse_args()
21
22
if args.task == 'dummy_mtl_concat':
23
    args.n_classes=18
24
    dataset = Generic_WSI_MTL_Dataset(csv_path = 'dataset_csv/dummy_dataset.csv',
25
                            shuffle = False, 
26
                            seed = args.seed, 
27
                            print_info = True,
28
                            label_dicts = [{'Lung':0, 'Breast':1, 'Colorectal':2, 'Ovarian':3, 
29
                                                                'Pancreatic':4, 'Adrenal':5, 
30
                                                                'Skin':6, 'Prostate':7, 'Renal':8, 'Bladder':9, 
31
                                                                'Esophagagostric':10,  'Thyroid':11,
32
                                                                'Head Neck':12,  'Glioma':13, 
33
                                                                'Germ Cell':14, 'Endometrial': 15, 'Cervix': 16, 'Liver': 17},
34
                                            {'Primary':0,  'Metastatic':1},
35
                                            {'F':0, 'M':1}],
36
                            label_cols = ['label', 'site', 'sex'],
37
                            patient_strat= False)
38
39
         
40
else:
41
    raise NotImplementedError
42
43
num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
44
val_num = np.floor(num_slides_cls * 0.1).astype(int)
45
test_num = np.floor(num_slides_cls * 0.2).astype(int)
46
47
print(val_num)
48
print(test_num)
49
50
if __name__ == '__main__':
51
        if args.label_frac > 0:
52
            label_fracs = [args.label_frac]
53
        else:
54
            label_fracs = [1.0]
55
56
        if args.hold_out_test:
57
            custom_test_ids = dataset.sample_held_out(test_num=test_num)
58
        else:
59
            custom_test_ids = None
60
        for lf in label_fracs:
61
            if args.split_code is not None:
62
                split_dir = 'splits/'+ str(args.split_code) + '_{}'.format(int(lf * 100))
63
            else:
64
                split_dir = 'splits/'+ str(args.task) + '_{}'.format(int(lf * 100))
65
            
66
            dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf, custom_test_ids=custom_test_ids)
67
68
            os.makedirs(split_dir, exist_ok=True)
69
            for i in range(args.k):
70
                if dataset.split_gen is None:
71
                    ids = []
72
                    for split in ['train', 'val', 'test']:
73
                        ids.append(dataset.get_split_from_df(pd.read_csv(os.path.join(split_dir, 'splits_{}.csv'.format(i))), split_key=split, return_ids_only=True))
74
                    
75
                    dataset.train_ids = ids[0]
76
                    dataset.val_ids = ids[1]
77
                    dataset.test_ids = ids[2]
78
                else:
79
                    dataset.set_splits()
80
81
                descriptor_df = dataset.test_split_gen(return_descriptor=True)
82
                descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
83
                
84
                splits = dataset.return_splits(from_id=True)
85
                save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
86
                save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
87
                
88
89
90