a b/datasets/dataset_generic.py
1
from __future__ import print_function, division
2
import os
3
import torch
4
import numpy as np
5
import pandas as pd
6
import math
7
import re
8
import pdb
9
import pickle
10
from scipy import stats
11
12
from torch.utils.data import Dataset
13
import h5py
14
15
from utils.utils import generate_split, nth
16
17
def save_splits(split_datasets, column_keys, filename, boolean_style=False):
18
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
19
    if not boolean_style:
20
        df = pd.concat(splits, ignore_index=True, axis=1)
21
        df.columns = column_keys
22
    else:
23
        df = pd.concat(splits, ignore_index = True, axis=0)
24
        index = df.values.tolist()
25
        one_hot = np.eye(len(split_datasets)).astype(bool)
26
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
27
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
28
29
    df.to_csv(filename)
30
    print()
31
32
class Generic_WSI_Classification_Dataset(Dataset):
33
    def __init__(self,
34
        csv_path = 'dataset_csv/ccrcc_clean.csv',
35
        shuffle = False, 
36
        seed = 7, 
37
        print_info = True,
38
        label_dict = {},
39
        filter_dict = {},
40
        ignore=[],
41
        patient_strat=False,
42
        label_col = None,
43
        patient_voting = 'max',
44
        ):
45
        """
46
        Args:
47
            csv_file (string): Path to the csv file with annotations.
48
            shuffle (boolean): Whether to shuffle
49
            seed (int): random seed for shuffling the data
50
            print_info (boolean): Whether to print a summary of the dataset
51
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
52
            ignore (list): List containing class labels to ignore
53
        """
54
        self.label_dict = label_dict
55
        self.num_classes = len(set(self.label_dict.values()))
56
        self.seed = seed
57
        self.print_info = print_info
58
        self.patient_strat = patient_strat
59
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
60
        self.data_dir = None
61
        if not label_col:
62
            label_col = 'label'
63
        self.label_col = label_col
64
65
        slide_data = pd.read_csv(csv_path)
66
        slide_data = self.filter_df(slide_data, filter_dict)
67
        slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col)
68
69
        ###shuffle data
70
        if shuffle:
71
            np.random.seed(seed)
72
            np.random.shuffle(slide_data)
73
74
        self.slide_data = slide_data
75
76
        self.patient_data_prep(patient_voting)
77
        self.cls_ids_prep()
78
79
        if print_info:
80
            self.summarize()
81
82
    def cls_ids_prep(self):
83
        # store ids corresponding each class at the patient or case level
84
        self.patient_cls_ids = [[] for i in range(self.num_classes)]        
85
        for i in range(self.num_classes):
86
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
87
88
        # store ids corresponding each class at the slide level
89
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
90
        for i in range(self.num_classes):
91
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
92
93
    def patient_data_prep(self, patient_voting='max'):
94
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
95
        patient_labels = []
96
        
97
        for p in patients:
98
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
99
            assert len(locations) > 0
100
            label = self.slide_data['label'][locations].values
101
            if patient_voting == 'max':
102
                label = label.max() # get patient label (MIL convention)
103
            elif patient_voting == 'maj':
104
                label = stats.mode(label)[0]
105
            else:
106
                raise NotImplementedError
107
            patient_labels.append(label)
108
        
109
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
110
111
    @staticmethod
112
    def df_prep(data, label_dict, ignore, label_col):
113
        if label_col != 'label':
114
            data['label'] = data[label_col].copy()
115
116
        mask = data['label'].isin(ignore)
117
        data = data[~mask]
118
        data.reset_index(drop=True, inplace=True)
119
        for i in data.index:
120
            key = data.loc[i, 'label']
121
            data.at[i, 'label'] = label_dict[key]
122
123
        return data
124
125
    def filter_df(self, df, filter_dict={}):
126
        if len(filter_dict) > 0:
127
            filter_mask = np.full(len(df), True, bool)
128
            # assert 'label' not in filter_dict.keys()
129
            for key, val in filter_dict.items():
130
                mask = df[key].isin(val)
131
                filter_mask = np.logical_and(filter_mask, mask)
132
            df = df[filter_mask]
133
        return df
134
135
    def __len__(self):
136
        if self.patient_strat:
137
            return len(self.patient_data['case_id'])
138
139
        else:
140
            return len(self.slide_data)
141
142
    def summarize(self):
143
        print("label column: {}".format(self.label_col))
144
        print("label dictionary: {}".format(self.label_dict))
145
        print("number of classes: {}".format(self.num_classes))
146
        print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
147
        for i in range(self.num_classes):
148
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
149
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
150
151
    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
152
        settings = {
153
                    'n_splits' : k, 
154
                    'val_num' : val_num, 
155
                    'test_num': test_num,
156
                    'label_frac': label_frac,
157
                    'seed': self.seed,
158
                    'custom_test_ids': custom_test_ids
159
                    }
160
161
        if self.patient_strat:
162
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
163
        else:
164
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})
165
166
        self.split_gen = generate_split(**settings)
167
168
    def set_splits(self,start_from=None):
169
        if start_from:
170
            ids = nth(self.split_gen, start_from)
171
172
        else:
173
            ids = next(self.split_gen)
174
175
        if self.patient_strat:
176
            slide_ids = [[] for i in range(len(ids))] 
177
178
            for split in range(len(ids)): 
179
                for idx in ids[split]:
180
                    case_id = self.patient_data['case_id'][idx]
181
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
182
                    slide_ids[split].extend(slide_indices)
