a b/datasets/dataset_mtl_concat.py
1
from __future__ import print_function, division
2
import os
3
import torch
4
import numpy as np
5
import pandas as pd
6
import math
7
import re
8
import pdb
9
import pickle
10
from scipy import stats
11
from torch.utils.data import Dataset
12
import h5py
13
from utils.utils import generate_split, nth
14
15
16
def save_splits(split_datasets, column_keys, filename, boolean_style=False):
17
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
18
    if not boolean_style:
19
        df = pd.concat(splits, ignore_index=True, axis=1)
20
        df.columns = column_keys
21
    else:
22
        df = pd.concat(splits, ignore_index = True, axis=0)
23
        index = df.values.tolist()
24
        one_hot = np.eye(len(split_datasets)).astype(bool)
25
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
26
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
27
28
    df.to_csv(filename)
29
30
class Generic_WSI_MTL_Dataset(Dataset):
31
    def __init__(self,
32
        csv_path = None,
33
        shuffle = False, 
34
        seed = 7, 
35
        print_info = True,
36
        label_dicts = [{}, {}, {}],
37
        patient_strat=False,
38
        label_cols = ['label', 'site', 'sex'],
39
        patient_voting = 'max',
40
        filter_dict = {},
41
        ):
42
        """
43
        Args:
44
            csv_file (string): Path to the dataset csv file.
45
            shuffle (boolean): Whether to shuffle
46
            seed (int): random seed for shuffling the data
47
            print_info (boolean): Whether to print a summary of the dataset
48
            label_dicts (list of dict): List of dictionaries with key, value pairs for converting str labels to int for each label column
49
            label_cols (list): List of column headings to use as labels and map with label_dicts
50
            filter_dict (dict): Dictionary of key, value pairs to exclude from the dataset where key represents a column name, 
51
                                and value is a list of values to ignore in that column
52
            patient_voting (string): Rule for deciding the patient-level label
53
        """
54
        self.custom_test_ids = None
55
        self.seed = seed
56
        self.print_info = print_info
57
        self.patient_strat = patient_strat
58
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
59
        self.data_dir = None
60
        self.label_cols = label_cols
61
        self.split_gen = None
62
63
        slide_data = pd.read_csv(csv_path)
64
        slide_data = self.filter_df(slide_data, filter_dict)
65
66
        self.label_dicts = label_dicts
67
        self.num_classes=[len(set(label_dict.values())) for label_dict in self.label_dicts]
68
69
        slide_data = self.df_prep(slide_data, self.label_dicts, self.label_cols)
70
        ###shuffle data
71
        if shuffle:
72
            np.random.seed(seed)
73
            np.random.shuffle(slide_data)
74
75
        self.slide_data = slide_data
76
77
        self.patient_data_prep(patient_voting)
78
        self.cls_ids_prep()
79
80
        if print_info:
81
            self.summarize()
82
83
    def cls_ids_prep(self):
84
        # store ids corresponding each class at the patient or case level
85
        self.patient_cls_ids = [[] for i in range(self.num_classes[0])]     
86
        for i in range(self.num_classes[0]):
87
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
88
89
        # store ids corresponding each class at the slide level
90
        self.slide_cls_ids = [[] for i in range(self.num_classes[0])]
91
        for i in range(self.num_classes[0]):
92
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
93
94
    def patient_data_prep(self, patient_voting='max'):
95
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
96
        patient_labels = []
97
        
98
        for p in patients:
99
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
100
            assert len(locations) > 0
101
            label = self.slide_data['label'][locations].values
102
            if patient_voting == 'max':
103
                label = label.max() # get patient label (MIL convention)
104
            elif patient_voting == 'maj':
105
                label = stats.mode(label)[0]
106
            else:
107
                raise NotImplementedError
108
            patient_labels.append(label)
109
        
110
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
111
112
    @staticmethod
113
    def filter_df(df, filter_dict={}):
114
        if len(filter_dict) > 0:
115
            filter_mask = np.full(len(df), True, bool)
116
            # assert 'label' not in filter_dict.keys()
117
            for key, val in filter_dict.items():
118
                mask = df[key].isin(val)
119
                filter_mask = np.logical_and(filter_mask, mask)
120
            df = df[filter_mask]
121
        return df
122
123
    @staticmethod
124
    def df_prep(data, label_dicts, label_cols):
125
        if label_cols[0] != 'label':
126
            data['label'] = data[label_cols[0]].copy()
