Switch to unified view

a b/datasets/dataset_generic.py
1
from __future__ import print_function, division
2
import os
3
import torch
4
import numpy as np
5
import pandas as pd
6
import math
7
import re
8
import pdb
9
import pickle
10
from scipy import stats
11
12
from torch.utils.data import Dataset
13
import h5py
14
15
from utils.utils import generate_split, nth
16
17
18
19
def save_splits(split_datasets, column_keys, filename, boolean_style=False):
20
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
21
    if not boolean_style:
22
        df = pd.concat(splits, ignore_index=True, axis=1)
23
        df.columns = column_keys
24
    else:
25
        df = pd.concat(splits, ignore_index = True, axis=0)
26
        index = df.values.tolist()
27
        one_hot = np.eye(len(split_datasets)).astype(bool)
28
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
29
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
30
31
    df.to_csv(filename)
32
    print()
33
34
class Generic_WSI_Classification_Dataset(Dataset):
35
    def __init__(self,
36
        csv_path = 'dataset_csv/ccrcc_clean.csv',
37
        shuffle = False, 
38
        seed = 7, 
39
        print_info = True,
40
        label_dict = {},
41
        ignore=[],
42
        patient_strat=False,
43
        label_col = None,
44
        patient_voting = 'max',
45
        multi_site = False,
46
        filter_dict = {},
47
                patient_level = False
48
        ):
49
        """
50
        Args:
51
            csv_file (string): Path to the csv file with annotations.
52
            shuffle (boolean): Whether to shuffle
53
            seed (int): random seed for shuffling the data
54
            print_info (boolean): Whether to print a summary of the dataset
55
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
56
            ignore (list): List containing class labels to ignore
57
            patient_voting (string): Rule for deciding the patient-level label
58
        """
59
        self.custom_test_ids = None
60
        self.seed = seed
61
        self.print_info = print_info
62
        self.patient_strat = patient_strat
63
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
64
        self.data_dir = None
65
        self.split_gen = None
66
        self.patient_level = patient_level
67
68
        if not label_col:
69
            label_col = 'label'
70
        self.label_col = label_col
71
72
        slide_data = pd.read_csv(csv_path)
73
        slide_data = self.filter_df(slide_data, filter_dict)
74
75
        if multi_site:
76
            label_dict = self.init_multi_site_label_dict(slide_data, label_dict)
77
78
        self.label_dict = label_dict
79
        self.num_classes=len(set(self.label_dict.values()))
80
        
81
        slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col, multi_site)
82
83
        ###shuffle data
84
        if shuffle:
85
            np.random.seed(seed)
86
            np.random.shuffle(slide_data)
87
88
        self.slide_data = slide_data
89
90
        self.patient_data_prep(patient_voting)
91
        self.cls_ids_prep()
92
93
        if print_info:
94
            self.summarize()
95
96
                
97
        if self.patient_level:
98
            self.patient_dict = self.build_patient_dict()
99
            #self.slide_data   = self.slide_data.drop_duplicates(subset=['case_id'])
100
        else:
101
            self.patient_dict = {}
102
103
104
    def build_patient_dict(self):
105
        patient_dict = {}   
106
        patient_cases = self.slide_data['case_id'].unique()
107
        slide_cases   = self.slide_data.set_index('case_id')
108
    
109
        for patient in patient_cases:
110
            slide_ids = slide_cases.loc[patient,'slide_id']
111
112
            if isinstance(slide_ids, str):
113
                slide_ids = np.array(slide_ids).reshape(-1)
114
            else:
115
                slide_ids = slide_ids.values
116
        
117
            patient_dict.update({patient:slide_ids})
118
        
119
        return patient_dict
120
121
122
    def cls_ids_prep(self):
123
        # store ids corresponding each class at the patient or case level
124
        self.patient_cls_ids = [[] for i in range(self.num_classes)]        
125
        for i in range(self.num_classes):
126
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
127
128
        # store ids corresponding each class at the slide level
129
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
130
        for i in range(self.num_classes):
131
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
132
133
    def patient_data_prep(self, patient_voting='max'):
134
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
135
        patient_labels = []
136
        
137
        for p in patients:
138
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
139
            assert len(locations) > 0
140
            label = self.slide_data['label'][locations].values
141
            if patient_voting == 'max':
142
                label = label.max() # get patient label (MIL convention)
143
            elif patient_voting == 'maj':