183
184
            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]
185
186
        else:
187
            self.train_ids, self.val_ids, self.test_ids = ids
188
189
    def get_split_from_df(self, all_splits, split_key='train'):
190
        split = all_splits[split_key]
191
        split = split.dropna().reset_index(drop=True)
192
193
        if len(split) > 0:
194
            mask = self.slide_data['slide_id'].isin(split.tolist())
195
            df_slice = self.slide_data[mask].reset_index(drop=True)
196
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
197
        else:
198
            split = None
199
        
200
        return split
201
202
    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
203
        merged_split = []
204
        for split_key in split_keys:
205
            split = all_splits[split_key]
206
            split = split.dropna().reset_index(drop=True).tolist()
207
            merged_split.extend(split)
208
209
        if len(split) > 0:
210
            mask = self.slide_data['slide_id'].isin(merged_split)
211
            df_slice = self.slide_data[mask].reset_index(drop=True)
212
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
213
        else:
214
            split = None
215
        
216
        return split
217
218
219
    def return_splits(self, from_id=True, csv_path=None):
220
221
222
        if from_id:
223
            if len(self.train_ids) > 0:
224
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
225
                train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes)
226
227
            else:
228
                train_split = None
229
            
230
            if len(self.val_ids) > 0:
231
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
232
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes)
233
234
            else:
235
                val_split = None
236
            
237
            if len(self.test_ids) > 0:
238
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
239
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes)
240
            
241
            else:
242
                test_split = None
243
            
244
        
245
        else:
246
            assert csv_path 
247
            all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype)  # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
248
            train_split = self.get_split_from_df(all_splits, 'train')
249
            val_split = self.get_split_from_df(all_splits, 'val')
250
            test_split = self.get_split_from_df(all_splits, 'test')
251
            
252
        return train_split, val_split, test_split
253
254
    def get_list(self, ids):
255
        return self.slide_data['slide_id'][ids]
256
257
    def getlabel(self, ids):
258
        return self.slide_data['label'][ids]
259
260
    def __getitem__(self, idx):
261
        return None
262
263
    def test_split_gen(self, return_descriptor=False):
264
265
        if return_descriptor:
266
            index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)]
267
            columns = ['train', 'val', 'test']
268
            df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
269
                            columns= columns)
270
271
        count = len(self.train_ids)
272
        print('\nnumber of training samples: {}'.format(count))
273
        labels = self.getlabel(self.train_ids)
274
        unique, counts = np.unique(labels, return_counts=True)
275
        for u in range(len(unique)):
276
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
277
            if return_descriptor:
278
                df.loc[index[u], 'train'] = counts[u]
279
        
280
        count = len(self.val_ids)
281
        print('\nnumber of val samples: {}'.format(count))
282
        labels = self.getlabel(self.val_ids)
283
        unique, counts = np.unique(labels, return_counts=True)
284
        for u in range(len(unique)):
285
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
286
            if return_descriptor:
287
                df.loc[index[u], 'val'] = counts[u]
288
289
        count = len(self.test_ids)
290
        print('\nnumber of test samples: {}'.format(count))
291
        labels = self.getlabel(self.test_ids)
292
        unique, counts = np.unique(labels, return_counts=True)
293
        for u in range(len(unique)):
294
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
295
            if return_descriptor:
296
                df.loc[index[u], 'test'] = counts[u]
297
298
        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
299
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
300
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
301
302
        if return_descriptor:
303
            return df
304
305
    def save_split(self, filename):
306
        train_split = self.get_list(self.train_ids)
307
        val_split = self.get_list(self.val_ids)
308
        test_split = self.get_list(self.test_ids)
309
        df_tr = pd.DataFrame({'train': train_split})
310
        df_v = pd.DataFrame({'val': val_split})
311
        df_t = pd.DataFrame({'test': test_split})
312
        df = pd.concat([df_tr, df_v, df_t], axis=1) 
313
        df.to_csv(filename, index = False)
314
315
316
class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
317
    def __init__(self,
318
        data_dir, 
319
        **kwargs):
320
    
321
        super(Generic_MIL_Dataset, self).__init__(**kwargs)
322
        self.data_dir = data_dir
323
        self.use_h5 = False
324
325
    def load_from_h5(self, toggle):
326
        self.use_h5 = toggle
327
328
    def __getitem__(self, idx):
329
        slide_id = self.slide_data['slide_id'][idx]
330
        label = self.slide_data['label'][idx]
331
        if type(self.data_dir) == dict:
332
            source = self.slide_data['source'][idx]
333
            data_dir = self.data_dir[source]
334
        else:
335
            data_dir = self.data_dir
336
337
        if not self.use_h5:
338
            if self.data_dir:
339
                full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
340
                features = torch.load(full_path)
341
                return features, label
342
            
343
            else:
344
                return slide_id, label
345
346
        else:
347
            full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
348
            with h5py.File(full_path,'r') as hdf5_file:
349
                features = hdf5_file['features'][:]
350
                coords = hdf5_file['coords'][:]
351
352
            features = torch.from_numpy(features)
353
            return features, label, coords
354
355
356
class Generic_Split(Generic_MIL_Dataset):
357
    def __init__(self, slide_data, data_dir=None, num_classes=2):
358
        self.use_h5 = False
359
        self.slide_data = slide_data
360
        self.data_dir = data_dir
361
        self.num_classes = num_classes
362
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
363
        for i in range(self.num_classes):
364
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
365
366
    def __len__(self):
367
        return len(self.slide_data)
368
        
369
370