Diff of /create_splits_seq.py [000000] .. [0fdc30]

Switch to unified view

a b/create_splits_seq.py
1
import pdb
2
import os
3
import pandas as pd
4
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits
5
from datasets.dataset_survival import Generic_WSI_Survival_Dataset, Generic_MIL_Survival_Dataset
6
import argparse
7
import numpy as np
8
9
parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
10
parser.add_argument('--label_frac', type=float, default= 1.0,
11
                    help='fraction of labels (default: 1)')
12
parser.add_argument('--seed', type=int, default=1,
13
                    help='random seed (default: 1)')
14
parser.add_argument('--k', type=int, default=10,
15
                    help='number of splits (default: 10)')
16
parser.add_argument('--task', type=str, choices=[
17
    'task_1_tumor_vs_normal', 
18
    'task_2_tumor_subtyping', 
19
    'task_3_survival_prediction',
20
    'task_3_survival_prediction_augmented',
21
    'task_3_survival_prediction_after_T', 
22
    'task_4_tumor_grading_kat2',
23
    'task_4_tumor_grading_kat4', 
24
    'task_5_tumor_subtyping',
25
    'task_6_survival_prediction_augmented',
26
    'task_7_tumor_grading_kat2_augmented',
27
    'task_7_tumor_grading_kat4_augmented',
28
    'task_8_tumor_subtyping_augmented',
29
    'task_9_survival_prediction_augmented_random'])
30
parser.add_argument('--csv_path', type=str, default=None, help='Path to csv dataset.')
31
parser.add_argument('--split_name', type=str, default=None, help='Name of split folder.')
32
parser.add_argument('--val_frac', type=float, default= 0.1,
33
                    help='fraction of labels for validation (default: 0.1)')
34
parser.add_argument('--test_frac', type=float, default= 0.1,
35
                    help='fraction of labels for test (default: 0.1)')
36
37
args = parser.parse_args()
38
39
if args.task == 'task_1_tumor_vs_normal':
40
    args.n_classes=2
41
    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_vs_normal_dummy_clean.csv',
42
                            shuffle = False, 
43
                            seed = args.seed, 
44
                            print_info = True,
45
                            label_dict = {'normal_tissue':0, 'tumor_tissue':1},
46
                            patient_strat=True,
47
                            ignore=[])
48
49
elif args.task == 'task_2_tumor_subtyping':
50
    args.n_classes=3
51
    dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_subtyping_dummy_clean.csv',
52
                            shuffle = False, 
53
                            seed = args.seed, 
54
                            print_info = True,
55
                            label_dict = {'subtype_1':0, 'subtype_2':1, 'subtype_3':2},
56
                            patient_strat= True,
57
                            patient_voting='maj',
58
                            ignore=[])
59
60
elif args.task == 'task_3_survival_prediction':
61
    args.n_classes=2
62
63
    if args.csv_path == None:
64
        raise ValueError('Must provide a csv dataset file.')
65
    else:
66
        csv_path = args.csv_path
67
68
    dataset = Generic_WSI_Survival_Dataset(csv_path = csv_path,
69
                        shuffle = False, 
70
                        seed = args.seed, 
71
                        print_info = True,
72
                        label_dict = {'lebt':0, 'tod':1},
73
                        event_col = 'event',
74
                        time_col = 'time',
75
                        patient_strat=True,
76
                        ignore=[])
77
78
79
elif args.task == 'task_3_survival_prediction_augmented':
80
    args.n_classes=2
81
82
    if args.csv_path == None:
83
        raise ValueError('Must provide a csv dataset file.')
84
    else:
85
        csv_path = args.csv_path
86
87
    dataset = Generic_WSI_Survival_Dataset(csv_path = csv_path,
88
                        shuffle = False, 
89
                        seed = args.seed, 
90
                        print_info = True,
91
                        label_dict = {'lebt':0, 'tod':1},
92
                        event_col = 'event',
93
                        time_col = 'time',
94
                        patient_strat=True,
95
                        ignore=[])
96
97
elif args.task == 'task_3_survival_prediction_after_T':
98
    args.n_classes=2
99
100
    if args.csv_path == None:
101
        raise ValueError('Must provide a csv dataset file.')
102
    else:
103
        csv_path = args.csv_path
104
105
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
106
                        shuffle = False, 
107
                        seed = args.seed, 
108
                        print_info = True,
109
                        label_dict = {'lebt':0, 'tod':1},
110
                        label_col = 'Survival_after_T',
111
                        patient_strat=True,
112
                        ignore=[])
113
114
elif args.task == 'task_4_tumor_grading_kat2':
115
    args.n_classes=2
116
    if args.csv_path == None:
117
        raise ValueError('Must provide a csv dataset file.')
118
    else:
119
        csv_path = args.csv_path
120
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
121
                        shuffle = False, 
122
                        seed = args.seed, 
123
                        print_info = True,
124
                        label_dict = {'G1 G2':0, 'G3 G4':1},
125
                        label_col = 'Grading_kat2',
126
                        patient_strat=True,
127
                        ignore=[])
128
129
130
elif args.task == 'task_4_tumor_grading_kat4':
131
    args.n_classes=4
132
    if args.csv_path == None:
133
        raise ValueError('Must provide a csv dataset file.')