127
128
        data.reset_index(drop=True, inplace=True)
129
        for i in data.index:
130
            key = data.loc[i, 'label']
131
            data.at[i, 'label'] = label_dicts[0][key]
132
133
        for idx, (label_dict, label_col) in enumerate(zip(label_dicts[1:], label_cols[1:])):
134
            print(label_dict, label_col)
135
            data[label_col] = data[label_col].map(label_dict)
136
137
        return data
138
139
    def __len__(self):
140
        if self.patient_strat:
141
            return len(self.patient_data['case_id'])
142
143
        else:
144
            return len(self.slide_data)
145
146
    def summarize(self):
147
148
        for task in range(len(self.label_dicts)):
149
            print('task: ', task)
150
            print("label column: {}".format(self.label_cols[task]))
151
            print("label dictionary: {}".format(self.label_dicts[task]))
152
            print("number of classes: {}".format(self.num_classes[task]))
153
            print("slide-level counts: ", '\n', self.slide_data[self.label_cols[task]].value_counts(sort = False))
154
        
155
        for i in range(self.num_classes[0]):
156
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
157
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
158
159
    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
160
        settings = {
161
                    'n_splits' : k, 
162
                    'val_num' : val_num, 
163
                    'test_num': test_num,
164
                    'label_frac': label_frac,
165
                    'seed': self.seed,
166
                    'custom_test_ids': custom_test_ids
167
                    }
168
169
        if self.patient_strat:
170
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
171
        else:
172
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})
173
174
        self.split_gen = generate_split(**settings)
175
176
    def sample_held_out(self, test_num = (40, 40)):
177
178
        test_ids = []
179
        np.random.seed(self.seed) #fix seed
180
        
181
        if self.patient_strat:
182
            cls_ids = self.patient_cls_ids
183
        else:
184
            cls_ids = self.slide_cls_ids
185
186
        for c in range(len(test_num)):
187
            test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids
188
189
        if self.patient_strat:
190
            slide_ids = [] 
191
            for idx in test_ids:
192
                case_id = self.patient_data['case_id'][idx]
193
                slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
194
                slide_ids.extend(slide_indices)
195
196
            return slide_ids
197
        else:
198
            return test_ids
199
200
    def set_splits(self,start_from=None):
201
        if start_from:
202
            ids = nth(self.split_gen, start_from)
203
204
        else:
205
            ids = next(self.split_gen)
206
207
        if self.patient_strat:
208
            slide_ids = [[] for i in range(len(ids))] 
209
210
            for split in range(len(ids)): 
211
                for idx in ids[split]:
212
                    case_id = self.patient_data['case_id'][idx]
213
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
214
                    slide_ids[split].extend(slide_indices)
215
216
            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]
217
218
        else:
219
            self.train_ids, self.val_ids, self.test_ids = ids
220
221
    def get_split_from_df(self, all_splits=None, split_key='train', return_ids_only=False, split=None):
222
        if split is None:
223
            split = all_splits[split_key]
224
            split = split.dropna().reset_index(drop=True)
225
226
        if len(split) > 0:
227
            mask = self.slide_data['slide_id'].isin(split.tolist())
228
            if return_ids_only:
229
                ids = np.where(mask)[0]
230
                return ids
231
232
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
233
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols)
234
            
235
        else:
236
            split = None
237
        
238
        return split
239
240
    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
241
        merged_split = []
242
        for split_key in split_keys:
243
            split = all_splits[split_key]
244
            split = split.dropna().reset_index(drop=True).tolist()
245
            merged_split.extend(split)
246
247
        if len(split) > 0:
248
            mask = self.slide_data['slide_id'].isin(merged_split)
249
            df_slice = self.slide_data[mask].dropna().reset_index(drop=True)
250
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols)
251
        else:
252
            split = None
253
        
254
        return split
255
256
257
    def return_splits(self, from_id=True, csv_path=None):
258
        if from_id:
259
            if len(self.train_ids) > 0:
260
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
261
                train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols)
262
263
            else:
264
                train_split = None
265
            
266
            if len(self.val_ids) > 0:
267
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
268
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols)
269
270
            else:
271
                val_split = None
272
            
273
            if len(self.test_ids) > 0:
274
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
275
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols)
276
            
277
            else:
278
                test_split = None
279
            
280
        
281
        else:
282
            assert csv_path 
283
            all_splits = pd.read_csv(csv_path)
