a b/datasets/dataset_generic.py
1
from __future__ import print_function, division
2
import os
3
import torch
4
import numpy as np
5
import pandas as pd
6
import math
7
import re
8
import pdb
9
import pickle
10
from scipy import stats
11
12
from torch.utils.data import Dataset
13
import h5py
14
15
from utils.utils import generate_split, nth
16
17
18
19
def save_splits(split_datasets, column_keys, filename, boolean_style=False):
20
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
21
    if not boolean_style:
22
        df = pd.concat(splits, ignore_index=True, axis=1)
23
        df.columns = column_keys
24
    else:
25
        df = pd.concat(splits, ignore_index = True, axis=0)
26
        index = df.values.tolist()
27
        one_hot = np.eye(len(split_datasets)).astype(bool)
28
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
29
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
30
31
    df.to_csv(filename)
32
    print()
33
34
class Generic_WSI_Classification_Dataset(Dataset):
35
    def __init__(self,
36
        csv_path = 'dataset_csv/ccrcc_clean.csv',
37
        shuffle = False, 
38
        seed = 7, 
39
        print_info = True,
40
        label_dict = {},
41
        ignore=[],
42
        patient_strat=False,
43
        label_col = None,
44
        patient_voting = 'max',
45
        multi_site = False,
46
        filter_dict = {},
47
        ):
48
        """
49
        Args:
50
            csv_file (string): Path to the csv file with annotations.
51
            shuffle (boolean): Whether to shuffle
52
            seed (int): random seed for shuffling the data
53
            print_info (boolean): Whether to print a summary of the dataset
54
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
55
            ignore (list): List containing class labels to ignore
56
            patient_voting (string): Rule for deciding the patient-level label
57
        """
58
        self.custom_test_ids = None
59
        self.seed = seed
60
        self.print_info = print_info
61
        self.patient_strat = patient_strat
62
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
63
        self.data_dir = None
64
        if not label_col:
65
            label_col = 'label'
66
        self.label_col = label_col
67
68
        slide_data = pd.read_csv(csv_path)
69
        slide_data = self.filter_df(slide_data, filter_dict)
70
71
        if multi_site:
72
            label_dict = self.init_multi_site_label_dict(slide_data, label_dict)
73
74
        self.label_dict = label_dict
75
        self.num_classes=len(set(self.label_dict.values()))
76
        
77
        slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col, multi_site)
78
79
        ###shuffle data
80
        if shuffle:
81
            np.random.seed(seed)
82
            np.random.shuffle(slide_data)
83
84
        self.slide_data = slide_data
85
86
        self.patient_data_prep(patient_voting)
87
        self.cls_ids_prep()
88
89
        if print_info:
90
            self.summarize()
91
92
    def cls_ids_prep(self):
93
        # store ids corresponding each class at the patient or case level
94
        self.patient_cls_ids = [[] for i in range(self.num_classes)]        
95
        for i in range(self.num_classes):
96
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
97
98
        # store ids corresponding each class at the slide level
99
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
100
        for i in range(self.num_classes):
101
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
102
103
    def patient_data_prep(self, patient_voting='max'):
104
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
105
        patient_labels = []
106
        
107
        for p in patients:
108
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
109
            assert len(locations) > 0
110
            label = self.slide_data['label'][locations].values
111
            if patient_voting == 'max':
112
                label = label.max() # get patient label (MIL convention)
113
            elif patient_voting == 'maj':
114
                label = stats.mode(label)[0]
115
            else:
116
                raise NotImplementedError
117
            patient_labels.append(label)
118
        
119
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
120
121
    @staticmethod
122
    def init_multi_site_label_dict(slide_data, label_dict):
123
        print('initiating multi-source label dictionary')
124
        sites = np.unique(slide_data['site'].values)
125
        multi_site_dict = {}
126
        num_classes = len(label_dict)
127
        for key, val in label_dict.items():
128
            for idx, site in enumerate(sites):
129
                site_key = (key, site)
130
                site_val = val+idx*num_classes
131
                multi_site_dict.update({site_key:site_val})
132
                print('{} : {}'.format(site_key, site_val))
133
        return multi_site_dict
134
135
    @staticmethod
136
    def filter_df(df, filter_dict={}):
137
        if len(filter_dict) > 0:
138
            filter_mask = np.full(len(df), True, bool)
139
            # assert 'label' not in filter_dict.keys()
140
            for key, val in filter_dict.items():
141
                mask = df[key].isin(val)
142
                filter_mask = np.logical_and(filter_mask, mask)
143
            df = df[filter_mask]
144
        return df
145
146
    @staticmethod
147
    def df_prep(data, label_dict, ignore, label_col, multi_site=False):
148
        if label_col != 'label':