144
                label = stats.mode(label)[0]
145
            else:
146
                raise NotImplementedError
147
            patient_labels.append(label)
148
        
149
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
150
151
    @staticmethod
152
    def init_multi_site_label_dict(slide_data, label_dict):
153
        print('initiating multi-source label dictionary')
154
        sites = np.unique(slide_data['site'].values)
155
        multi_site_dict = {}
156
        num_classes = len(label_dict)
157
        for key, val in label_dict.items():
158
            for idx, site in enumerate(sites):
159
                site_key = (key, site)
160
                site_val = val+idx*num_classes
161
                multi_site_dict.update({site_key:site_val})
162
                print('{} : {}'.format(site_key, site_val))
163
        return multi_site_dict
164
165
    @staticmethod
166
    def filter_df(df, filter_dict={}):
167
        if len(filter_dict) > 0:
168
            filter_mask = np.full(len(df), True, bool)
169
            # assert 'label' not in filter_dict.keys()
170
            for key, val in filter_dict.items():
171
                mask = df[key].isin(val)
172
                filter_mask = np.logical_and(filter_mask, mask)
173
            df = df[filter_mask]
174
        return df
175
176
    @staticmethod
177
    def df_prep(data, label_dict, ignore, label_col, multi_site=False):
178
        if label_col != 'label':
179
            data['label'] = data[label_col].copy()
180
181
        mask = data['label'].isin(ignore)
182
        data = data[~mask]
183
        data.reset_index(drop=True, inplace=True)
184
        for i in data.index:
185
            key = data.loc[i, 'label']
186
            if multi_site:
187
                site = data.loc[i, 'site']
188
                key = (key, site)
189
            data.at[i, 'label'] = label_dict[key]
190
191
        return data
192
193
    def __len__(self):
194
        if self.patient_strat:
195
            return len(self.patient_data['case_id'])
196
197
        else:
198
            return len(self.slide_data)
199
200
    def summarize(self):
201
        print("label column: {}".format(self.label_col))
202
        print("label dictionary: {}".format(self.label_dict))
203
        print("number of classes: {}".format(self.num_classes))
204
        print("slide-level counts: ", self.slide_data['label'].value_counts(sort = False))
205
        for i in range(self.num_classes):
206
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
207
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
208
209
    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
210
        settings = {
211
                    'n_splits' : k, 
212
                    'val_num' : val_num, 
213
                    'test_num': test_num,
214
                    'label_frac': label_frac,
215
                    'seed': self.seed,
216
                    'custom_test_ids': custom_test_ids
217
                    }
218
219
        if self.patient_strat:
220
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
221
        else:
222
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})
223
224
        self.split_gen = generate_split(**settings)
225
226
    def sample_held_out(self, test_num = (40, 40)):
227
228
        test_ids = []
229
        np.random.seed(self.seed) #fix seed
230
        
231
        if self.patient_strat:
232
            cls_ids = self.patient_cls_ids
233
        else:
234
            cls_ids = self.slide_cls_ids
235
236
        for c in range(len(test_num)):
237
            test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids
238
239
        # if self.patient_strat:
240
        #   slide_ids = [] 
241
        #   for idx in test_ids:
242
        #       case_id = self.patient_data['case_id'][idx]
243
        #       slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
244
        #       slide_ids.extend(slide_indices)
245
246
        #   return slide_ids
247
        # else:
248
        #   return test_ids
249
        return test_ids
250
251
    def set_splits(self,start_from=None):
252
        if start_from:
253
            ids = nth(self.split_gen, start_from)
254
255
        else:
256
            ids = next(self.split_gen)
257
258
        if self.patient_strat:
259
            slide_ids = [[] for i in range(len(ids))] 
260
261
            for split in range(len(ids)): 
262
                for idx in ids[split]:
263
                    case_id = self.patient_data['case_id'][idx]
264
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
265
                    slide_ids[split].extend(slide_indices)
266
267
            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]
268
269
        else:
270
            self.train_ids, self.val_ids, self.test_ids = ids
271
272
    def get_split_from_df(self, all_splits=None, split_key='train', split=None, return_ids_only=False):
273
        if split is None:
274
            split = all_splits[split_key]
275
            split = split.dropna().reset_index(drop=True)
276
277
        if len(split) > 0:
278
            mask = self.slide_data['slide_id'].isin(split.tolist())
