Diff of /datasets/dataset_mtl.py [000000] .. [4cd6c8]

Switch to unified view

a b/datasets/dataset_mtl.py
1
from __future__ import print_function, division
2
import os
3
import torch
4
import numpy as np
5
import pandas as pd
6
import math
7
import re
8
import pdb
9
import pickle
10
from scipy import stats
11
12
from torch.utils.data import Dataset
13
import h5py
14
15
from utils.utils import generate_split, nth
16
17
18
def save_splits(split_datasets, column_keys, filename, boolean_style=False):
19
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
20
    if not boolean_style:
21
        df = pd.concat(splits, ignore_index=True, axis=1)
22
        df.columns = column_keys
23
    else:
24
        df = pd.concat(splits, ignore_index = True, axis=0)
25
        index = df.values.tolist()
26
        one_hot = np.eye(len(split_datasets)).astype(bool)
27
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
28
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
29
30
    df.to_csv(filename)
31
    print()
32
33
class Generic_WSI_MTL_Dataset(Dataset):
34
    def __init__(self,
35
        csv_path = 'dataset_csv/ccrcc_clean.csv',
36
        shuffle = False,
37
        seed = 7,
38
        print_info = True,
39
        label_dicts = [{}, {}],
40
        ignore=[],
41
        patient_strat=False,
42
        label_cols = ['label', 'label_2'],
43
        patient_voting = 'max',
44
        multi_site = False,
45
        filter_dict = {},
46
        patient_level = False
47
        ):
48
        """
49
        Args:
50
            csv_file (string): Path to the csv file with annotations.
51
            shuffle (boolean): Whether to shuffle
52
            seed (int): random seed for shuffling the data
53
            print_info (boolean): Whether to print a summary of the dataset
54
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
55
            ignore (list): List containing class labels to ignore
56
            patient_voting (string): Rule for deciding the patient-level label
57
        """
58
        self.custom_test_ids = None
59
        self.seed = seed
60
        self.print_info = print_info
61
        self.patient_strat = patient_strat
62
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
63
        self.data_dir = None
64
        self.label_cols = label_cols
65
66
        slide_data = pd.read_csv(csv_path)
67
        slide_data = self.filter_df(slide_data, filter_dict)
68
69
        self.patient_level = patient_level
70
71
        if multi_site:
72
            label_dicts[0] = self.init_multi_site_label_dict(slide_data, label_dicts[0])
73
74
        self.label_dicts = label_dicts
75
        self.num_classes=[len(set(label_dict.values())) for label_dict in self.label_dicts]
76
77
        slide_data = self.df_prep(slide_data, self.label_dicts, ignore, self.label_cols, multi_site)
78
        ###shuffle data
79
        if shuffle:
80
            np.random.seed(seed)
81
            np.random.shuffle(slide_data)
82
83
        self.slide_data = slide_data
84
85
        #self.patient_data_prep(patient_voting)
86
        #self.cls_ids_prep()
87
88
        #if print_info:
89
        #   self.summarize()
90
91
92
        if self.patient_level:
93
            self.patient_dict = self.build_patient_dict()
94
            #self.slide_data   = self.slide_data.drop_duplicates(subset=['case_id'])
95
        else:
96
            self.patient_dict = {}
97
98
99
    def build_patient_dict(self):
100
                patient_dict = {}
101
                patient_cases = self.slide_data['case_id'].unique()
102
                slide_cases   = self.slide_data.set_index('case_id')
103
104
                for patient in patient_cases:
105
                        slide_ids = slide_cases.loc[patient,'slide_id']
106
107
                        if isinstance(slide_ids, str):
108
                                slide_ids = np.array(slide_ids).reshape(-1)
109
                        else:
110
                                slide_ids = slide_ids.values
111
112
                        patient_dict.update({patient:slide_ids})
113
114
                return patient_dict
115
116
117
    def cls_ids_prep(self):
118
119
        b_weighted_samples=False
120
121
        if(b_weighted_samples):
122
123
            # store ids corresponding each class at the patient or case level
124
            self.patient_cls_ids = [[] for i in range(self.num_classes[0])]
125
            for i in range(self.num_classes[0]):
126
                self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
127
128
            # store ids corresponding each class at the slide level
129
            self.slide_cls_ids = [[] for i in range(self.num_classes[0])]
130
            for i in range(self.num_classes[0]):
131
                self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
132
    
133
        else:
134
            self.patient_cls_ids = None
135
            self.slide_cls_ids = None
136
137
138
    def patient_data_prep(self, patient_voting='max'):
139
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
140
        patient_labels = []
141
142
        for p in patients:
143
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
144
            assert len(locations) > 0