134
    else:
135
        csv_path = args.csv_path
136
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
137
                        shuffle = False, 
138
                        seed = args.seed, 
139
                        print_info = True,
140
                        label_dict = {'niedriger Malignitätsgrad':0, 'mittlerer Malignitätsgrad':1, 'hoher Malignitätsgrad':2, 'sehr hoher Malignitätsgrad':3},
141
                        label_col = 'Grading_kat4',
142
                        patient_strat=True,
143
                        ignore=[])
144
145
146
elif args.task == 'task_5_tumor_subtyping':
147
    args.n_classes=4
148
    if args.csv_path == None:
149
        raise ValueError('Must provide a csv dataset file.')
150
    else:
151
        csv_path = args.csv_path
152
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
153
                        shuffle = False, 
154
                        seed = args.seed, 
155
                        print_info = True,
156
                        label_dict = {'Plattenepithelkarzinom':0, 'Adenokarzinom+BAC':1, 'grosszelliges Karzinom':2, 'NSCLC NOS':3},
157
                        label_col = 'Histo_kat6',
158
                        patient_strat=True,
159
                        ignore=[])
160
161
162
elif args.task == 'task_6_survival_prediction_augmented':
163
    args.n_classes=2
164
    if args.csv_path == None:
165
        raise ValueError('Must provide a csv dataset file.')
166
    else:
167
        csv_path = args.csv_path
168
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
169
                        shuffle = False, 
170
                        seed = args.seed, 
171
                        print_info = True,
172
                        label_dict = {'lebt':0, 'tod':1},
173
                        label_col = 'Survival_after_T',
174
                        patient_strat=True,
175
                        ignore=[])
176
177
elif args.task == 'task_7_tumor_grading_kat2_augmented':
178
    args.n_classes=2
179
    if args.csv_path == None:
180
        raise ValueError('Must provide a csv dataset file.')
181
    else:
182
        csv_path = args.csv_path
183
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
184
                        shuffle = False, 
185
                        seed = args.seed, 
186
                        print_info = True,
187
                        label_dict = {'G1 G2':0, 'G3 G4':1},
188
                        label_col = 'Grading_kat2',
189
                        patient_strat=True,
190
                        ignore=[])
191
192
elif args.task == 'task_7_tumor_grading_kat4_augmented':
193
    args.n_classes=4
194
    if args.csv_path == None:
195
        raise ValueError('Must provide a csv dataset file.')
196
    else:
197
        csv_path = args.csv_path
198
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
199
                        shuffle = False, 
200
                        seed = args.seed, 
201
                        print_info = True,
202
                        label_dict = {'niedriger Malignitätsgrad':0, 'mittlerer Malignitätsgrad':1, 'hoher Malignitätsgrad':2, 'sehr hoher Malignitätsgrad':3},
203
                        label_col = 'Grading_kat4',
204
                        patient_strat=True,
205
                        ignore=[])
206
207
elif args.task == 'task_8_tumor_subtyping_augmented':
208
    args.n_classes=4
209
    if args.csv_path == None:
210
        raise ValueError('Must provide a csv dataset file.')
211
    else:
212
        csv_path = args.csv_path
213
    dataset = Generic_WSI_Classification_Dataset(csv_path = csv_path,
214
                        shuffle = False, 
215
                        seed = args.seed, 
216
                        print_info = True,
217
                        label_dict = {'Plattenepithelkarzinom':0, 'Adenokarzinom+BAC':1, 'grosszelliges Karzinom':2, 'NSCLC NOS':3},
218
                        label_col = 'Histo_kat6',
219
                        patient_strat=True,
220
                        ignore=[])
221
elif args.task == 'task_9_survival_prediction_augmented_random':
222
    args.n_classes=2
223
    dataset = Generic_WSI_Classification_Dataset(csv_path = '/home/ammeling/projects/TMA/annotations/aug_survival_prediction_random.csv',
224
                        shuffle = False, 
225
                        seed = args.seed, 
226
                        print_info = True,
227
                        label_dict = {'lebt':0, 'tod':1},
228
                        label_col = 'Survival_Status',
229
                        patient_strat=True,
230
                        ignore=[])
231
232
233
else:
234
    raise NotImplementedError
235
236
num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
237
val_num = np.round(num_slides_cls * args.val_frac).astype(int)
238
test_num = np.round(num_slides_cls * args.test_frac).astype(int)
239
240
if __name__ == '__main__':
241
    if args.label_frac > 0:
242
        label_fracs = [args.label_frac]
243
    else:
244
        label_fracs = [0.1, 0.25, 0.5, 0.75, 1.0]
245
246
    if args.split_name is not None:
247
        split_name = args.split_name
248
    else:
249
        split_name = ''
250
        
251
    for lf in label_fracs:
252
        split_dir = 'splits/'+ str(args.task)  +'_{}_{}'.format(split_name, int(lf * 100))
253
        os.makedirs(split_dir, exist_ok=True)
254
        dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf)
255
        for i in range(args.k):
256
            dataset.set_splits()
257
            descriptor_df = dataset.test_split_gen(return_descriptor=True)
258
            splits = dataset.return_splits(from_id=True)
259
            save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
260
            save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
261
            descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
262
263
264