Switch to unified view

a b/datasets/dataset_survival.py
1
from __future__ import print_function, division
2
import math
3
import os
4
import pdb
5
import pickle
6
import re
7
8
import h5py
9
import numpy as np
10
import pandas as pd
11
from scipy import stats
12
from sklearn.preprocessing import StandardScaler
13
14
import torch
15
from torch.utils.data import Dataset
16
17
from utils.utils import generate_split, nth
18
19
20
class Generic_WSI_Survival_Dataset(Dataset):
21
    def __init__(self,
22
        csv_path = 'dataset_csv/ccrcc_clean.csv', mode = 'omic', apply_sig = False,
23
        shuffle = False, seed = 7, print_info = True, n_bins = 4, ignore=[],
24
        patient_strat=False, label_col = None, filter_dict = {}, eps=1e-6):
25
        r"""
26
        Generic_WSI_Survival_Dataset 
27
28
        Args:
29
            csv_file (string): Path to the csv file with annotations.
30
            shuffle (boolean): Whether to shuffle
31
            seed (int): random seed for shuffling the data
32
            print_info (boolean): Whether to print a summary of the dataset
33
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
34
            ignore (list): List containing class labels to ignore
35
        """
36
        self.custom_test_ids = None
37
        self.seed = seed
38
        self.print_info = print_info
39
        self.patient_strat = patient_strat
40
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
41
        self.data_dir = None
42
43
        if shuffle:
44
            np.random.seed(seed)
45
            np.random.shuffle(slide_data)
46
47
        slide_data = pd.read_csv(csv_path, low_memory=False)
48
        #slide_data = slide_data.drop(['Unnamed: 0'], axis=1)
49
        if 'case_id' not in slide_data:
50
            slide_data.index = slide_data.index.str[:12]
51
            slide_data['case_id'] = slide_data.index
52
            slide_data = slide_data.reset_index(drop=True)
53
54
        if not label_col:
55
            label_col = 'survival_months'
56
        else:
57
            assert label_col in slide_data.columns
58
        self.label_col = label_col
59
60
        if "IDC" in slide_data['oncotree_code']: # must be BRCA (and if so, use only IDCs)
61
            slide_data = slide_data[slide_data['oncotree_code'] == 'IDC']
62
63
        patients_df = slide_data.drop_duplicates(['case_id']).copy()
64
        uncensored_df = patients_df[patients_df['censorship'] < 1]
65
66
        disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False)
67
        q_bins[-1] = slide_data[label_col].max() + eps
68
        q_bins[0] = slide_data[label_col].min() - eps
69
        
70
        disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True)
71
        patients_df.insert(2, 'label', disc_labels.values.astype(int))
72
73
        patient_dict = {}
74
        slide_data = slide_data.set_index('case_id')
75
        for patient in patients_df['case_id']:
76
            slide_ids = slide_data.loc[patient, 'slide_id']
77
            if isinstance(slide_ids, str):
78
                slide_ids = np.array(slide_ids).reshape(-1)
79
            else:
80
                slide_ids = slide_ids.values
81
            patient_dict.update({patient:slide_ids})
82
83
        self.patient_dict = patient_dict
84
    
85
        slide_data = patients_df
86
        slide_data.reset_index(drop=True, inplace=True)
87
        slide_data = slide_data.assign(slide_id=slide_data['case_id'])
88
89
        label_dict = {}
90
        key_count = 0
91
        for i in range(len(q_bins)-1):
92
            for c in [0, 1]:
93
                print('{} : {}'.format((i, c), key_count))
94
                label_dict.update({(i, c):key_count})
95
                key_count+=1
96
97
        self.label_dict = label_dict
98
        for i in slide_data.index:
99
            key = slide_data.loc[i, 'label']
100
            slide_data.at[i, 'disc_label'] = key
101
            censorship = slide_data.loc[i, 'censorship']
102
            key = (key, int(censorship))
103
            slide_data.at[i, 'label'] = label_dict[key]
104
105
        self.bins = q_bins
106
        self.num_classes=len(self.label_dict)
107
        patients_df = slide_data.drop_duplicates(['case_id'])
108
        self.patient_data = {'case_id':patients_df['case_id'].values, 'label':patients_df['label'].values}
109
110
        #new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2]) ### ICCV
111
        new_cols = list(slide_data.columns[-1:]) + list(slide_data.columns[:-1])  ### PORPOISE
112
        slide_data = slide_data[new_cols]
113
        self.slide_data = slide_data
114
        metadata = ['disc_label', 'Unnamed: 0', 'case_id', 'label', 'slide_id', 'age', 'site', 'survival_months', 'censorship', 'is_female', 'oncotree_code', 'train']
115
        self.metadata = slide_data.columns[:12]
116
        
117
        for col in slide_data.drop(self.metadata, axis=1).columns:
118
            if not pd.Series(col).str.contains('|_cnv|_rnaseq|_rna|_mut')[0]:
119
                print(col)
120
        #pdb.set_trace()
121
122
        assert self.metadata.equals(pd.Index(metadata))
123
        self.mode = mode