145
            label = self.slide_data['label'][locations].values
146
            if patient_voting == 'max':
147
                label = label.max() # get patient label (MIL convention)
148
            elif patient_voting == 'maj':
149
                label = stats.mode(label)[0]
150
            else:
151
                raise NotImplementedError
152
            patient_labels.append(label)
153
154
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
155
156
    @staticmethod
157
    def init_multi_site_label_dict(slide_data, label_dict):
158
        print('initiating multi-source label dictionary')
159
        sites = np.unique(slide_data['site'].values)
160
        multi_site_dict = {}
161
        num_classes = len(label_dict)
162
        for key, val in label_dict.items():
163
            for idx, site in enumerate(sites):
164
                site_key = (key, site)
165
                site_val = val+idx*num_classes
166
                multi_site_dict.update({site_key:site_val})
167
                print('{} : {}'.format(site_key, site_val))
168
        return multi_site_dict
169
170
    @staticmethod
171
    def filter_df(df, filter_dict={}):
172
        if len(filter_dict) > 0:
173
            filter_mask = np.full(len(df), True, bool)
174
            # assert 'label' not in filter_dict.keys()
175
            for key, val in filter_dict.items():
176
                mask = df[key].isin(val)
177
                filter_mask = np.logical_and(filter_mask, mask)
178
            df = df[filter_mask]
179
        return df
180
181
    @staticmethod
182
    def df_prep(data, label_dicts, ignore, label_cols, multi_site=False):
183
        for idx, (label_dict, label_col) in enumerate(zip(label_dicts, label_cols)):
184
            print(label_dict, label_col)
185
            data[label_col] = data[label_col].map(label_dict)
186
187
        return data
188
189
    def __len__(self):
190
        if self.patient_strat:
191
            return len(self.patient_data['case_id'])
192
193
        else:
194
            return len(self.slide_data)
195
196
    def summarize(self):
197
198
        for task in range(len(self.label_dicts)):
199
            print('task: ', task)
200
            print("label column: {}".format(self.label_cols[task]))
201
            print("label dictionary: {}".format(self.label_dicts[task]))
202
            print("number of classes: {}".format(self.num_classes[task]))
203
            print("slide-level counts: ", '\n', self.slide_data[self.label_cols[task]].value_counts(sort = False))
204
205
        for i in range(self.num_classes[0]):
206
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
207
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
208
209
    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
210
        settings = {
211
                    'n_splits' : k,
212
                    'val_num' : val_num,
213
                    'test_num': test_num,
214
                    'label_frac': label_frac,
215
                    'seed': self.seed,
216
                    'custom_test_ids': custom_test_ids
217
                    }
218
219
        if self.patient_strat:
220
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
221
        else:
222
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})
223
224
        self.split_gen = generate_split(**settings)
225
226
    def sample_held_out(self, test_num = (40, 40)):
227
228
        test_ids = []
229
        np.random.seed(self.seed) #fix seed
230
231
        if self.patient_strat:
232
            cls_ids = self.patient_cls_ids
233
        else:
234
            cls_ids = self.slide_cls_ids
235
236
        for c in range(len(test_num)):
237
            test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids
238
239
        if self.patient_strat:
240
            slide_ids = []
241
            for idx in test_ids:
242
                case_id = self.patient_data['case_id'][idx]
243
                slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
244
                slide_ids.extend(slide_indices)
245
246
            return slide_ids
247
        else:
248
            return test_ids
249
250
    def set_splits(self,start_from=None):
251
        if start_from:
252
            ids = nth(self.split_gen, start_from)
253
254
        else:
255
            ids = next(self.split_gen)
256
257
        if self.patient_strat:
258
            slide_ids = [[] for i in range(len(ids))]
259
260
            for split in range(len(ids)):
261
                for idx in ids[split]:
262
                    case_id = self.patient_data['case_id'][idx]
263
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
264
                    slide_ids[split].extend(slide_indices)
265
266
            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]
267
268
        else:
269
            self.train_ids, self.val_ids, self.test_ids = ids
270
271
    def get_split_from_df(self, all_splits=None, split_key='train', split=None):
272
        if split is None:
273
            split = all_splits[split_key]
274
            split = split.dropna().reset_index(drop=True)
275
276
        if len(split) > 0:
277
            mask = self.slide_data['slide_id'].isin(split.tolist())
278
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
279
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level)
280
        else:
281
            split = None
282
283
        return split
284
285
    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
286
        merged_split = []
287
        for split_key in split_keys:
288
            split = all_splits[split_key]
289
            split = split.dropna().reset_index(drop=True).tolist()
