Diff of /create_splits.py [000000] .. [4cd6c8]

Switch to unified view

a b/create_splits.py
1
import pdb
2
import os
3
import pandas as pd
4
import numpy as np
5
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits
6
import argparse
7
8
parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
9
parser.add_argument('--label_frac', type=float, default= 1,
10
                    help='fraction of labels (default: [0.25, 0.5, 0.75, 1.0])')
11
parser.add_argument('--seed', type=int, default=1,
12
                    help='random seed (default: 1)')
13
parser.add_argument('--k', type=int, default=10,
14
                    help='number of splits (default: 10)')
15
parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl'],
16
                    help='select one of the supported tasks for which to perform the split')
17
parser.add_argument('--hold_out_test', action='store_true', default=False,
18
                    help='hold-out the test set for each split')
19
parser.add_argument('--split_code', type=str, default=None)
20
21
args = parser.parse_args()
22
23
24
# The splits are made to account for all possible diagnosis across all splits:
25
# grade - cell rejection only low grade
26
#       - low grade cellular, high-grade antibody case ...
27
#       - ....
28
# mtl  - low grade cellular only
29
#      - high grade cellular only
30
#      - low grade cellular + quilty
31
#      - .....
32
#---------------------------------------------------
33
if args.task == 'cardiac-grade':
34
    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_GradeSplit.csv',
35
                            shuffle = False,
36
                            seed = args.seed,
37
                            print_info = True,
38
                            label_dict = {'cell_only_low'       : 0,
39
                          'cell_only_high'      : 1,
40
                      'cell_low_quilty'     : 2,
41
                      'cell_high_quilty'        : 3,
42
                      'amr_only_low'        : 4,
43
                      'amr_only_high'       : 5,
44
                      'amr_low_quilty'      : 6,
45
                      'amr_high_quilty'     : 7,
46
                      'cell_amr_low'        : 8,
47
                      'cell_amr_high'       : 9,
48
                      'cell_amr_quilty_low'     : 10,
49
                      'cell_amr_quilty_high'    : 11},
50
                            patient_strat= True,
51
                            ignore=[])
52
53
    p_val  = 0.1   # use 10% of data in validation
54
    p_test = 0.2   # to use hold-out test set set p_test = 0
55
56
57
elif args.task == 'cardiac-mtl':
58
    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTLSplit.csv',
59
                            shuffle = False,
60
                            seed = args.seed,
61
                            print_info = True,
62
                            label_dict = {'healthy'                 :0, 
63
                                          'quilty'                  :1,
64
                                          'cell_only_low'           :2, 
65
                                          'cell_only_high'          :3,
66
                                          'cell_low_quilty'         :4, 
67
                                          'cell_high_quilty'        :5,
68
                                          'amr_only_low'            :6, 
69
                                          'amr_only_high'           :7, 
70
                                          'amr_low_quilty'          :8, 
71
                                          'amr_high_quilty'         :9,
72
                                          'cell_amr_low'            :10,
73
                                          'cell_amr_high'           :11, 
74
                                          'cell_amr_quilty_low'     :12, 
75
                                          'cell_amr_quilty_high'    :13},
76
                            patient_strat= True,
77
                            ignore=[])
78
                
79
                
80
81
    p_val  = 0.1   # use 10% of data in validation
82
    p_test = 0.2   # use 20% data for test set
83
84
else:
85
    raise NotImplementedError
86
87
88
# splits
89
num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
90
val_num = np.floor(num_slides_cls * p_val).astype(int)      # use 10% data in validation
91
test_num = np.floor(num_slides_cls * p_test).astype(int)     # use 20% for test set
92
print("---------------------------------")
93
print(f"validation set size = {val_num} ")
94
print(f"test set size       = {test_num}")
95
print("---------------------------------")
96
97
98
if __name__ == '__main__':
99
    if args.label_frac > 0:
100
        label_fracs = [args.label_frac]
101
    else:
102
        label_fracs = [0.25, 0.5, 0.75, 1.0]
103
104
    if args.hold_out_test:
105
        custom_test_ids = dataset.sample_held_out(test_num=test_num)
106
    else:
107
        custom_test_ids = None
108
109
    for lf in label_fracs:
110
        if args.split_code is not None:
111
            split_dir = 'splits/'+ str(args.split_code) + '_{}'.format(int(lf * 100))
112
        else:
113
            split_dir = 'splits/'+ str(args.task) + '_{}'.format(int(lf * 100))
114
115
        os.makedirs(split_dir, exist_ok=True)
116
        #pdb.set_trace()
117
        dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf, custom_test_ids=custom_test_ids)
118
        for i in range(args.k):
119
            dataset.set_splits()
120
            descriptor_df = dataset.test_split_gen(return_descriptor=True)
121
            splits = dataset.return_splits(from_id=True)
122
            save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
123
            save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
124
            descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
125