149
            data['label'] = data[label_col].copy()
150
151
        mask = data['label'].isin(ignore)
152
        data = data[~mask]
153
        data.reset_index(drop=True, inplace=True)
154
        for i in data.index:
155
            key = data.loc[i, 'label']
156
            if multi_site:
157
                site = data.loc[i, 'site']
158
                key = (key, site)
159
            data.at[i, 'label'] = label_dict[key]
160
161
        return data
162
163
    def __len__(self):
164
        if self.patient_strat:
165
            return len(self.patient_data['case_id'])
166
167
        else:
168
            return len(self.slide_data)
169
170
    def summarize(self):
171
        print("label column: {}".format(self.label_col))
172
        print("label dictionary: {}".format(self.label_dict))
173
        print("number of classes: {}".format(self.num_classes))
174
        print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
175
        for i in range(self.num_classes):
176
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
177
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
178
179
    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
180
        settings = {
181
                    'n_splits' : k, 
182
                    'val_num' : val_num, 
183
                    'test_num': test_num,
184
                    'label_frac': label_frac,
185
                    'seed': self.seed,
186
                    'custom_test_ids': custom_test_ids
187
                    }
188
189
        if self.patient_strat:
190
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
191
        else:
192
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})
193
194
        self.split_gen = generate_split(**settings)
195
196
    def sample_held_out(self, test_num = (40, 40)):
197
198
        test_ids = []
199
        np.random.seed(self.seed) #fix seed
200
        
201
        if self.patient_strat:
202
            cls_ids = self.patient_cls_ids
203
        else:
204
            cls_ids = self.slide_cls_ids
205
206
        for c in range(len(test_num)):
207
            test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids
208
209
        if self.patient_strat:
210
            slide_ids = [] 
211
            for idx in test_ids:
212
                case_id = self.patient_data['case_id'][idx]
213
                slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
214
                slide_ids.extend(slide_indices)
215
216
            return slide_ids
217
        else:
218
            return test_ids
219
220
    def set_splits(self,start_from=None):
221
        if start_from:
222
            ids = nth(self.split_gen, start_from)
223
224
        else:
225
            ids = next(self.split_gen)
226
227
        if self.patient_strat:
228
            slide_ids = [[] for i in range(len(ids))] 
229
230
            for split in range(len(ids)): 
231
                for idx in ids[split]:
232
                    case_id = self.patient_data['case_id'][idx]
233
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
234
                    slide_ids[split].extend(slide_indices)
235
236
            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]
237
238
        else:
239
            self.train_ids, self.val_ids, self.test_ids = ids
240
241
    def get_split_from_df(self, all_splits, split_key='train'):
242
        split = all_splits[split_key]
243
        split = split.dropna().reset_index(drop=True)
244
245
        if len(split) > 0:
246
            mask = self.slide_data['slide_id'].isin(split.tolist())
247
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
248
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
249
        else:
250
            split = None
251
        
252
        return split
253
254
    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
255
        merged_split = []
256
        for split_key in split_keys:
257
            split = all_splits[split_key]
258
            split = split.dropna().reset_index(drop=True).tolist()
259
            merged_split.extend(split)
260
261
        if len(split) > 0:
262
            mask = self.slide_data['slide_id'].isin(merged_split)
263
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
264
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
265
        else:
266
            split = None
267
        
268
        return split
269
270
271
    def return_splits(self, from_id=True, csv_path=None):
272
        if from_id:
273
            if len(self.train_ids) > 0:
274
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
275
                train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes)
276
277
            else:
278
                train_split = None
279
            
280
            if len(self.val_ids) > 0:
281
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
282
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes)
283
284
            else:
285
                val_split = None
286
            
287
            if len(self.test_ids) > 0:
288
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
289
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes)
290
            
291
            else:
292
                test_split = None
293
            
294
        
295
        else:
296
            assert csv_path 
297
            all_splits = pd.read_csv(csv_path)
298
            train_split = self.get_split_from_df(all_splits, 'train')
299
            val_split = self.get_split_from_df(all_splits, 'val')
300
            test_split = self.get_split_from_df(all_splits, 'test')
301
            
302
        return train_split, val_split, test_split
303
304
    def get_list(self, ids):
305
        return self.slide_data['slide_id'][ids]
306
307
    def getlabel(self, ids):
308
        return self.slide_data['label'][ids]
309
310
    def __getitem__(self, idx):
311
        return None
312
313
    def test_split_gen(self, return_descriptor=False):
314
        if return_descriptor:
315
            index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)]
316
            columns = ['train', 'val', 'test']
317
            df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
318
                            columns= columns)
319
        count = len(self.train_ids)
320
        print('\nnumber of training samples: {}'.format(count))