124
        self.cls_ids_prep()
125
126
        ### ICCV discrepancies
127
        # For BLCA, TPTEP1_rnaseq was accidentally appended to the metadata
128
        #pdb.set_trace()
129
130
        if print_info:
131
            self.summarize()
132
133
        ### Signatures
134
        self.apply_sig = apply_sig
135
        if self.apply_sig:
136
            self.signatures = pd.read_csv('./datasets_csv_sig/signatures.csv')
137
        else:
138
            self.signatures = None
139
140
        if print_info:
141
            self.summarize()
142
143
144
    def cls_ids_prep(self):
145
        r"""
146
147
        """
148
        self.patient_cls_ids = [[] for i in range(self.num_classes)]        
149
        for i in range(self.num_classes):
150
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
151
152
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
153
        for i in range(self.num_classes):
154
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
155
156
157
    def patient_data_prep(self):
158
        r"""
159
        
160
        """
161
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
162
        patient_labels = []
163
        
164
        for p in patients:
165
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
166
            assert len(locations) > 0
167
            label = self.slide_data['label'][locations[0]] # get patient label
168
            patient_labels.append(label)
169
        
170
        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
171
172
173
    @staticmethod
174
    def df_prep(data, n_bins, ignore, label_col):
175
        r"""
176
        
177
        """
178
179
        mask = data[label_col].isin(ignore)
180
        data = data[~mask]
181
        data.reset_index(drop=True, inplace=True)
182
        disc_labels, bins = pd.cut(data[label_col], bins=n_bins)
183
        return data, bins
184
185
    def __len__(self):
186
        if self.patient_strat:
187
            return len(self.patient_data['case_id'])
188
        else:
189
            return len(self.slide_data)
190
191
    def summarize(self):
192
        print("label column: {}".format(self.label_col))
193
        print("label dictionary: {}".format(self.label_dict))
194
        print("number of classes: {}".format(self.num_classes))
195
        print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
196
        for i in range(self.num_classes):
197
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
198
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
199
200
201
    def get_split_from_df(self, all_splits: dict, split_key: str='train', scaler=None):
202
        split = all_splits[split_key]
203
        split = split.dropna().reset_index(drop=True)
204
205
        if len(split) > 0:
206
            mask = self.slide_data['slide_id'].isin(split.tolist())
207
            df_slice = self.slide_data[mask].reset_index(drop=True)
208
            split = Generic_Split(df_slice, metadata=self.metadata, mode=self.mode, signatures=self.signatures, data_dir=self.data_dir, label_col=self.label_col, patient_dict=self.patient_dict, num_classes=self.num_classes)
209
        else:
210
            split = None
211
        
212
        return split
213
214
215
    def return_splits(self, from_id: bool=True, csv_path: str=None):
216
        if from_id:
217
            raise NotImplementedError
218
        else:
219
            assert csv_path 
220
            all_splits = pd.read_csv(csv_path)
221
            train_split = self.get_split_from_df(all_splits=all_splits, split_key='train')
222
            val_split = self.get_split_from_df(all_splits=all_splits, split_key='val')
223
            test_split = None #self.get_split_from_df(all_splits=all_splits, split_key='test')
224
225
            ### --> Normalizing Data
226
            print("****** Normalizing Data ******")
227
            scalers = train_split.get_scaler()
228
            train_split.apply_scaler(scalers=scalers)
229
            val_split.apply_scaler(scalers=scalers)
230
            #test_split.apply_scaler(scalers=scalers)
231
            ### <--
232
        return train_split, val_split#, test_split
233
234
235
    def get_list(self, ids):
236
        return self.slide_data['slide_id'][ids]
237
238
    def getlabel(self, ids):
239
        return self.slide_data['label'][ids]
240
241
    def __getitem__(self, idx):
242
        return None
243
244
    def __getitem__(self, idx):
245
        return None
246
247
248
class Generic_MIL_Survival_Dataset(Generic_WSI_Survival_Dataset):
249
    def __init__(self, data_dir, mode: str='omic', **kwargs):
250
        super(Generic_MIL_Survival_Dataset, self).__init__(**kwargs)
251
        self.data_dir = data_dir
252
        self.mode = mode
253
        self.use_h5 = False
254
255
    def load_from_h5(self, toggle):
256
        self.use_h5 = toggle
257
258
    def __getitem__(self, idx):
259
        case_id = self.slide_data['case_id'][idx]
260
        label = torch.Tensor([self.slide_data['disc_label'][idx]])
261
        event_time = torch.Tensor([self.slide_data[self.label_col][idx]])
262
        c = torch.Tensor([self.slide_data['censorship'][idx]])
263
        slide_ids = self.patient_dict[case_id]
264
265
        if type(self.data_dir) == dict:
266
            source = self.slide_data['oncotree_code'][idx]
267
            data_dir = self.data_dir[source]
268
        else:
269
            data_dir = self.data_dir
270
        
271
        if not self.use_h5:
272
            if self.data_dir:
273
                if self.mode == 'path':
274
                    path_features = []
275
                    for slide_id in slide_ids:
276
                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
277
                        wsi_bag = torch.load(wsi_path)
278
                        path_features.append(wsi_bag)
279
                    path_features = torch.cat(path_features, dim=0)
280
                    return (path_features, torch.zeros((1,1)), label, event_time, c)
281
282
                elif self.mode == 'cluster':
283
                    path_features = []
284
                    cluster_ids = []
285
                    for slide_id in slide_ids:
286
                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
287
                        wsi_bag = torch.load(wsi_path)
288
                        path_features.append(wsi_bag)
289
                        cluster_ids.extend(self.fname2ids[slide_id[:-4]+'.pt'])
290
                    path_features = torch.cat(path_features, dim=0)
291
                    cluster_ids = torch.Tensor(cluster_ids)
292
                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
293
                    return (path_features, cluster_ids, genomic_features, label, event_time, c)
294
295
                elif self.mode == 'omic':
296
                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
297
                    return (torch.zeros((1,1)), genomic_features.unsqueeze(dim=0), label, event_time, c)
298
299
                elif self.mode == 'pathomic':
300
                    path_features = []
301
                    for slide_id in slide_ids:
302
                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
303
                        wsi_bag = torch.load(wsi_path)
304
                        path_features.append(wsi_bag)
305
                    path_features = torch.cat(path_features, dim=0)
306
                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
307
                    return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c)
308
309
                elif self.mode == 'pathomic_fast':
310
                    casefeat_path = os.path.join(data_dir, f'split_{self.split_id}_case_pt', f'{case_id}.pt')
311
                    path_features = torch.load(casefeat_path)
312
                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
313
                    return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c)
314
315
                elif self.mode == 'coattn':
316
                    path_features = []
317
                    for slide_id in slide_ids:
318
                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
319
                        wsi_bag = torch.load(wsi_path)
320
                        path_features.append(wsi_bag)
321
                    path_features = torch.cat(path_features, dim=0)
322
                    omic1 = torch.tensor(self.genomic_features[self.omic_names[0]].iloc[idx])
323
                    omic2 = torch.tensor(self.genomic_features[self.omic_names[1]].iloc[idx])
324
                    omic3 = torch.tensor(self.genomic_features[self.omic_names[2]].iloc[idx])
325
                    omic4 = torch.tensor(self.genomic_features[self.omic_names[3]].iloc[idx])
326
                    omic5 = torch.tensor(self.genomic_features[self.omic_names[4]].iloc[idx])
327
                    omic6 = torch.tensor(self.genomic_features[self.omic_names[5]].iloc[idx])
328
                    return (path_features, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c)
329
330
                else:
331
                    raise NotImplementedError('Mode [%s] not implemented.' % self.mode)
332
            else:
333
                return slide_ids, label, event_time, c
334
335
336
class Generic_Split(Generic_MIL_Survival_Dataset):
337
    def __init__(self, slide_data, metadata, mode, 
338
        signatures=None, data_dir=None, label_col=None, patient_dict=None, num_classes=2):
339
        self.use_h5 = False
340
        self.slide_data = slide_data
341
        self.metadata = metadata
342
        self.mode = mode
343
        self.data_dir = data_dir
344
        self.num_classes = num_classes
345
        self.label_col = label_col
346
        self.patient_dict = patient_dict
347
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
348
        for i in range(self.num_classes):
349
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
350
351
        ### --> Initializing genomic features in Generic Split
352
        self.genomic_features = self.slide_data.drop(self.metadata, axis=1)
353
        self.signatures = signatures
354
355
        if mode == 'cluster':
356
            with open(os.path.join(data_dir, 'fast_cluster_ids.pkl'), 'rb') as handle:
357
                self.fname2ids = pickle.load(handle)
358
359
        def series_intersection(s1, s2):
360
            return pd.Series(list(set(s1) & set(s2)))
361
362
        if self.signatures is not None:
363
            self.omic_names = []
364
            for col in self.signatures.columns:
365
                omic = self.signatures[col].dropna().unique()
366
                omic = np.concatenate([omic+mode for mode in ['_mut', '_cnv', '_rnaseq']])
367
                omic = sorted(series_intersection(omic, self.genomic_features.columns))
368
                self.omic_names.append(omic)
369
            self.omic_sizes = [len(omic) for omic in self.omic_names]
370
        print("Shape", self.genomic_features.shape)
371
        ### <--
372
373
    def __len__(self):
374
        return len(self.slide_data)
375
376
    ### --> Getting StandardScaler of self.genomic_features
377
    def get_scaler(self):
378
        scaler_omic = StandardScaler().fit(self.genomic_features)
379
        return (scaler_omic,)
380
    ### <--
381
382
    ### --> Applying StandardScaler to self.genomic_features
383
    def apply_scaler(self, scalers: tuple=None):
384
        transformed = pd.DataFrame(scalers[0].transform(self.genomic_features))
385
        transformed.columns = self.genomic_features.columns
386
        self.genomic_features = transformed
387
    ### <--
388
389
    def set_split_id(self, split_id):
390
        self.split_id = split_id