|
a |
|
b/datasets/dataset_generic.py |
|
|
1 |
from __future__ import print_function, division |
|
|
2 |
import os |
|
|
3 |
import torch |
|
|
4 |
import numpy as np |
|
|
5 |
import pandas as pd |
|
|
6 |
import math |
|
|
7 |
import re |
|
|
8 |
import pdb |
|
|
9 |
import pickle |
|
|
10 |
from scipy import stats |
|
|
11 |
|
|
|
12 |
from torch.utils.data import Dataset |
|
|
13 |
import h5py |
|
|
14 |
|
|
|
15 |
from utils.utils import generate_split, nth |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def save_splits(split_datasets, column_keys, filename, boolean_style=False): |
|
|
20 |
splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))] |
|
|
21 |
if not boolean_style: |
|
|
22 |
df = pd.concat(splits, ignore_index=True, axis=1) |
|
|
23 |
df.columns = column_keys |
|
|
24 |
else: |
|
|
25 |
df = pd.concat(splits, ignore_index = True, axis=0) |
|
|
26 |
index = df.values.tolist() |
|
|
27 |
one_hot = np.eye(len(split_datasets)).astype(bool) |
|
|
28 |
bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0) |
|
|
29 |
df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test']) |
|
|
30 |
|
|
|
31 |
df.to_csv(filename) |
|
|
32 |
print() |
|
|
33 |
|
|
|
34 |
class Generic_WSI_Classification_Dataset(Dataset): |
|
|
35 |
def __init__(self, |
|
|
36 |
csv_path = 'dataset_csv/ccrcc_clean.csv', |
|
|
37 |
shuffle = False, |
|
|
38 |
seed = 7, |
|
|
39 |
print_info = True, |
|
|
40 |
label_dict = {}, |
|
|
41 |
ignore=[], |
|
|
42 |
patient_strat=False, |
|
|
43 |
label_col = None, |
|
|
44 |
patient_voting = 'max', |
|
|
45 |
multi_site = False, |
|
|
46 |
filter_dict = {}, |
|
|
47 |
patient_level = False |
|
|
48 |
): |
|
|
49 |
""" |
|
|
50 |
Args: |
|
|
51 |
csv_file (string): Path to the csv file with annotations. |
|
|
52 |
shuffle (boolean): Whether to shuffle |
|
|
53 |
seed (int): random seed for shuffling the data |
|
|
54 |
print_info (boolean): Whether to print a summary of the dataset |
|
|
55 |
label_dict (dict): Dictionary with key, value pairs for converting str labels to int |
|
|
56 |
ignore (list): List containing class labels to ignore |
|
|
57 |
patient_voting (string): Rule for deciding the patient-level label |
|
|
58 |
""" |
|
|
59 |
self.custom_test_ids = None |
|
|
60 |
self.seed = seed |
|
|
61 |
self.print_info = print_info |
|
|
62 |
self.patient_strat = patient_strat |
|
|
63 |
self.train_ids, self.val_ids, self.test_ids = (None, None, None) |
|
|
64 |
self.data_dir = None |
|
|
65 |
self.split_gen = None |
|
|
66 |
self.patient_level = patient_level |
|
|
67 |
|
|
|
68 |
if not label_col: |
|
|
69 |
label_col = 'label' |
|
|
70 |
self.label_col = label_col |
|
|
71 |
|
|
|
72 |
slide_data = pd.read_csv(csv_path) |
|
|
73 |
slide_data = self.filter_df(slide_data, filter_dict) |
|
|
74 |
|
|
|
75 |
if multi_site: |
|
|
76 |
label_dict = self.init_multi_site_label_dict(slide_data, label_dict) |
|
|
77 |
|
|
|
78 |
self.label_dict = label_dict |
|
|
79 |
self.num_classes=len(set(self.label_dict.values())) |
|
|
80 |
|
|
|
81 |
slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col, multi_site) |
|
|
82 |
|
|
|
83 |
###shuffle data |
|
|
84 |
if shuffle: |
|
|
85 |
np.random.seed(seed) |
|
|
86 |
np.random.shuffle(slide_data) |
|
|
87 |
|
|
|
88 |
self.slide_data = slide_data |
|
|
89 |
|
|
|
90 |
self.patient_data_prep(patient_voting) |
|
|
91 |
self.cls_ids_prep() |
|
|
92 |
|
|
|
93 |
if print_info: |
|
|
94 |
self.summarize() |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
if self.patient_level: |
|
|
98 |
self.patient_dict = self.build_patient_dict() |
|
|
99 |
#self.slide_data = self.slide_data.drop_duplicates(subset=['case_id']) |
|
|
100 |
else: |
|
|
101 |
self.patient_dict = {} |
|
|
102 |
|
|
|
103 |
|
|
|
104 |
def build_patient_dict(self): |
|
|
105 |
patient_dict = {} |
|
|
106 |
patient_cases = self.slide_data['case_id'].unique() |
|
|
107 |
slide_cases = self.slide_data.set_index('case_id') |
|
|
108 |
|
|
|
109 |
for patient in patient_cases: |
|
|
110 |
slide_ids = slide_cases.loc[patient,'slide_id'] |
|
|
111 |
|
|
|
112 |
if isinstance(slide_ids, str): |
|
|
113 |
slide_ids = np.array(slide_ids).reshape(-1) |
|
|
114 |
else: |
|
|
115 |
slide_ids = slide_ids.values |
|
|
116 |
|
|
|
117 |
patient_dict.update({patient:slide_ids}) |
|
|
118 |
|
|
|
119 |
return patient_dict |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def cls_ids_prep(self): |
|
|
123 |
# store ids corresponding each class at the patient or case level |
|
|
124 |
self.patient_cls_ids = [[] for i in range(self.num_classes)] |
|
|
125 |
for i in range(self.num_classes): |
|
|
126 |
self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0] |
|
|
127 |
|
|
|
128 |
# store ids corresponding each class at the slide level |
|
|
129 |
self.slide_cls_ids = [[] for i in range(self.num_classes)] |
|
|
130 |
for i in range(self.num_classes): |
|
|
131 |
self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] |
|
|
132 |
|
|
|
133 |
def patient_data_prep(self, patient_voting='max'): |
|
|
134 |
patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients |
|
|
135 |
patient_labels = [] |
|
|
136 |
|
|
|
137 |
for p in patients: |
|
|
138 |
locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist() |
|
|
139 |
assert len(locations) > 0 |
|
|
140 |
label = self.slide_data['label'][locations].values |
|
|
141 |
if patient_voting == 'max': |
|
|
142 |
label = label.max() # get patient label (MIL convention) |
|
|
143 |
elif patient_voting == 'maj': |
|
|
144 |
label = stats.mode(label)[0] |
|
|
145 |
else: |
|
|
146 |
raise NotImplementedError |
|
|
147 |
patient_labels.append(label) |
|
|
148 |
|
|
|
149 |
self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)} |
|
|
150 |
|
|
|
151 |
@staticmethod |
|
|
152 |
def init_multi_site_label_dict(slide_data, label_dict): |
|
|
153 |
print('initiating multi-source label dictionary') |
|
|
154 |
sites = np.unique(slide_data['site'].values) |
|
|
155 |
multi_site_dict = {} |
|
|
156 |
num_classes = len(label_dict) |
|
|
157 |
for key, val in label_dict.items(): |
|
|
158 |
for idx, site in enumerate(sites): |
|
|
159 |
site_key = (key, site) |
|
|
160 |
site_val = val+idx*num_classes |
|
|
161 |
multi_site_dict.update({site_key:site_val}) |
|
|
162 |
print('{} : {}'.format(site_key, site_val)) |
|
|
163 |
return multi_site_dict |
|
|
164 |
|
|
|
165 |
@staticmethod |
|
|
166 |
def filter_df(df, filter_dict={}): |
|
|
167 |
if len(filter_dict) > 0: |
|
|
168 |
filter_mask = np.full(len(df), True, bool) |
|
|
169 |
# assert 'label' not in filter_dict.keys() |
|
|
170 |
for key, val in filter_dict.items(): |
|
|
171 |
mask = df[key].isin(val) |
|
|
172 |
filter_mask = np.logical_and(filter_mask, mask) |
|
|
173 |
df = df[filter_mask] |
|
|
174 |
return df |
|
|
175 |
|
|
|
176 |
@staticmethod |
|
|
177 |
def df_prep(data, label_dict, ignore, label_col, multi_site=False): |
|
|
178 |
if label_col != 'label': |
|
|
179 |
data['label'] = data[label_col].copy() |
|
|
180 |
|
|
|
181 |
mask = data['label'].isin(ignore) |
|
|
182 |
data = data[~mask] |
|
|
183 |
data.reset_index(drop=True, inplace=True) |
|
|
184 |
for i in data.index: |
|
|
185 |
key = data.loc[i, 'label'] |
|
|
186 |
if multi_site: |
|
|
187 |
site = data.loc[i, 'site'] |
|
|
188 |
key = (key, site) |
|
|
189 |
data.at[i, 'label'] = label_dict[key] |
|
|
190 |
|
|
|
191 |
return data |
|
|
192 |
|
|
|
193 |
def __len__(self): |
|
|
194 |
if self.patient_strat: |
|
|
195 |
return len(self.patient_data['case_id']) |
|
|
196 |
|
|
|
197 |
else: |
|
|
198 |
return len(self.slide_data) |
|
|
199 |
|
|
|
200 |
def summarize(self): |
|
|
201 |
print("label column: {}".format(self.label_col)) |
|
|
202 |
print("label dictionary: {}".format(self.label_dict)) |
|
|
203 |
print("number of classes: {}".format(self.num_classes)) |
|
|
204 |
print("slide-level counts: ", self.slide_data['label'].value_counts(sort = False)) |
|
|
205 |
for i in range(self.num_classes): |
|
|
206 |
print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0])) |
|
|
207 |
print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0])) |
|
|
208 |
|
|
|
209 |
def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None): |
|
|
210 |
settings = { |
|
|
211 |
'n_splits' : k, |
|
|
212 |
'val_num' : val_num, |
|
|
213 |
'test_num': test_num, |
|
|
214 |
'label_frac': label_frac, |
|
|
215 |
'seed': self.seed, |
|
|
216 |
'custom_test_ids': custom_test_ids |
|
|
217 |
} |
|
|
218 |
|
|
|
219 |
if self.patient_strat: |
|
|
220 |
settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])}) |
|
|
221 |
else: |
|
|
222 |
settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)}) |
|
|
223 |
|
|
|
224 |
self.split_gen = generate_split(**settings) |
|
|
225 |
|
|
|
226 |
def sample_held_out(self, test_num = (40, 40)): |
|
|
227 |
|
|
|
228 |
test_ids = [] |
|
|
229 |
np.random.seed(self.seed) #fix seed |
|
|
230 |
|
|
|
231 |
if self.patient_strat: |
|
|
232 |
cls_ids = self.patient_cls_ids |
|
|
233 |
else: |
|
|
234 |
cls_ids = self.slide_cls_ids |
|
|
235 |
|
|
|
236 |
for c in range(len(test_num)): |
|
|
237 |
test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids |
|
|
238 |
|
|
|
239 |
# if self.patient_strat: |
|
|
240 |
# slide_ids = [] |
|
|
241 |
# for idx in test_ids: |
|
|
242 |
# case_id = self.patient_data['case_id'][idx] |
|
|
243 |
# slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist() |
|
|
244 |
# slide_ids.extend(slide_indices) |
|
|
245 |
|
|
|
246 |
# return slide_ids |
|
|
247 |
# else: |
|
|
248 |
# return test_ids |
|
|
249 |
return test_ids |
|
|
250 |
|
|
|
251 |
def set_splits(self,start_from=None): |
|
|
252 |
if start_from: |
|
|
253 |
ids = nth(self.split_gen, start_from) |
|
|
254 |
|
|
|
255 |
else: |
|
|
256 |
ids = next(self.split_gen) |
|
|
257 |
|
|
|
258 |
if self.patient_strat: |
|
|
259 |
slide_ids = [[] for i in range(len(ids))] |
|
|
260 |
|
|
|
261 |
for split in range(len(ids)): |
|
|
262 |
for idx in ids[split]: |
|
|
263 |
case_id = self.patient_data['case_id'][idx] |
|
|
264 |
slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist() |
|
|
265 |
slide_ids[split].extend(slide_indices) |
|
|
266 |
|
|
|
267 |
self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2] |
|
|
268 |
|
|
|
269 |
else: |
|
|
270 |
self.train_ids, self.val_ids, self.test_ids = ids |
|
|
271 |
|
|
|
272 |
def get_split_from_df(self, all_splits=None, split_key='train', split=None, return_ids_only=False): |
|
|
273 |
if split is None: |
|
|
274 |
split = all_splits[split_key] |
|
|
275 |
split = split.dropna().reset_index(drop=True) |
|
|
276 |
|
|
|
277 |
if len(split) > 0: |
|
|
278 |
mask = self.slide_data['slide_id'].isin(split.tolist()) |
|
|
279 |
if return_ids_only: |
|
|
280 |
ids = np.where(mask)[0] |
|
|
281 |
return ids |
|
|
282 |
|
|
|
283 |
df_slice = self.slide_data[mask].dropna().reset_index(drop=True) |
|
|
284 |
split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, patient_level=self.patient_level) |
|
|
285 |
else: |
|
|
286 |
split = None |
|
|
287 |
|
|
|
288 |
return split |
|
|
289 |
|
|
|
290 |
def get_merged_split_from_df(self, all_splits, split_keys=['train']): |
|
|
291 |
merged_split = [] |
|
|
292 |
for split_key in split_keys: |
|
|
293 |
split = all_splits[split_key] |
|
|
294 |
split = split.dropna().reset_index(drop=True).tolist() |
|
|
295 |
merged_split.extend(split) |
|
|
296 |
|
|
|
297 |
if len(split) > 0: |
|
|
298 |
mask = self.slide_data['slide_id'].isin(merged_split) |
|
|
299 |
df_slice = self.slide_data[mask].dropna().reset_index(drop=True) |
|
|
300 |
split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level) |
|
|
301 |
else: |
|
|
302 |
split = None |
|
|
303 |
|
|
|
304 |
return split |
|
|
305 |
|
|
|
306 |
|
|
|
307 |
def return_splits(self, from_id=True, csv_path=None): |
|
|
308 |
|
|
|
309 |
|
|
|
310 |
if from_id: |
|
|
311 |
if len(self.train_ids) > 0: |
|
|
312 |
train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True) |
|
|
313 |
train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level) |
|
|
314 |
|
|
|
315 |
else: |
|
|
316 |
train_split = None |
|
|
317 |
|
|
|
318 |
if len(self.val_ids) > 0: |
|
|
319 |
val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True) |
|
|
320 |
val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level) |
|
|
321 |
|
|
|
322 |
else: |
|
|
323 |
val_split = None |
|
|
324 |
|
|
|
325 |
if len(self.test_ids) > 0: |
|
|
326 |
test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True) |
|
|
327 |
test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level) |
|
|
328 |
|
|
|
329 |
else: # NO TEST SET - USE COPY OF VALIDATION SET |
|
|
330 |
#test_split = None |
|
|
331 |
test_data = self.slide_data.loc[self.val_ids].reset_index(drop=True) |
|
|
332 |
test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, patient_level = self.patient_level) |
|
|
333 |
|
|
|
334 |
else: |
|
|
335 |
assert csv_path |
|
|
336 |
all_splits = pd.read_csv(csv_path) |
|
|
337 |
train_split = self.get_split_from_df(all_splits, 'train') |
|
|
338 |
val_split = self.get_split_from_df(all_splits, 'val') |
|
|
339 |
test_split = self.get_split_from_df(all_splits, 'test') |
|
|
340 |
|
|
|
341 |
return train_split, val_split, test_split |
|
|
342 |
|
|
|
343 |
def get_list(self, ids): |
|
|
344 |
return self.slide_data['slide_id'][ids] |
|
|
345 |
|
|
|
346 |
def getlabel(self, ids): |
|
|
347 |
return self.slide_data['label'][ids] |
|
|
348 |
|
|
|
349 |
def __getitem__(self, idx): |
|
|
350 |
return None |
|
|
351 |
|
|
|
352 |
def test_split_gen(self, return_descriptor=False): |
|
|
353 |
if return_descriptor: |
|
|
354 |
index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)] |
|
|
355 |
columns = ['train', 'val', 'test'] |
|
|
356 |
df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index, |
|
|
357 |
columns= columns) |
|
|
358 |
count = len(self.train_ids) |
|
|
359 |
print('\nnumber of training samples: {}'.format(count)) |
|
|
360 |
labels = self.getlabel(self.train_ids) |
|
|
361 |
unique, counts = np.unique(labels, return_counts=True) |
|
|
362 |
missing_classes = np.setdiff1d(np.arange(self.num_classes), unique) |
|
|
363 |
unique = np.append(unique, missing_classes) |
|
|
364 |
counts = np.append(counts, np.full(len(missing_classes), 0)) |
|
|
365 |
inds = unique.argsort() |
|
|
366 |
counts = counts[inds] |
|
|
367 |
for u in range(len(unique)): |
|
|
368 |
print('number of samples in cls {}: {}'.format(unique[u], counts[u])) |
|
|
369 |
if return_descriptor: |
|
|
370 |
df.loc[index[u], 'train'] = counts[u] |
|
|
371 |
|
|
|
372 |
count = len(self.val_ids) |
|
|
373 |
print('\nnumber of val samples: {}'.format(count)) |
|
|
374 |
labels = self.getlabel(self.val_ids) |
|
|
375 |
unique, counts = np.unique(labels, return_counts=True) |
|
|
376 |
missing_classes = np.setdiff1d(np.arange(self.num_classes), unique) |
|
|
377 |
unique = np.append(unique, missing_classes) |
|
|
378 |
counts = np.append(counts, np.full(len(missing_classes), 0)) |
|
|
379 |
inds = unique.argsort() |
|
|
380 |
counts = counts[inds] |
|
|
381 |
for u in range(len(unique)): |
|
|
382 |
print('number of samples in cls {}: {}'.format(unique[u], counts[u])) |
|
|
383 |
if return_descriptor: |
|
|
384 |
df.loc[index[u], 'val'] = counts[u] |
|
|
385 |
|
|
|
386 |
count = len(self.test_ids) |
|
|
387 |
print('\nnumber of test samples: {}'.format(count)) |
|
|
388 |
labels = self.getlabel(self.test_ids) |
|
|
389 |
unique, counts = np.unique(labels, return_counts=True) |
|
|
390 |
missing_classes = np.setdiff1d(np.arange(self.num_classes), unique) |
|
|
391 |
unique = np.append(unique, missing_classes) |
|
|
392 |
counts = np.append(counts, np.full(len(missing_classes), 0)) |
|
|
393 |
inds = unique.argsort() |
|
|
394 |
counts = counts[inds] |
|
|
395 |
for u in range(len(unique)): |
|
|
396 |
print('number of samples in cls {}: {}'.format(unique[u], counts[u])) |
|
|
397 |
if return_descriptor: |
|
|
398 |
df.loc[index[u], 'test'] = counts[u] |
|
|
399 |
|
|
|
400 |
assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0 |
|
|
401 |
assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0 |
|
|
402 |
assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0 |
|
|
403 |
|
|
|
404 |
if return_descriptor: |
|
|
405 |
return df |
|
|
406 |
|
|
|
407 |
def save_split(self, filename): |
|
|
408 |
train_split = self.get_list(self.train_ids) |
|
|
409 |
val_split = self.get_list(self.val_ids) |
|
|
410 |
test_split = self.get_list(self.test_ids) |
|
|
411 |
df_tr = pd.DataFrame({'train': train_split}) |
|
|
412 |
df_v = pd.DataFrame({'val': val_split}) |
|
|
413 |
df_t = pd.DataFrame({'test': test_split}) |
|
|
414 |
df = pd.concat([df_tr, df_v, df_t], axis=1) |
|
|
415 |
df.to_csv(filename, index = False) |
|
|
416 |
|
|
|
417 |
|
|
|
418 |
class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset): |
|
|
419 |
def __init__(self, |
|
|
420 |
data_dir, |
|
|
421 |
**kwargs): |
|
|
422 |
super(Generic_MIL_Dataset, self).__init__(**kwargs) |
|
|
423 |
self.data_dir = data_dir |
|
|
424 |
self.use_h5 = False |
|
|
425 |
|
|
|
426 |
def load_from_h5(self, toggle): |
|
|
427 |
self.use_h5 = toggle |
|
|
428 |
|
|
|
429 |
def __getitem__(self, idx): |
|
|
430 |
|
|
|
431 |
if not self.patient_level: |
|
|
432 |
slide_id = self.slide_data['slide_id'][idx] |
|
|
433 |
label = self.slide_data['label'][idx] |
|
|
434 |
if type(self.data_dir) == dict: |
|
|
435 |
source = self.slide_data['source'][idx] |
|
|
436 |
data_dir = self.data_dir[source] |
|
|
437 |
else: |
|
|
438 |
data_dir = self.data_dir |
|
|
439 |
|
|
|
440 |
if not self.use_h5: |
|
|
441 |
if self.data_dir: |
|
|
442 |
full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id)) |
|
|
443 |
features = torch.load(full_path) |
|
|
444 |
return features, label |
|
|
445 |
|
|
|
446 |
else: |
|
|
447 |
return slide_id, label |
|
|
448 |
|
|
|
449 |
else: |
|
|
450 |
full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id)) |
|
|
451 |
with h5py.File(full_path,'r') as hdf5_file: |
|
|
452 |
features = hdf5_file['features'][:] |
|
|
453 |
coords = hdf5_file['coords'][:] |
|
|
454 |
|
|
|
455 |
features = torch.from_numpy(features) |
|
|
456 |
return features, label, coords |
|
|
457 |
|
|
|
458 |
else: |
|
|
459 |
|
|
|
460 |
case_id = self.slide_data['case_id'][idx] |
|
|
461 |
label = self.slide_data['label'][idx] |
|
|
462 |
slide_ids = self.patient_dict[case_id] |
|
|
463 |
|
|
|
464 |
if type(self.data_dir) == dict: |
|
|
465 |
source = self.slide_data['source'][idx] |
|
|
466 |
data_dir = self.data_dir[source] |
|
|
467 |
else: |
|
|
468 |
data_dir = self.data_dir |
|
|
469 |
|
|
|
470 |
if not self.use_h5: |
|
|
471 |
features_list = [] |
|
|
472 |
|
|
|
473 |
for slide_id in slide_ids: |
|
|
474 |
full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id)) |
|
|
475 |
slide_features = torch.load(full_path) |
|
|
476 |
features_list.append(slide_features) |
|
|
477 |
|
|
|
478 |
features = torch.cat( features_list, dim = 0) |
|
|
479 |
return features, label |
|
|
480 |
|
|
|
481 |
else: |
|
|
482 |
features_list = [] |
|
|
483 |
coords_list = [] |
|
|
484 |
|
|
|
485 |
for slide_id in slide_ids: |
|
|
486 |
full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id)) |
|
|
487 |
with h5py.File(full_path,'r') as hdf5_file: |
|
|
488 |
slide_features = hdf5_file['features'][:] |
|
|
489 |
slide_coords = hdf5_file['coords'][:] |
|
|
490 |
silide_features_t = torch.from_numpy(slide_features) |
|
|
491 |
slide_coords_t = torch.from_numpy(slide_coords) |
|
|
492 |
|
|
|
493 |
features_list.append( slide_features_t ) |
|
|
494 |
coords_list.append( slide_coords_t ) |
|
|
495 |
|
|
|
496 |
|
|
|
497 |
features = troch.cat( features_list, dim = 0) |
|
|
498 |
coords = torch.cat( coords_list, dim = 0) |
|
|
499 |
return features, label, coords |
|
|
500 |
|
|
|
501 |
|
|
|
502 |
class Generic_Split(Generic_MIL_Dataset): |
|
|
503 |
def __init__(self, slide_data, data_dir=None, num_classes=2, patient_level=False): |
|
|
504 |
self.use_h5 = False |
|
|
505 |
self.slide_data = slide_data |
|
|
506 |
self.data_dir = data_dir |
|
|
507 |
self.num_classes = num_classes |
|
|
508 |
self.slide_cls_ids = [[] for i in range(self.num_classes)] |
|
|
509 |
for i in range(self.num_classes): |
|
|
510 |
self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] |
|
|
511 |
|
|
|
512 |
self.patient_level = patient_level |
|
|
513 |
if self.patient_level: |
|
|
514 |
self.patient_dict = self.build_patient_dict() |
|
|
515 |
#self.slide_data = self.slide_data.drop_duplicates(subset=['case_id']) |
|
|
516 |
else: |
|
|
517 |
self.patient_dict = {} |
|
|
518 |
|
|
|
519 |
def __len__(self): |
|
|
520 |
return len(self.slide_data) |
|
|
521 |
|
|
|
522 |
|
|
|
523 |
class Generic_WSI_Inference_Dataset(Dataset): |
|
|
524 |
def __init__(self, |
|
|
525 |
data_dir, |
|
|
526 |
csv_path = None, |
|
|
527 |
print_info = True, |
|
|
528 |
): |
|
|
529 |
self.data_dir = data_dir |
|
|
530 |
self.print_info = print_info |
|
|
531 |
|
|
|
532 |
if csv_path is not None: |
|
|
533 |
data = pd.read_csv(csv_path) |
|
|
534 |
self.slide_data = data['slide_id'].values |
|
|
535 |
else: |
|
|
536 |
data = np.array(os.listdir(data_dir)) |
|
|
537 |
self.slide_data = np.char.strip(data, chars ='.pt') |
|
|
538 |
if print_info: |
|
|
539 |
print('total number of slides to infer: ', len(self.slide_data)) |
|
|
540 |
|
|
|
541 |
def __len__(self): |
|
|
542 |
return len(self.slide_data) |
|
|
543 |
|
|
|
544 |
def __getitem__(self, idx): |
|
|
545 |
slide_file = self.slide_data[idx]+'.pt' |
|
|
546 |
full_path = os.path.join(self.data_dir, 'pt_files',slide_file) |
|
|
547 |
features = torch.load(full_path) |
|
|
548 |
return features |