--- a +++ b/datasets/dataset_mtl.py @@ -0,0 +1,560 @@ +from __future__ import print_function, division +import os +import torch +import numpy as np +import pandas as pd +import math +import re +import pdb +import pickle +from scipy import stats + +from torch.utils.data import Dataset +import h5py + +from utils.utils import generate_split, nth + + +def save_splits(split_datasets, column_keys, filename, boolean_style=False): + splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))] + if not boolean_style: + df = pd.concat(splits, ignore_index=True, axis=1) + df.columns = column_keys + else: + df = pd.concat(splits, ignore_index = True, axis=0) + index = df.values.tolist() + one_hot = np.eye(len(split_datasets)).astype(bool) + bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0) + df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test']) + + df.to_csv(filename) + print() + +class Generic_WSI_MTL_Dataset(Dataset): + def __init__(self, + csv_path = 'dataset_csv/ccrcc_clean.csv', + shuffle = False, + seed = 7, + print_info = True, + label_dicts = [{}, {}], + ignore=[], + patient_strat=False, + label_cols = ['label', 'label_2'], + patient_voting = 'max', + multi_site = False, + filter_dict = {}, + patient_level = False + ): + """ + Args: + csv_file (string): Path to the csv file with annotations. + shuffle (boolean): Whether to shuffle + seed (int): random seed for shuffling the data + print_info (boolean): Whether to print a summary of the dataset + label_dict (dict): Dictionary with key, value pairs for converting str labels to int + ignore (list): List containing class labels to ignore + patient_voting (string): Rule for deciding the patient-level label + """ + self.custom_test_ids = None + self.seed = seed + self.print_info = print_info + self.patient_strat = patient_strat + self.train_ids, self.val_ids, self.test_ids = (None, None, None) + self.data_dir = None + self.label_cols = label_cols + + slide_data = pd.read_csv(csv_path) + slide_data = self.filter_df(slide_data, filter_dict) + + self.patient_level = patient_level + + if multi_site: + label_dicts[0] = self.init_multi_site_label_dict(slide_data, label_dicts[0]) + + self.label_dicts = label_dicts + self.num_classes=[len(set(label_dict.values())) for label_dict in self.label_dicts] + + slide_data = self.df_prep(slide_data, self.label_dicts, ignore, self.label_cols, multi_site) + ###shuffle data + if shuffle: + np.random.seed(seed) + np.random.shuffle(slide_data) + + self.slide_data = slide_data + + #self.patient_data_prep(patient_voting) + #self.cls_ids_prep() + + #if print_info: + # self.summarize() + + + if self.patient_level: + self.patient_dict = self.build_patient_dict() + #self.slide_data = self.slide_data.drop_duplicates(subset=['case_id']) + else: + self.patient_dict = {} + + + def build_patient_dict(self): + patient_dict = {} + patient_cases = self.slide_data['case_id'].unique() + slide_cases = self.slide_data.set_index('case_id') + + for patient in patient_cases: + slide_ids = slide_cases.loc[patient,'slide_id'] + + if isinstance(slide_ids, str): + slide_ids = np.array(slide_ids).reshape(-1) + else: + slide_ids = slide_ids.values + + patient_dict.update({patient:slide_ids}) + + return patient_dict + + + def cls_ids_prep(self): + + b_weighted_samples=False + + if(b_weighted_samples): + + # store ids corresponding each class at the patient or case level + self.patient_cls_ids = [[] for i in range(self.num_classes[0])] + for i in range(self.num_classes[0]): + self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0] + + # store ids corresponding each class at the slide level + self.slide_cls_ids = [[] for i in range(self.num_classes[0])] + for i in range(self.num_classes[0]): + self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] + + else: + self.patient_cls_ids = None + self.slide_cls_ids = None + + + def patient_data_prep(self, patient_voting='max'): + patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients + patient_labels = [] + + for p in patients: + locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist() + assert len(locations) > 0 + label = self.slide_data['label'][locations].values + if patient_voting == 'max': + label = label.max() # get patient label (MIL convention) + elif patient_voting == 'maj': + label = stats.mode(label)[0] + else: + raise NotImplementedError + patient_labels.append(label) + + self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)} + + @staticmethod + def init_multi_site_label_dict(slide_data, label_dict): + print('initiating multi-source label dictionary') + sites = np.unique(slide_data['site'].values) + multi_site_dict = {} + num_classes = len(label_dict) + for key, val in label_dict.items(): + for idx, site in enumerate(sites): + site_key = (key, site) + site_val = val+idx*num_classes + multi_site_dict.update({site_key:site_val}) + print('{} : {}'.format(site_key, site_val)) + return multi_site_dict + + @staticmethod + def filter_df(df, filter_dict={}): + if len(filter_dict) > 0: + filter_mask = np.full(len(df), True, bool) + # assert 'label' not in filter_dict.keys() + for key, val in filter_dict.items(): + mask = df[key].isin(val) + filter_mask = np.logical_and(filter_mask, mask) + df = df[filter_mask] + return df + + @staticmethod + def df_prep(data, label_dicts, ignore, label_cols, multi_site=False): + for idx, (label_dict, label_col) in enumerate(zip(label_dicts, label_cols)): + print(label_dict, label_col) + data[label_col] = data[label_col].map(label_dict) + + return data + + def __len__(self): + if self.patient_strat: + return len(self.patient_data['case_id']) + + else: + return len(self.slide_data) + + def summarize(self): + + for task in range(len(self.label_dicts)): + print('task: ', task) + print("label column: {}".format(self.label_cols[task])) + print("label dictionary: {}".format(self.label_dicts[task])) + print("number of classes: {}".format(self.num_classes[task])) + print("slide-level counts: ", '\n', self.slide_data[self.label_cols[task]].value_counts(sort = False)) + + for i in range(self.num_classes[0]): + print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0])) + print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0])) + + def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None): + settings = { + 'n_splits' : k, + 'val_num' : val_num, + 'test_num': test_num, + 'label_frac': label_frac, + 'seed': self.seed, + 'custom_test_ids': custom_test_ids + } + + if self.patient_strat: + settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])}) + else: + settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)}) + + self.split_gen = generate_split(**settings) + + def sample_held_out(self, test_num = (40, 40)): + + test_ids = [] + np.random.seed(self.seed) #fix seed + + if self.patient_strat: + cls_ids = self.patient_cls_ids + else: + cls_ids = self.slide_cls_ids + + for c in range(len(test_num)): + test_ids.extend(np.random.choice(cls_ids[c], test_num[c], replace = False)) # validation ids + + if self.patient_strat: + slide_ids = [] + for idx in test_ids: + case_id = self.patient_data['case_id'][idx] + slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist() + slide_ids.extend(slide_indices) + + return slide_ids + else: + return test_ids + + def set_splits(self,start_from=None): + if start_from: + ids = nth(self.split_gen, start_from) + + else: + ids = next(self.split_gen) + + if self.patient_strat: + slide_ids = [[] for i in range(len(ids))] + + for split in range(len(ids)): + for idx in ids[split]: + case_id = self.patient_data['case_id'][idx] + slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist() + slide_ids[split].extend(slide_indices) + + self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2] + + else: + self.train_ids, self.val_ids, self.test_ids = ids + + def get_split_from_df(self, all_splits=None, split_key='train', split=None): + if split is None: + split = all_splits[split_key] + split = split.dropna().reset_index(drop=True) + + if len(split) > 0: + mask = self.slide_data['slide_id'].isin(split.tolist()) + df_slice = self.slide_data[mask].dropna().reset_index(drop=True) + split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level) + else: + split = None + + return split + + def get_merged_split_from_df(self, all_splits, split_keys=['train']): + merged_split = [] + for split_key in split_keys: + split = all_splits[split_key] + split = split.dropna().reset_index(drop=True).tolist() + merged_split.extend(split) + + if len(split) > 0: + mask = self.slide_data['slide_id'].isin(merged_split) + df_slice = self.slide_data[mask].dropna().reset_index(drop=True) + split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level) + else: + split = None + + return split + + + def return_splits(self, from_id=True, csv_path=None): + + + if from_id: + if len(self.train_ids) > 0: + train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True) + train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level) + + else: + train_split = None + + if len(self.val_ids) > 0: + val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True) + val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level) + + else: + val_split = None + + if len(self.test_ids) > 0: + test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True) + test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes, label_cols=self.label_cols, patient_level = self.patient_level) + + else: + test_split = None + + + else: + assert csv_path + all_splits = pd.read_csv(csv_path) + train_split = self.get_split_from_df(all_splits, 'train') + val_split = self.get_split_from_df(all_splits, 'val') + test_split = self.get_split_from_df(all_splits, 'test') + + return train_split, val_split, test_split + + def get_list(self, ids): + return self.slide_data['slide_id'][ids] + + def getlabel(self, ids, task): + if task > 0: + return self.slide_data[self.label_cols[task]][ids] + else: + return self.slide_data['label'][ids] + + def __getitem__(self, idx): + return None + + def test_split_gen(self, return_descriptor=False): + if return_descriptor: + dfs = [] + for task in range(len(self.label_dicts)): + index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])] + columns = ['train', 'val', 'test'] + df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index, + columns= columns) + dfs.append(df) + + + + + for task in range(len(self.label_dicts)): + count = len(self.train_ids) + print('\nnumber of training samples: {}'.format(count)) + index = [list(self.label_dicts[task].keys())[list(self.label_dicts[task].values()).index(i)] for i in range(self.num_classes[task])] + labels = self.getlabel(self.train_ids, task) + unique, counts = np.unique(labels, return_counts=True) + missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique) + unique = np.append(unique, missing_classes) + counts = np.append(counts, np.full(len(missing_classes), 0)) + inds = unique.argsort() + counts = counts[inds] + for u in range(len(unique)): + print('number of samples in cls {}: {}'.format(unique[u], counts[u])) + if return_descriptor: + dfs[task].loc[index[u], 'train'] = counts[u] + + count = len(self.val_ids) + print('\nnumber of val samples: {}'.format(count)) + labels = self.getlabel(self.val_ids, task) + unique, counts = np.unique(labels, return_counts=True) + missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique) + unique = np.append(unique, missing_classes) + counts = np.append(counts, np.full(len(missing_classes), 0)) + inds = unique.argsort() + counts = counts[inds] + for u in range(len(unique)): + print('number of samples in cls {}: {}'.format(unique[u], counts[u])) + if return_descriptor: + dfs[task].loc[index[u], 'val'] = counts[u] + + count = len(self.test_ids) + print('\nnumber of test samples: {}'.format(count)) + labels = self.getlabel(self.test_ids, task) + unique, counts = np.unique(labels, return_counts=True) + missing_classes = np.setdiff1d(np.arange(self.num_classes[task]), unique) + unique = np.append(unique, missing_classes) + counts = np.append(counts, np.full(len(missing_classes), 0)) + inds = unique.argsort() + counts = counts[inds] + for u in range(len(unique)): + print('number of samples in cls {}: {}'.format(unique[u], counts[u])) + if return_descriptor: + dfs[task].loc[index[u], 'test'] = counts[u] + + assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0 + assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0 + assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0 + + if return_descriptor: + df = pd.concat(dfs, axis=0) + return df + + def save_split(self, filename): + train_split = self.get_list(self.train_ids) + val_split = self.get_list(self.val_ids) + test_split = self.get_list(self.test_ids) + df_tr = pd.DataFrame({'train': train_split}) + df_v = pd.DataFrame({'val': val_split}) + df_t = pd.DataFrame({'test': test_split}) + df = pd.concat([df_tr, df_v, df_t], axis=1) + df.to_csv(filename, index = False) + + +class Generic_MIL_MTL_Dataset(Generic_WSI_MTL_Dataset): + def __init__(self, + data_dir, + **kwargs): + super(Generic_MIL_MTL_Dataset, self).__init__(**kwargs) + self.data_dir = data_dir + self.use_h5 = False + + def load_from_h5(self, toggle): + self.use_h5 = toggle + + def __getitem__(self, idx): + + if not self.patient_level: + + slide_id = self.slide_data['slide_id'][idx] + label_task1 = self.slide_data[self.label_cols[0]][idx] + label_task2 = self.slide_data[self.label_cols[1]][idx] + label_task3 = self.slide_data[self.label_cols[2]][idx] + if type(self.data_dir) == dict: + source = self.slide_data['source'][idx] + data_dir = self.data_dir[source] + else: + data_dir = self.data_dir + + if not self.use_h5: + if self.data_dir: + full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id)) + features = torch.load(full_path) + return features, label_task1, label_task2, label_task3 + + else: + return slide_id, label_task1, label_task2, label_task3 + + else: + full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id)) + with h5py.File(full_path,'r') as hdf5_file: + features = hdf5_file['features'][:] + coords = hdf5_file['coords'][:] + + features = torch.from_numpy(features) + return features, label_task1, label_task2, label_task3, coords + + else: + case_id = self.slide_data['case_id'][idx] + slide_ids = self.patient_dict[case_id] + label_task1 = self.slide_data[self.label_cols[0]][idx] + label_task2 = self.slide_data[self.label_cols[1]][idx] + label_task3 = self.slide_data[self.label_cols[2]][idx] + + if type(self.data_dir) == dict: + source = self.slide_data['source'][idx] + data_dir = self.data_dir[source] + else: + data_dir = self.data_dir + + if not self.use_h5: + features_list = [] + + for slide_id in slide_ids: + full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id)) + slide_features = torch.load(full_path) + features_list.append(slide_features) + + features = torch.cat( features_list, dim = 0) + return features, label_task1, label_task2, label_task3 + + else: + features_list = [] + coords_list = [] + + for slide_id in slide_ids: + full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id)) + with h5py.File(full_path,'r') as hdf5_file: + slide_features = hdf5_file['features'][:] + slide_coords = hdf5_file['coords'][:] + silide_features_t = torch.from_numpy(slide_features) + slide_coords_t = torch.from_numpy(slide_coords) + + features_list.append( slide_features_t ) + coords_list.append( slide_coords_t ) + + features = troch.cat( features_list, dim = 0) + coords = torch.cat( coords_list, dim = 0) + return features, label_task1, label_task2, label_task3, coords + + +class Generic_Split(Generic_MIL_MTL_Dataset): + def __init__(self, slide_data, data_dir=None, num_classes=2, label_cols=None, patient_level=False): + self.use_h5 = False + self.slide_data = slide_data + self.data_dir = data_dir + self.num_classes = num_classes + self.slide_cls_ids = [[] for i in range(self.num_classes[0])] + self.label_cols = label_cols + self.slide_cls_ids=None + #for i in range(self.num_classes[0]): + # self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] + + self.patient_level = patient_level + if self.patient_level: + self.patient_dict = self.build_patient_dict() + #self.slide_data = self.slide_data.drop_duplicates(subset=['case_id']) + else: + self.patient_dict = {} + + def __len__(self): + return len(self.slide_data) + + +class Generic_WSI_Inference_Dataset(Dataset): + def __init__(self, + data_dir, + csv_path = None, + print_info = True, + ): + self.data_dir = data_dir + self.print_info = print_info + + if csv_path is not None: + data = pd.read_csv(csv_path) + self.slide_data = data['slide_id'].values + else: + data = np.array(os.listdir(data_dir)) + self.slide_data = np.char.strip(data, chars ='.pt') + if print_info: + print('total number of slides to infer: ', len(self.slide_data)) + + def __len__(self): + return len(self.slide_data) + + def __getitem__(self, idx): + slide_file = self.slide_data[idx]+'.pt' + full_path = os.path.join(self.data_dir, 'pt_files',slide_file) + features = torch.load(full_path) + return features