290
            merged_split.extend(split)
291
292
        if len(split) > 0:
293
            mask = self.slide_data['slide_id'].isin(merged_split)
294
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
295
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level)
296
        else:
297
            split = None
298
299
        return split
300
301
302
    def return_splits(self, from_id=True, csv_path=None):
303
304
305
        if from_id:
306
            if len(self.train_ids) > 0:
307
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
308
                train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level)
309
310
            else:
311
                train_split = None
312
313
            if len(self.val_ids) > 0:
314
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
315
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level)
316
317
            else:
318
                val_split = None
319
320
            if len(self.test_ids) > 0:
321
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
322
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level)
323
324
            else:
325
                test_split = None
326
327
328
        else:
329
            assert csv_path
330
            all_splits = pd.read_csv(csv_path)
331
            train_split = self.get_split_from_df(all_splits, 'train')
332
            val_split = self.get_split_from_df(all_splits, 'val')
333
            test_split = self.get_split_from_df(all_splits, 'test')
334
335
        return train_split, val_split, test_split
336
337
    def get_list(self, ids):
338
        return self.slide_data['slide_id'][ids]
339
340
    def getlabel(self, ids, task):
341
        if task > 0:
342
            return self.slide_data[self.label_cols[task]][ids]
343
        else:
344
            return self.slide_data['label'][ids]
345
346
    def __getitem__(self, idx):
347
        return None
348
349
    def test_split_gen(self, return_descriptor=False):
350
        if return_descriptor:
351
            dfs = []
352
            for task in range(len(self.label_dicts)):
353
                index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])]
354
                columns = ['train', 'val', 'test']
355
                df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
356
                            columns= columns)
357
                dfs.append(df)
358
359
360
361
362
        for task in range(len(self.label_dicts)):
363
            count = len(self.train_ids)
364
            print('\nnumber of training samples: {}'.format(count))
365
            index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])]
366
            labels = self.getlabel(self.train_ids, task)
367
            unique, counts = np.unique(labels, return_counts=True)
368
            missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique)
369
            unique = np.append(unique, missing_classes)
370
            counts = np.append(counts, np.full(len(missing_classes), 0))
371
            inds = unique.argsort()
372
            counts = counts[inds]
373
            for u in range(len(unique)):
374
                print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
375
                if return_descriptor:
376
                    dfs[task].loc[index[u], 'train'] = counts[u]
377
378
            count = len(self.val_ids)
379
            print('\nnumber of val samples: {}'.format(count))
380
            labels = self.getlabel(self.val_ids, task)
381
            unique, counts = np.unique(labels, return_counts=True)
382
            missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique)
383
            unique = np.append(unique, missing_classes)
384
            counts = np.append(counts, np.full(len(missing_classes), 0))
385
            inds = unique.argsort()
386
            counts = counts[inds]
387
            for u in range(len(unique)):
388
                print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
389
                if return_descriptor:
390
                    dfs[task].loc[index[u], 'val'] = counts[u]
391
392
            count = len(self.test_ids)
393
            print('\nnumber of test samples: {}'.format(count))
394
            labels = self.getlabel(self.test_ids, task)
395
            unique, counts = np.unique(labels, return_counts=True)
396
            missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique)
397
            unique = np.append(unique, missing_classes)
398
            counts = np.append(counts, np.full(len(missing_classes), 0))
399
            inds = unique.argsort()
400
            counts = counts[inds]
401
            for u in range(len(unique)):
402
                print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
403
                if return_descriptor:
404
                    dfs[task].loc[index[u], 'test'] = counts[u]
405
406
        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
407
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
408
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
409
410
        if return_descriptor:
411
            df = pd.concat(dfs, axis=0)
412
            return df
413
414
    def save_split(self, filename):
415
        train_split = self.get_list(self.train_ids)
416
        val_split = self.get_list(self.val_ids)
417
        test_split = self.get_list(self.test_ids)
418
        df_tr = pd.DataFrame({'train': train_split})
419
        df_v = pd.DataFrame({'val': val_split})
420
        df_t = pd.DataFrame({'test': test_split})
421
        df = pd.concat([df_tr, df_v, df_t], axis=1)
422
        df.to_csv(filename, index = False)
423
424
425
class Generic_MIL_MTL_Dataset(Generic_WSI_MTL_Dataset):
426
    def __init__(self,
427
        data_dir,
428
        **kwargs):
429
        super(Generic_MIL_MTL_Dataset, self).__init__(**kwargs)
430
        self.data_dir = data_dir
431
        self.use_h5 = False
432
433
    def load_from_h5(self, toggle):
434
        self.use_h5 = toggle