321
        labels = self.getlabel(self.train_ids)
322
        unique, counts = np.unique(labels, return_counts=True)
323
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
324
        unique = np.append(unique, missing_classes)
325
        counts = np.append(counts, np.full(len(missing_classes), 0))
326
        inds = unique.argsort()
327
        counts = counts[inds]
328
        for u in range(len(unique)):
329
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
330
            if return_descriptor:
331
                df.loc[index[u], 'train'] = counts[u]
332
        
333
        count = len(self.val_ids)
334
        print('\nnumber of val samples: {}'.format(count))
335
        labels = self.getlabel(self.val_ids)
336
        unique, counts = np.unique(labels, return_counts=True)
337
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
338
        unique = np.append(unique, missing_classes)
339
        counts = np.append(counts, np.full(len(missing_classes), 0))
340
        inds = unique.argsort()
341
        counts = counts[inds]
342
        for u in range(len(unique)):
343
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
344
            if return_descriptor:
345
                df.loc[index[u], 'val'] = counts[u]
346
347
        count = len(self.test_ids)
348
        print('\nnumber of test samples: {}'.format(count))
349
        labels = self.getlabel(self.test_ids)
350
        unique, counts = np.unique(labels, return_counts=True)
351
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
352
        unique = np.append(unique, missing_classes)
353
        counts = np.append(counts, np.full(len(missing_classes), 0))
354
        inds = unique.argsort()
355
        counts = counts[inds]
356
        for u in range(len(unique)):
357
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
358
            if return_descriptor:
359
                df.loc[index[u], 'test'] = counts[u]
360
361
        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
362
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
363
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
364
365
        if return_descriptor:
366
            return df
367
368
    def save_split(self, filename):
369
        train_split = self.get_list(self.train_ids)
370
        val_split = self.get_list(self.val_ids)
371
        test_split = self.get_list(self.test_ids)
372
        df_tr = pd.DataFrame({'train': train_split})
373
        df_v = pd.DataFrame({'val': val_split})
374
        df_t = pd.DataFrame({'test': test_split})
375
        df = pd.concat([df_tr, df_v, df_t], axis=1) 
376
        df.to_csv(filename, index = False)
377
378
379
class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
380
    def __init__(self,
381
        data_dir, 
382
        **kwargs):
383
        super(Generic_MIL_Dataset, self).__init__(**kwargs)
384
        self.data_dir = data_dir
385
        self.use_h5 = False
386
387
    def load_from_h5(self, toggle):
388
        self.use_h5 = toggle
389
390
    def __getitem__(self, idx):
391
        slide_id = self.slide_data['slide_id'][idx]
392
        label = self.slide_data['label'][idx]
393
        if type(self.data_dir) == dict:
394
            source = self.slide_data['source'][idx]
395
            data_dir = self.data_dir[source]
396
        else:
397
            data_dir = self.data_dir
398
399
        if not self.use_h5:
400
            if self.data_dir:
401
                full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
402
                features = torch.load(full_path)
403
                return features, label
404
            
405
            else:
406
                return slide_id, label
407
408
        else:
409
            full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
410
            with h5py.File(full_path,'r') as hdf5_file:
411
                features = hdf5_file['features'][:]
412
                coords = hdf5_file['coords'][:]
413
414
            features = torch.from_numpy(features)
415
            return features, label, coords
416
417
418
class Generic_Split(Generic_MIL_Dataset):
419
    def __init__(self, slide_data, data_dir=None, num_classes=2):
420
        self.use_h5 = False
421
        self.slide_data = slide_data
422
        self.data_dir = data_dir
423
        self.num_classes = num_classes
424
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
425
        for i in range(self.num_classes):
426
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
427
428
    def __len__(self):
429
        return len(self.slide_data)
430
        
431
432
class Generic_WSI_Inference_Dataset(Dataset):
433
    def __init__(self,
434
        data_dir,
435
        csv_path = None,
436
        print_info = True,
437
        ):
438
        self.data_dir = data_dir
439
        self.print_info = print_info
440
441
        if csv_path is not None:
442
            data = pd.read_csv(csv_path)
443
            self.slide_data = data['slide_id'].values
444
        else:
445
            data = np.array(os.listdir(data_dir))
446
            self.slide_data = np.char.strip(data, chars ='.pt') 
447
        if print_info:
448
            print('total number of slides to infer: ', len(self.slide_data))
449
450
    def __len__(self):
451
        return len(self.slide_data)
452
453
    def __getitem__(self, idx):
454
        slide_file = self.slide_data[idx]+'.pt'
455
        full_path = os.path.join(self.data_dir, 'pt_files',slide_file)
456
        features = torch.load(full_path)
457
        return features