[dc0d36]: / create_splits.py

Download this file

126 lines (106 with data), 5.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pdb
import os
import pandas as pd
import numpy as np
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits
import argparse
parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
parser.add_argument('--label_frac', type=float, default= 1,
help='fraction of labels (default: [0.25, 0.5, 0.75, 1.0])')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--k', type=int, default=10,
help='number of splits (default: 10)')
parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl'],
help='select one of the supported tasks for which to perform the split')
parser.add_argument('--hold_out_test', action='store_true', default=False,
help='hold-out the test set for each split')
parser.add_argument('--split_code', type=str, default=None)
args = parser.parse_args()
# The splits are made to account for all possible diagnosis across all splits:
# grade - cell rejection only low grade
# - low grade cellular, high-grade antibody case ...
# - ....
# mtl - low grade cellular only
# - high grade cellular only
# - low grade cellular + quilty
# - .....
#---------------------------------------------------
if args.task == 'cardiac-grade':
dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_GradeSplit.csv',
shuffle = False,
seed = args.seed,
print_info = True,
label_dict = {'cell_only_low' : 0,
'cell_only_high' : 1,
'cell_low_quilty' : 2,
'cell_high_quilty' : 3,
'amr_only_low' : 4,
'amr_only_high' : 5,
'amr_low_quilty' : 6,
'amr_high_quilty' : 7,
'cell_amr_low' : 8,
'cell_amr_high' : 9,
'cell_amr_quilty_low' : 10,
'cell_amr_quilty_high' : 11},
patient_strat= True,
ignore=[])
p_val = 0.1 # use 10% of data in validation
p_test = 0.2 # to use hold-out test set set p_test = 0
elif args.task == 'cardiac-mtl':
dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTLSplit.csv',
shuffle = False,
seed = args.seed,
print_info = True,
label_dict = {'healthy' :0,
'quilty' :1,
'cell_only_low' :2,
'cell_only_high' :3,
'cell_low_quilty' :4,
'cell_high_quilty' :5,
'amr_only_low' :6,
'amr_only_high' :7,
'amr_low_quilty' :8,
'amr_high_quilty' :9,
'cell_amr_low' :10,
'cell_amr_high' :11,
'cell_amr_quilty_low' :12,
'cell_amr_quilty_high' :13},
patient_strat= True,
ignore=[])
p_val = 0.1 # use 10% of data in validation
p_test = 0.2 # use 20% data for test set
else:
raise NotImplementedError
# splits
num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
val_num = np.floor(num_slides_cls * p_val).astype(int) # use 10% data in validation
test_num = np.floor(num_slides_cls * p_test).astype(int) # use 20% for test set
print("---------------------------------")
print(f"validation set size = {val_num} ")
print(f"test set size = {test_num}")
print("---------------------------------")
if __name__ == '__main__':
if args.label_frac > 0:
label_fracs = [args.label_frac]
else:
label_fracs = [0.25, 0.5, 0.75, 1.0]
if args.hold_out_test:
custom_test_ids = dataset.sample_held_out(test_num=test_num)
else:
custom_test_ids = None
for lf in label_fracs:
if args.split_code is not None:
split_dir = 'splits/'+ str(args.split_code) + '_{}'.format(int(lf * 100))
else:
split_dir = 'splits/'+ str(args.task) + '_{}'.format(int(lf * 100))
os.makedirs(split_dir, exist_ok=True)
#pdb.set_trace()
dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf, custom_test_ids=custom_test_ids)
for i in range(args.k):
dataset.set_splits()
descriptor_df = dataset.test_split_gen(return_descriptor=True)
splits = dataset.return_splits(from_id=True)
save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
save_splits(splits, ['train', 'val','test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))