|
a |
|
b/datasets/dataset_survival.py |
|
|
1 |
from __future__ import print_function, division |
|
|
2 |
import math |
|
|
3 |
import os |
|
|
4 |
import pdb |
|
|
5 |
import pickle |
|
|
6 |
import re |
|
|
7 |
|
|
|
8 |
import h5py |
|
|
9 |
import numpy as np |
|
|
10 |
import pandas as pd |
|
|
11 |
from scipy import stats |
|
|
12 |
from sklearn.preprocessing import StandardScaler |
|
|
13 |
|
|
|
14 |
import torch |
|
|
15 |
from torch.utils.data import Dataset |
|
|
16 |
|
|
|
17 |
from utils.utils import generate_split, nth |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
class Generic_WSI_Survival_Dataset(Dataset): |
|
|
21 |
def __init__(self, |
|
|
22 |
csv_path = 'dataset_csv/ccrcc_clean.csv', mode = 'omic', apply_sig = False, |
|
|
23 |
shuffle = False, seed = 7, print_info = True, n_bins = 4, ignore=[], |
|
|
24 |
patient_strat=False, label_col = None, filter_dict = {}, eps=1e-6): |
|
|
25 |
r""" |
|
|
26 |
Generic_WSI_Survival_Dataset |
|
|
27 |
|
|
|
28 |
Args: |
|
|
29 |
csv_file (string): Path to the csv file with annotations. |
|
|
30 |
shuffle (boolean): Whether to shuffle |
|
|
31 |
seed (int): random seed for shuffling the data |
|
|
32 |
print_info (boolean): Whether to print a summary of the dataset |
|
|
33 |
label_dict (dict): Dictionary with key, value pairs for converting str labels to int |
|
|
34 |
ignore (list): List containing class labels to ignore |
|
|
35 |
""" |
|
|
36 |
self.custom_test_ids = None |
|
|
37 |
self.seed = seed |
|
|
38 |
self.print_info = print_info |
|
|
39 |
self.patient_strat = patient_strat |
|
|
40 |
self.train_ids, self.val_ids, self.test_ids = (None, None, None) |
|
|
41 |
self.data_dir = None |
|
|
42 |
|
|
|
43 |
if shuffle: |
|
|
44 |
np.random.seed(seed) |
|
|
45 |
np.random.shuffle(slide_data) |
|
|
46 |
|
|
|
47 |
slide_data = pd.read_csv(csv_path, low_memory=False) |
|
|
48 |
#slide_data = slide_data.drop(['Unnamed: 0'], axis=1) |
|
|
49 |
if 'case_id' not in slide_data: |
|
|
50 |
slide_data.index = slide_data.index.str[:12] |
|
|
51 |
slide_data['case_id'] = slide_data.index |
|
|
52 |
slide_data = slide_data.reset_index(drop=True) |
|
|
53 |
|
|
|
54 |
if not label_col: |
|
|
55 |
label_col = 'survival_months' |
|
|
56 |
else: |
|
|
57 |
assert label_col in slide_data.columns |
|
|
58 |
self.label_col = label_col |
|
|
59 |
|
|
|
60 |
if "IDC" in slide_data['oncotree_code']: # must be BRCA (and if so, use only IDCs) |
|
|
61 |
slide_data = slide_data[slide_data['oncotree_code'] == 'IDC'] |
|
|
62 |
|
|
|
63 |
patients_df = slide_data.drop_duplicates(['case_id']).copy() |
|
|
64 |
uncensored_df = patients_df[patients_df['censorship'] < 1] |
|
|
65 |
|
|
|
66 |
disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False) |
|
|
67 |
q_bins[-1] = slide_data[label_col].max() + eps |
|
|
68 |
q_bins[0] = slide_data[label_col].min() - eps |
|
|
69 |
|
|
|
70 |
disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True) |
|
|
71 |
patients_df.insert(2, 'label', disc_labels.values.astype(int)) |
|
|
72 |
|
|
|
73 |
patient_dict = {} |
|
|
74 |
slide_data = slide_data.set_index('case_id') |
|
|
75 |
for patient in patients_df['case_id']: |
|
|
76 |
slide_ids = slide_data.loc[patient, 'slide_id'] |
|
|
77 |
if isinstance(slide_ids, str): |
|
|
78 |
slide_ids = np.array(slide_ids).reshape(-1) |
|
|
79 |
else: |
|
|
80 |
slide_ids = slide_ids.values |
|
|
81 |
patient_dict.update({patient:slide_ids}) |
|
|
82 |
|
|
|
83 |
self.patient_dict = patient_dict |
|
|
84 |
|
|
|
85 |
slide_data = patients_df |
|
|
86 |
slide_data.reset_index(drop=True, inplace=True) |
|
|
87 |
slide_data = slide_data.assign(slide_id=slide_data['case_id']) |
|
|
88 |
|
|
|
89 |
label_dict = {} |
|
|
90 |
key_count = 0 |
|
|
91 |
for i in range(len(q_bins)-1): |
|
|
92 |
for c in [0, 1]: |
|
|
93 |
print('{} : {}'.format((i, c), key_count)) |
|
|
94 |
label_dict.update({(i, c):key_count}) |
|
|
95 |
key_count+=1 |
|
|
96 |
|
|
|
97 |
self.label_dict = label_dict |
|
|
98 |
for i in slide_data.index: |
|
|
99 |
key = slide_data.loc[i, 'label'] |
|
|
100 |
slide_data.at[i, 'disc_label'] = key |
|
|
101 |
censorship = slide_data.loc[i, 'censorship'] |
|
|
102 |
key = (key, int(censorship)) |
|
|
103 |
slide_data.at[i, 'label'] = label_dict[key] |
|
|
104 |
|
|
|
105 |
self.bins = q_bins |
|
|
106 |
self.num_classes=len(self.label_dict) |
|
|
107 |
patients_df = slide_data.drop_duplicates(['case_id']) |
|
|
108 |
self.patient_data = {'case_id':patients_df['case_id'].values, 'label':patients_df['label'].values} |
|
|
109 |
|
|
|
110 |
#new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2]) ### ICCV |
|
|
111 |
new_cols = list(slide_data.columns[-1:]) + list(slide_data.columns[:-1]) ### PORPOISE |
|
|
112 |
slide_data = slide_data[new_cols] |
|
|
113 |
self.slide_data = slide_data |
|
|
114 |
metadata = ['disc_label', 'Unnamed: 0', 'case_id', 'label', 'slide_id', 'age', 'site', 'survival_months', 'censorship', 'is_female', 'oncotree_code', 'train'] |
|
|
115 |
self.metadata = slide_data.columns[:12] |
|
|
116 |
|
|
|
117 |
for col in slide_data.drop(self.metadata, axis=1).columns: |
|
|
118 |
if not pd.Series(col).str.contains('|_cnv|_rnaseq|_rna|_mut')[0]: |
|
|
119 |
print(col) |
|
|
120 |
#pdb.set_trace() |
|
|
121 |
|
|
|
122 |
assert self.metadata.equals(pd.Index(metadata)) |
|
|
123 |
self.mode = mode |
|
|
124 |
self.cls_ids_prep() |
|
|
125 |
|
|
|
126 |
### ICCV discrepancies |
|
|
127 |
# For BLCA, TPTEP1_rnaseq was accidentally appended to the metadata |
|
|
128 |
#pdb.set_trace() |
|
|
129 |
|
|
|
130 |
if print_info: |
|
|
131 |
self.summarize() |
|
|
132 |
|
|
|
133 |
### Signatures |
|
|
134 |
self.apply_sig = apply_sig |
|
|
135 |
if self.apply_sig: |
|
|
136 |
self.signatures = pd.read_csv('./datasets_csv_sig/signatures.csv') |
|
|
137 |
else: |
|
|
138 |
self.signatures = None |
|
|
139 |
|
|
|
140 |
if print_info: |
|
|
141 |
self.summarize() |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
def cls_ids_prep(self): |
|
|
145 |
r""" |
|
|
146 |
|
|
|
147 |
""" |
|
|
148 |
self.patient_cls_ids = [[] for i in range(self.num_classes)] |
|
|
149 |
for i in range(self.num_classes): |
|
|
150 |
self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0] |
|
|
151 |
|
|
|
152 |
self.slide_cls_ids = [[] for i in range(self.num_classes)] |
|
|
153 |
for i in range(self.num_classes): |
|
|
154 |
self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
def patient_data_prep(self): |
|
|
158 |
r""" |
|
|
159 |
|
|
|
160 |
""" |
|
|
161 |
patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients |
|
|
162 |
patient_labels = [] |
|
|
163 |
|
|
|
164 |
for p in patients: |
|
|
165 |
locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist() |
|
|
166 |
assert len(locations) > 0 |
|
|
167 |
label = self.slide_data['label'][locations[0]] # get patient label |
|
|
168 |
patient_labels.append(label) |
|
|
169 |
|
|
|
170 |
self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)} |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
@staticmethod |
|
|
174 |
def df_prep(data, n_bins, ignore, label_col): |
|
|
175 |
r""" |
|
|
176 |
|
|
|
177 |
""" |
|
|
178 |
|
|
|
179 |
mask = data[label_col].isin(ignore) |
|
|
180 |
data = data[~mask] |
|
|
181 |
data.reset_index(drop=True, inplace=True) |
|
|
182 |
disc_labels, bins = pd.cut(data[label_col], bins=n_bins) |
|
|
183 |
return data, bins |
|
|
184 |
|
|
|
185 |
def __len__(self): |
|
|
186 |
if self.patient_strat: |
|
|
187 |
return len(self.patient_data['case_id']) |
|
|
188 |
else: |
|
|
189 |
return len(self.slide_data) |
|
|
190 |
|
|
|
191 |
def summarize(self): |
|
|
192 |
print("label column: {}".format(self.label_col)) |
|
|
193 |
print("label dictionary: {}".format(self.label_dict)) |
|
|
194 |
print("number of classes: {}".format(self.num_classes)) |
|
|
195 |
print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False)) |
|
|
196 |
for i in range(self.num_classes): |
|
|
197 |
print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0])) |
|
|
198 |
print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0])) |
|
|
199 |
|
|
|
200 |
|
|
|
201 |
def get_split_from_df(self, all_splits: dict, split_key: str='train', scaler=None): |
|
|
202 |
split = all_splits[split_key] |
|
|
203 |
split = split.dropna().reset_index(drop=True) |
|
|
204 |
|
|
|
205 |
if len(split) > 0: |
|
|
206 |
mask = self.slide_data['slide_id'].isin(split.tolist()) |
|
|
207 |
df_slice = self.slide_data[mask].reset_index(drop=True) |
|
|
208 |
split = Generic_Split(df_slice, metadata=self.metadata, mode=self.mode, signatures=self.signatures, data_dir=self.data_dir, label_col=self.label_col, patient_dict=self.patient_dict, num_classes=self.num_classes) |
|
|
209 |
else: |
|
|
210 |
split = None |
|
|
211 |
|
|
|
212 |
return split |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
def return_splits(self, from_id: bool=True, csv_path: str=None): |
|
|
216 |
if from_id: |
|
|
217 |
raise NotImplementedError |
|
|
218 |
else: |
|
|
219 |
assert csv_path |
|
|
220 |
all_splits = pd.read_csv(csv_path) |
|
|
221 |
train_split = self.get_split_from_df(all_splits=all_splits, split_key='train') |
|
|
222 |
val_split = self.get_split_from_df(all_splits=all_splits, split_key='val') |
|
|
223 |
test_split = None #self.get_split_from_df(all_splits=all_splits, split_key='test') |
|
|
224 |
|
|
|
225 |
### --> Normalizing Data |
|
|
226 |
print("****** Normalizing Data ******") |
|
|
227 |
scalers = train_split.get_scaler() |
|
|
228 |
train_split.apply_scaler(scalers=scalers) |
|
|
229 |
val_split.apply_scaler(scalers=scalers) |
|
|
230 |
#test_split.apply_scaler(scalers=scalers) |
|
|
231 |
### <-- |
|
|
232 |
return train_split, val_split#, test_split |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
def get_list(self, ids): |
|
|
236 |
return self.slide_data['slide_id'][ids] |
|
|
237 |
|
|
|
238 |
def getlabel(self, ids): |
|
|
239 |
return self.slide_data['label'][ids] |
|
|
240 |
|
|
|
241 |
def __getitem__(self, idx): |
|
|
242 |
return None |
|
|
243 |
|
|
|
244 |
def __getitem__(self, idx): |
|
|
245 |
return None |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
class Generic_MIL_Survival_Dataset(Generic_WSI_Survival_Dataset): |
|
|
249 |
def __init__(self, data_dir, mode: str='omic', **kwargs): |
|
|
250 |
super(Generic_MIL_Survival_Dataset, self).__init__(**kwargs) |
|
|
251 |
self.data_dir = data_dir |
|
|
252 |
self.mode = mode |
|
|
253 |
self.use_h5 = False |
|
|
254 |
|
|
|
255 |
def load_from_h5(self, toggle): |
|
|
256 |
self.use_h5 = toggle |
|
|
257 |
|
|
|
258 |
def __getitem__(self, idx): |
|
|
259 |
case_id = self.slide_data['case_id'][idx] |
|
|
260 |
label = torch.Tensor([self.slide_data['disc_label'][idx]]) |
|
|
261 |
event_time = torch.Tensor([self.slide_data[self.label_col][idx]]) |
|
|
262 |
c = torch.Tensor([self.slide_data['censorship'][idx]]) |
|
|
263 |
slide_ids = self.patient_dict[case_id] |
|
|
264 |
|
|
|
265 |
if type(self.data_dir) == dict: |
|
|
266 |
source = self.slide_data['oncotree_code'][idx] |
|
|
267 |
data_dir = self.data_dir[source] |
|
|
268 |
else: |
|
|
269 |
data_dir = self.data_dir |
|
|
270 |
|
|
|
271 |
if not self.use_h5: |
|
|
272 |
if self.data_dir: |
|
|
273 |
if self.mode == 'path': |
|
|
274 |
path_features = [] |
|
|
275 |
for slide_id in slide_ids: |
|
|
276 |
wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs'))) |
|
|
277 |
wsi_bag = torch.load(wsi_path) |
|
|
278 |
path_features.append(wsi_bag) |
|
|
279 |
path_features = torch.cat(path_features, dim=0) |
|
|
280 |
return (path_features, torch.zeros((1,1)), label, event_time, c) |
|
|
281 |
|
|
|
282 |
elif self.mode == 'cluster': |
|
|
283 |
path_features = [] |
|
|
284 |
cluster_ids = [] |
|
|
285 |
for slide_id in slide_ids: |
|
|
286 |
wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs'))) |
|
|
287 |
wsi_bag = torch.load(wsi_path) |
|
|
288 |
path_features.append(wsi_bag) |
|
|
289 |
cluster_ids.extend(self.fname2ids[slide_id[:-4]+'.pt']) |
|
|
290 |
path_features = torch.cat(path_features, dim=0) |
|
|
291 |
cluster_ids = torch.Tensor(cluster_ids) |
|
|
292 |
genomic_features = torch.tensor(self.genomic_features.iloc[idx]) |
|
|
293 |
return (path_features, cluster_ids, genomic_features, label, event_time, c) |
|
|
294 |
|
|
|
295 |
elif self.mode == 'omic': |
|
|
296 |
genomic_features = torch.tensor(self.genomic_features.iloc[idx]) |
|
|
297 |
return (torch.zeros((1,1)), genomic_features.unsqueeze(dim=0), label, event_time, c) |
|
|
298 |
|
|
|
299 |
elif self.mode == 'pathomic': |
|
|
300 |
path_features = [] |
|
|
301 |
for slide_id in slide_ids: |
|
|
302 |
wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs'))) |
|
|
303 |
wsi_bag = torch.load(wsi_path) |
|
|
304 |
path_features.append(wsi_bag) |
|
|
305 |
path_features = torch.cat(path_features, dim=0) |
|
|
306 |
genomic_features = torch.tensor(self.genomic_features.iloc[idx]) |
|
|
307 |
return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c) |
|
|
308 |
|
|
|
309 |
elif self.mode == 'pathomic_fast': |
|
|
310 |
casefeat_path = os.path.join(data_dir, f'split_{self.split_id}_case_pt', f'{case_id}.pt') |
|
|
311 |
path_features = torch.load(casefeat_path) |
|
|
312 |
genomic_features = torch.tensor(self.genomic_features.iloc[idx]) |
|
|
313 |
return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c) |
|
|
314 |
|
|
|
315 |
elif self.mode == 'coattn': |
|
|
316 |
path_features = [] |
|
|
317 |
for slide_id in slide_ids: |
|
|
318 |
wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs'))) |
|
|
319 |
wsi_bag = torch.load(wsi_path) |
|
|
320 |
path_features.append(wsi_bag) |
|
|
321 |
path_features = torch.cat(path_features, dim=0) |
|
|
322 |
omic1 = torch.tensor(self.genomic_features[self.omic_names[0]].iloc[idx]) |
|
|
323 |
omic2 = torch.tensor(self.genomic_features[self.omic_names[1]].iloc[idx]) |
|
|
324 |
omic3 = torch.tensor(self.genomic_features[self.omic_names[2]].iloc[idx]) |
|
|
325 |
omic4 = torch.tensor(self.genomic_features[self.omic_names[3]].iloc[idx]) |
|
|
326 |
omic5 = torch.tensor(self.genomic_features[self.omic_names[4]].iloc[idx]) |
|
|
327 |
omic6 = torch.tensor(self.genomic_features[self.omic_names[5]].iloc[idx]) |
|
|
328 |
return (path_features, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c) |
|
|
329 |
|
|
|
330 |
else: |
|
|
331 |
raise NotImplementedError('Mode [%s] not implemented.' % self.mode) |
|
|
332 |
else: |
|
|
333 |
return slide_ids, label, event_time, c |
|
|
334 |
|
|
|
335 |
|
|
|
336 |
class Generic_Split(Generic_MIL_Survival_Dataset): |
|
|
337 |
def __init__(self, slide_data, metadata, mode, |
|
|
338 |
signatures=None, data_dir=None, label_col=None, patient_dict=None, num_classes=2): |
|
|
339 |
self.use_h5 = False |
|
|
340 |
self.slide_data = slide_data |
|
|
341 |
self.metadata = metadata |
|
|
342 |
self.mode = mode |
|
|
343 |
self.data_dir = data_dir |
|
|
344 |
self.num_classes = num_classes |
|
|
345 |
self.label_col = label_col |
|
|
346 |
self.patient_dict = patient_dict |
|
|
347 |
self.slide_cls_ids = [[] for i in range(self.num_classes)] |
|
|
348 |
for i in range(self.num_classes): |
|
|
349 |
self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] |
|
|
350 |
|
|
|
351 |
### --> Initializing genomic features in Generic Split |
|
|
352 |
self.genomic_features = self.slide_data.drop(self.metadata, axis=1) |
|
|
353 |
self.signatures = signatures |
|
|
354 |
|
|
|
355 |
if mode == 'cluster': |
|
|
356 |
with open(os.path.join(data_dir, 'fast_cluster_ids.pkl'), 'rb') as handle: |
|
|
357 |
self.fname2ids = pickle.load(handle) |
|
|
358 |
|
|
|
359 |
def series_intersection(s1, s2): |
|
|
360 |
return pd.Series(list(set(s1) & set(s2))) |
|
|
361 |
|
|
|
362 |
if self.signatures is not None: |
|
|
363 |
self.omic_names = [] |
|
|
364 |
for col in self.signatures.columns: |
|
|
365 |
omic = self.signatures[col].dropna().unique() |
|
|
366 |
omic = np.concatenate([omic+mode for mode in ['_mut', '_cnv', '_rnaseq']]) |
|
|
367 |
omic = sorted(series_intersection(omic, self.genomic_features.columns)) |
|
|
368 |
self.omic_names.append(omic) |
|
|
369 |
self.omic_sizes = [len(omic) for omic in self.omic_names] |
|
|
370 |
print("Shape", self.genomic_features.shape) |
|
|
371 |
### <-- |
|
|
372 |
|
|
|
373 |
def __len__(self): |
|
|
374 |
return len(self.slide_data) |
|
|
375 |
|
|
|
376 |
### --> Getting StandardScaler of self.genomic_features |
|
|
377 |
def get_scaler(self): |
|
|
378 |
scaler_omic = StandardScaler().fit(self.genomic_features) |
|
|
379 |
return (scaler_omic,) |
|
|
380 |
### <-- |
|
|
381 |
|
|
|
382 |
### --> Applying StandardScaler to self.genomic_features |
|
|
383 |
def apply_scaler(self, scalers: tuple=None): |
|
|
384 |
transformed = pd.DataFrame(scalers[0].transform(self.genomic_features)) |
|
|
385 |
transformed.columns = self.genomic_features.columns |
|
|
386 |
self.genomic_features = transformed |
|
|
387 |
### <-- |
|
|
388 |
|
|
|
389 |
def set_split_id(self, split_id): |
|
|
390 |
self.split_id = split_id |