279
            if return_ids_only:
280
                ids = np.where(mask)[0]
281
                return ids
282
            
283
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
284
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, patient_level=self.patient_level)
285
        else:
286
            split = None
287
        
288
        return split
289
290
    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
291
        merged_split = []
292
        for split_key in split_keys:
293
            split = all_splits[split_key]
294
            split = split.dropna().reset_index(drop=True).tolist()
295
            merged_split.extend(split)
296
297
        if len(split) > 0:
298
            mask = self.slide_data['slide_id'].isin(merged_split)
299
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
300
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level)
301
        else:
302
            split = None
303
        
304
        return split
305
306
307
    def return_splits(self, from_id=True, csv_path=None):
308
309
310
        if from_id:
311
            if len(self.train_ids) > 0:
312
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
313
                train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level)
314
315
            else:
316
                train_split = None
317
            
318
            if len(self.val_ids) > 0:
319
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
320
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level)
321
322
            else:
323
                val_split = None
324
            
325
            if len(self.test_ids) > 0:
326
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
327
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level)
328
            
329
            else: # NO TEST SET - USE COPY OF VALIDATION SET
330
                                #test_split = None
331
                                test_data  = self.slide_data.loc[self.val_ids].reset_index(drop=True)
332
                                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level)
333
        
334
        else:
335
            assert csv_path 
336
            all_splits = pd.read_csv(csv_path)
337
            train_split = self.get_split_from_df(all_splits, 'train')
338
            val_split = self.get_split_from_df(all_splits, 'val')
339
            test_split = self.get_split_from_df(all_splits, 'test')
340
            
341
        return train_split, val_split, test_split
342
343
    def get_list(self, ids):
344
        return self.slide_data['slide_id'][ids]
345
346
    def getlabel(self, ids):
347
        return self.slide_data['label'][ids]
348
349
    def __getitem__(self, idx):
350
        return None
351
352
    def test_split_gen(self, return_descriptor=False):
353
        if return_descriptor:
354
            index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)]
355
            columns = ['train', 'val', 'test']
356
            df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
357
                            columns= columns)
358
        count = len(self.train_ids)
359
        print('\nnumber of training samples: {}'.format(count))
360
        labels = self.getlabel(self.train_ids)
361
        unique, counts = np.unique(labels, return_counts=True)
362
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
363
        unique = np.append(unique, missing_classes)
364
        counts = np.append(counts, np.full(len(missing_classes), 0))
365
        inds = unique.argsort()
366
        counts = counts[inds]
367
        for u in range(len(unique)):
368
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
369
            if return_descriptor:
370
                df.loc[index[u], 'train'] = counts[u]
371
        
372
        count = len(self.val_ids)
373
        print('\nnumber of val samples: {}'.format(count))
374
        labels = self.getlabel(self.val_ids)
375
        unique, counts = np.unique(labels, return_counts=True)
376
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
377
        unique = np.append(unique, missing_classes)
378
        counts = np.append(counts, np.full(len(missing_classes), 0))
379
        inds = unique.argsort()
380
        counts = counts[inds]
381
        for u in range(len(unique)):
382
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
383
            if return_descriptor:
384
                df.loc[index[u], 'val'] = counts[u]
385
386
        count = len(self.test_ids)
387
        print('\nnumber of test samples: {}'.format(count))
388
        labels = self.getlabel(self.test_ids)
389
        unique, counts = np.unique(labels, return_counts=True)
390
        missing_classes = np.setdiff1d(np.arange(self.num_classes), unique)
391
        unique = np.append(unique, missing_classes)
392
        counts = np.append(counts, np.full(len(missing_classes), 0))
393
        inds = unique.argsort()
394
        counts = counts[inds]
395
        for u in range(len(unique)):
396
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
397
            if return_descriptor:
398
                df.loc[index[u], 'test'] = counts[u]
399
400
        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
401
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
402
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
403
404
        if return_descriptor:
405
            return df
406
407
    def save_split(self, filename):
408
        train_split = self.get_list(self.train_ids)
409
        val_split = self.get_list(self.val_ids)
410
        test_split = self.get_list(self.test_ids)
411
        df_tr = pd.DataFrame({'train': train_split})
412
        df_v = pd.DataFrame({'val': val_split})
413
        df_t = pd.DataFrame({'test': test_split})
414
        df = pd.concat([df_tr, df_v, df_t], axis=1) 