435
436
    def __getitem__(self, idx):
437
438
        if not self.patient_level:
439
440
            slide_id    = self.slide_data['slide_id'][idx]
441
            label_task1 = self.slide_data[self.label_cols[0]][idx]
442
            label_task2 = self.slide_data[self.label_cols[1]][idx]
443
            label_task3 = self.slide_data[self.label_cols[2]][idx]
444
            if type(self.data_dir) == dict:
445
                source = self.slide_data['source'][idx]
446
                data_dir = self.data_dir[source]
447
            else:
448
                data_dir = self.data_dir
449
450
            if not self.use_h5:
451
                if self.data_dir:
452
                    full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
453
                    features = torch.load(full_path)
454
                    return features, label_task1, label_task2, label_task3
455
456
                else:
457
                    return slide_id, label_task1, label_task2, label_task3
458
459
            else:
460
                full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
461
                with h5py.File(full_path,'r') as hdf5_file:
462
                    features = hdf5_file['features'][:]
463
                    coords = hdf5_file['coords'][:]
464
465
                features = torch.from_numpy(features)
466
                return features, label_task1, label_task2, label_task3, coords
467
468
        else:
469
            case_id = self.slide_data['case_id'][idx]
470
            slide_ids = self.patient_dict[case_id]
471
            label_task1 = self.slide_data[self.label_cols[0]][idx]
472
            label_task2 = self.slide_data[self.label_cols[1]][idx]
473
            label_task3 = self.slide_data[self.label_cols[2]][idx]
474
475
            if type(self.data_dir) == dict:
476
                source = self.slide_data['source'][idx]
477
                data_dir = self.data_dir[source]
478
            else:
479
                data_dir = self.data_dir
480
481
            if not self.use_h5:
482
                features_list = []
483
484
                for slide_id in slide_ids:
485
                    full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
486
                    slide_features = torch.load(full_path)
487
                    features_list.append(slide_features)
488
489
                features = torch.cat( features_list, dim = 0)
490
                return features, label_task1, label_task2, label_task3
491
492
            else:
493
                features_list = []
494
                coords_list   = []
495
496
                for slide_id in slide_ids:
497
                    full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
498
                    with h5py.File(full_path,'r') as hdf5_file:
499
                        slide_features   = hdf5_file['features'][:]
500
                        slide_coords     = hdf5_file['coords'][:]
501
                        silide_features_t = torch.from_numpy(slide_features)
502
                        slide_coords_t   = torch.from_numpy(slide_coords)
503
    
504
                        features_list.append( slide_features_t )
505
                        coords_list.append(   slide_coords_t   )
506
507
                features = troch.cat( features_list, dim = 0)
508
                coords   = torch.cat( coords_list,   dim = 0)
509
                return features, label_task1, label_task2, label_task3, coords
510
511
512
class Generic_Split(Generic_MIL_MTL_Dataset):
513
    def __init__(self, slide_data, data_dir=None, num_classes=2, label_cols=None, patient_level=False):
514
        self.use_h5 = False
515
        self.slide_data = slide_data
516
        self.data_dir = data_dir
517
        self.num_classes = num_classes
518
        self.slide_cls_ids = [[] for i in range(self.num_classes[0])]
519
        self.label_cols = label_cols
520
        self.slide_cls_ids=None
521
        #for i in range(self.num_classes[0]):
522
        #   self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
523
        
524
        self.patient_level = patient_level
525
        if self.patient_level:
526
            self.patient_dict = self.build_patient_dict() 
527
            #self.slide_data   = self.slide_data.drop_duplicates(subset=['case_id'])
528
        else:
529
            self.patient_dict = {}
530
531
    def __len__(self):
532
        return len(self.slide_data)
533
534
535
class Generic_WSI_Inference_Dataset(Dataset):
536
    def __init__(self,
537
        data_dir,
538
        csv_path = None,
539
        print_info = True,
540
        ):
541
        self.data_dir = data_dir
542
        self.print_info = print_info
543
544
        if csv_path is not None:
545
            data = pd.read_csv(csv_path)
546
            self.slide_data = data['slide_id'].values
547
        else:
548
            data = np.array(os.listdir(data_dir))
549
            self.slide_data = np.char.strip(data, chars ='.pt')
550
        if print_info:
551
            print('total number of slides to infer: ', len(self.slide_data))
552
553
    def __len__(self):
554
        return len(self.slide_data)
555
556
    def __getitem__(self, idx):
557
        slide_file = self.slide_data[idx]+'.pt'
558
        full_path = os.path.join(self.data_dir, 'pt_files',slide_file)
559
        features = torch.load(full_path)
560
        return features