284
            train_split = self.get_split_from_df(all_splits, 'train')
285
            val_split = self.get_split_from_df(all_splits, 'val')
286
            test_split = self.get_split_from_df(all_splits, 'test')
287
            
288
        return train_split, val_split, test_split
289
290
    def get_list(self, ids):
291
        return self.slide_data['slide_id'][ids]
292
293
    def getlabel(self, ids, task):
294
        if task > 0:
295
            return self.slide_data[self.label_cols[task]][ids]
296
        else:
297
            return self.slide_data['label'][ids]
298
299
    def __getitem__(self, idx):
300
        return None
301
302
    def test_split_gen(self, return_descriptor=False):
303
        if return_descriptor:
304
            dfs = []
305
            for task in range(len(self.label_dicts)):
306
                index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])]
307
                columns = ['train', 'val', 'test']
308
                df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
309
                            columns= columns)
310
                dfs.append(df)
311
312
        for task in range(len(self.label_dicts)):
313
            index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])]
314
            for split_name, ids in zip(['train', 'val', 'test'], [self.train_ids, self.val_ids, self.test_ids]):
315
                count = len(ids)
316
                print('\nnumber of {} samples: {}'.format(split_name, count))
317
                labels = self.getlabel(ids, task)
318
                unique, counts = np.unique(labels, return_counts=True)
319
                missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique)
320
                unique = np.append(unique, missing_classes)
321
                counts = np.append(counts, np.full(len(missing_classes), 0))
322
                inds = unique.argsort()
323
                counts = counts[inds]
324
                for u in range(len(unique)):
325
                    print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
326
                    if return_descriptor:
327
                        dfs[task].loc[index[u], split_name] = counts[u]
328
329
        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
330
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
331
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
332
333
        if return_descriptor:
334
            df = pd.concat(dfs, axis=0) 
335
            return df
336
337
    def save_split(self, filename):
338
        train_split = self.get_list(self.train_ids)
339
        val_split = self.get_list(self.val_ids)
340
        test_split = self.get_list(self.test_ids)
341
        df_tr = pd.DataFrame({'train': train_split})
342
        df_v = pd.DataFrame({'val': val_split})
343
        df_t = pd.DataFrame({'test': test_split})
344
        df = pd.concat([df_tr, df_v, df_t], axis=1) 
345
        df.to_csv(filename, index = False)
346
347
class Generic_MIL_MTL_Dataset(Generic_WSI_MTL_Dataset):
348
    def __init__(self,
349
        data_dir, 
350
        **kwargs):
351
        super(Generic_MIL_MTL_Dataset, self).__init__(**kwargs)
352
        self.data_dir = data_dir
353
        self.use_h5 = False
354
355
    def load_from_h5(self, toggle):
356
        self.use_h5 = toggle
357
358
    def __getitem__(self, idx):
359
        slide_id = self.slide_data['slide_id'][idx]
360
        label = self.slide_data['label'][idx]
361
        site = self.slide_data[self.label_cols[1]][idx]
362
        sex = self.slide_data[self.label_cols[2]][idx]
363
        if type(self.data_dir) == dict:
364
            source = self.slide_data['source'][idx]
365
            data_dir = self.data_dir[source]
366
        else:
367
            data_dir = self.data_dir
368
369
        if not self.use_h5:
370
            full_path = os.path.join(data_dir, '{}.pt'.format(slide_id))
371
            features = torch.load(full_path)
372
373
            return features, label, site, sex
374
            
375
376
        else:
377
            full_path = os.path.join(data_dir, '{}.h5'.format(slide_id))
378
            with h5py.File(full_path,'r') as hdf5_file:
379
                features = hdf5_file['features'][:]
380
                coords = hdf5_file['coords'][:]
381
382
            features = torch.from_numpy(features)
383
            return features, label, site, sex, coords
384
385
386
387
class Generic_Split(Generic_MIL_MTL_Dataset):
388
    def __init__(self, slide_data, data_dir=None, num_classes=2, label_cols=None):
389
        self.use_h5 = False
390
        self.slide_data = slide_data
391
        self.data_dir = data_dir
392
        self.num_classes = num_classes
393
        self.slide_cls_ids = [[] for i in range(self.num_classes[0])]
394
        self.label_cols = label_cols
395
        self.infer = False
396
        for i in range(self.num_classes[0]):
397
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
398
399
    def __len__(self):
400
        return len(self.slide_data)