415
        df.to_csv(filename, index = False)
416
417
418
class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
419
    def __init__(self,
420
        data_dir, 
421
        **kwargs):
422
        super(Generic_MIL_Dataset, self).__init__(**kwargs)
423
        self.data_dir = data_dir
424
        self.use_h5 = False
425
426
    def load_from_h5(self, toggle):
427
        self.use_h5 = toggle
428
429
    def __getitem__(self, idx):
430
431
        if not self.patient_level:
432
            slide_id = self.slide_data['slide_id'][idx]
433
            label = self.slide_data['label'][idx]
434
            if type(self.data_dir) == dict:
435
                source = self.slide_data['source'][idx]
436
                data_dir = self.data_dir[source]
437
            else:
438
                data_dir = self.data_dir
439
440
            if not self.use_h5:
441
                if self.data_dir:
442
                    full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
443
                    features = torch.load(full_path)
444
                    return features, label
445
                
446
                else:
447
                    return slide_id, label
448
449
            else:
450
                full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
451
                with h5py.File(full_path,'r') as hdf5_file:
452
                    features = hdf5_file['features'][:]
453
                    coords = hdf5_file['coords'][:]
454
455
                features = torch.from_numpy(features)
456
                return features, label, coords
457
458
        else:
459
            
460
            case_id   = self.slide_data['case_id'][idx]
461
            label     = self.slide_data['label'][idx]
462
            slide_ids = self.patient_dict[case_id]
463
464
            if type(self.data_dir) == dict:
465
                source = self.slide_data['source'][idx]
466
                data_dir = self.data_dir[source]
467
            else:
468
                data_dir = self.data_dir
469
470
            if not self.use_h5:
471
                features_list = []
472
                
473
                for slide_id in slide_ids:
474
                    full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
475
                    slide_features = torch.load(full_path)
476
                    features_list.append(slide_features)
477
478
                features = torch.cat( features_list, dim = 0)
479
                return features, label  
480
481
            else:
482
                features_list = []
483
                coords_list   = []
484
485
                for slide_id in slide_ids:
486
                    full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
487
                    with h5py.File(full_path,'r') as hdf5_file:
488
                        slide_features   = hdf5_file['features'][:]
489
                        slide_coords     = hdf5_file['coords'][:]
490
                        silide_features_t = torch.from_numpy(slide_features)
491
                        slide_coords_t   = torch.from_numpy(slide_coords)
492
493
                        features_list.append( slide_features_t )
494
                        coords_list.append(   slide_coords_t   )
495
                    
496
497
                features = troch.cat( features_list, dim = 0)
498
                coords   = torch.cat( coords_list,   dim = 0)
499
                return features, label, coords
500
501
502
class Generic_Split(Generic_MIL_Dataset):
503
    def __init__(self, slide_data, data_dir=None, num_classes=2, patient_level=False):
504
        self.use_h5 = False
505
        self.slide_data = slide_data
506
        self.data_dir = data_dir
507
        self.num_classes = num_classes
508
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
509
        for i in range(self.num_classes):
510
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
511
        
512
        self.patient_level = patient_level
513
        if self.patient_level:
514
            self.patient_dict = self.build_patient_dict()
515
            #self.slide_data   = self.slide_data.drop_duplicates(subset=['case_id'])
516
        else:
517
            self.patient_dict = {}
518
519
    def __len__(self):
520
        return len(self.slide_data)
521
        
522
523
class Generic_WSI_Inference_Dataset(Dataset):
524
    def __init__(self,
525
        data_dir,
526
        csv_path = None,
527
        print_info = True,
528
        ):
529
        self.data_dir = data_dir
530
        self.print_info = print_info
531
532
        if csv_path is not None:
533
            data = pd.read_csv(csv_path)
534
            self.slide_data = data['slide_id'].values
535
        else:
536
            data = np.array(os.listdir(data_dir))
537
            self.slide_data = np.char.strip(data, chars ='.pt') 
538
        if print_info:
539
            print('total number of slides to infer: ', len(self.slide_data))
540
541
    def __len__(self):
542
        return len(self.slide_data)
543
544
    def __getitem__(self, idx):
545
        slide_file = self.slide_data[idx]+'.pt'
546
        full_path = os.path.join(self.data_dir, 'pt_files',slide_file)
547
        features = torch.load(full_path)
548
        return features