Switch to side-by-side view

--- a
+++ b/datasets/dataset_survival.py
@@ -0,0 +1,390 @@
+from __future__ import print_function, division
+import math
+import os
+import pdb
+import pickle
+import re
+
+import h5py
+import numpy as np
+import pandas as pd
+from scipy import stats
+from sklearn.preprocessing import StandardScaler
+
+import torch
+from torch.utils.data import Dataset
+
+from utils.utils import generate_split, nth
+
+
+class Generic_WSI_Survival_Dataset(Dataset):
+    def __init__(self,
+        csv_path = 'dataset_csv/ccrcc_clean.csv', mode = 'omic', apply_sig = False,
+        shuffle = False, seed = 7, print_info = True, n_bins = 4, ignore=[],
+        patient_strat=False, label_col = None, filter_dict = {}, eps=1e-6):
+        r"""
+        Generic_WSI_Survival_Dataset 
+
+        Args:
+            csv_file (string): Path to the csv file with annotations.
+            shuffle (boolean): Whether to shuffle
+            seed (int): random seed for shuffling the data
+            print_info (boolean): Whether to print a summary of the dataset
+            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
+            ignore (list): List containing class labels to ignore
+        """
+        self.custom_test_ids = None
+        self.seed = seed
+        self.print_info = print_info
+        self.patient_strat = patient_strat
+        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
+        self.data_dir = None
+
+        if shuffle:
+            np.random.seed(seed)
+            np.random.shuffle(slide_data)
+
+        slide_data = pd.read_csv(csv_path, low_memory=False)
+        #slide_data = slide_data.drop(['Unnamed: 0'], axis=1)
+        if 'case_id' not in slide_data:
+            slide_data.index = slide_data.index.str[:12]
+            slide_data['case_id'] = slide_data.index
+            slide_data = slide_data.reset_index(drop=True)
+
+        if not label_col:
+            label_col = 'survival_months'
+        else:
+            assert label_col in slide_data.columns
+        self.label_col = label_col
+
+        if "IDC" in slide_data['oncotree_code']: # must be BRCA (and if so, use only IDCs)
+            slide_data = slide_data[slide_data['oncotree_code'] == 'IDC']
+
+        patients_df = slide_data.drop_duplicates(['case_id']).copy()
+        uncensored_df = patients_df[patients_df['censorship'] < 1]
+
+        disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False)
+        q_bins[-1] = slide_data[label_col].max() + eps
+        q_bins[0] = slide_data[label_col].min() - eps
+        
+        disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True)
+        patients_df.insert(2, 'label', disc_labels.values.astype(int))
+
+        patient_dict = {}
+        slide_data = slide_data.set_index('case_id')
+        for patient in patients_df['case_id']:
+            slide_ids = slide_data.loc[patient, 'slide_id']
+            if isinstance(slide_ids, str):
+                slide_ids = np.array(slide_ids).reshape(-1)
+            else:
+                slide_ids = slide_ids.values
+            patient_dict.update({patient:slide_ids})
+
+        self.patient_dict = patient_dict
+    
+        slide_data = patients_df
+        slide_data.reset_index(drop=True, inplace=True)
+        slide_data = slide_data.assign(slide_id=slide_data['case_id'])
+
+        label_dict = {}
+        key_count = 0
+        for i in range(len(q_bins)-1):
+            for c in [0, 1]:
+                print('{} : {}'.format((i, c), key_count))
+                label_dict.update({(i, c):key_count})
+                key_count+=1
+
+        self.label_dict = label_dict
+        for i in slide_data.index:
+            key = slide_data.loc[i, 'label']
+            slide_data.at[i, 'disc_label'] = key
+            censorship = slide_data.loc[i, 'censorship']
+            key = (key, int(censorship))
+            slide_data.at[i, 'label'] = label_dict[key]
+
+        self.bins = q_bins
+        self.num_classes=len(self.label_dict)
+        patients_df = slide_data.drop_duplicates(['case_id'])
+        self.patient_data = {'case_id':patients_df['case_id'].values, 'label':patients_df['label'].values}
+
+        #new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2]) ### ICCV
+        new_cols = list(slide_data.columns[-1:]) + list(slide_data.columns[:-1])  ### PORPOISE
+        slide_data = slide_data[new_cols]
+        self.slide_data = slide_data
+        metadata = ['disc_label', 'Unnamed: 0', 'case_id', 'label', 'slide_id', 'age', 'site', 'survival_months', 'censorship', 'is_female', 'oncotree_code', 'train']
+        self.metadata = slide_data.columns[:12]
+        
+        for col in slide_data.drop(self.metadata, axis=1).columns:
+            if not pd.Series(col).str.contains('|_cnv|_rnaseq|_rna|_mut')[0]:
+                print(col)
+        #pdb.set_trace()
+
+        assert self.metadata.equals(pd.Index(metadata))
+        self.mode = mode
+        self.cls_ids_prep()
+
+        ### ICCV discrepancies
+        # For BLCA, TPTEP1_rnaseq was accidentally appended to the metadata
+        #pdb.set_trace()
+
+        if print_info:
+            self.summarize()
+
+        ### Signatures
+        self.apply_sig = apply_sig
+        if self.apply_sig:
+            self.signatures = pd.read_csv('./datasets_csv_sig/signatures.csv')
+        else:
+            self.signatures = None
+
+        if print_info:
+            self.summarize()
+
+
+    def cls_ids_prep(self):
+        r"""
+
+        """
+        self.patient_cls_ids = [[] for i in range(self.num_classes)]        
+        for i in range(self.num_classes):
+            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
+
+        self.slide_cls_ids = [[] for i in range(self.num_classes)]
+        for i in range(self.num_classes):
+            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
+
+
+    def patient_data_prep(self):
+        r"""
+        
+        """
+        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
+        patient_labels = []
+        
+        for p in patients:
+            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
+            assert len(locations) > 0
+            label = self.slide_data['label'][locations[0]] # get patient label
+            patient_labels.append(label)
+        
+        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
+
+
+    @staticmethod
+    def df_prep(data, n_bins, ignore, label_col):
+        r"""
+        
+        """
+
+        mask = data[label_col].isin(ignore)
+        data = data[~mask]
+        data.reset_index(drop=True, inplace=True)
+        disc_labels, bins = pd.cut(data[label_col], bins=n_bins)
+        return data, bins
+
+    def __len__(self):
+        if self.patient_strat:
+            return len(self.patient_data['case_id'])
+        else:
+            return len(self.slide_data)
+
+    def summarize(self):
+        print("label column: {}".format(self.label_col))
+        print("label dictionary: {}".format(self.label_dict))
+        print("number of classes: {}".format(self.num_classes))
+        print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
+        for i in range(self.num_classes):
+            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
+            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
+
+
+    def get_split_from_df(self, all_splits: dict, split_key: str='train', scaler=None):
+        split = all_splits[split_key]
+        split = split.dropna().reset_index(drop=True)
+
+        if len(split) > 0:
+            mask = self.slide_data['slide_id'].isin(split.tolist())
+            df_slice = self.slide_data[mask].reset_index(drop=True)
+            split = Generic_Split(df_slice, metadata=self.metadata, mode=self.mode, signatures=self.signatures, data_dir=self.data_dir, label_col=self.label_col, patient_dict=self.patient_dict, num_classes=self.num_classes)
+        else:
+            split = None
+        
+        return split
+
+
+    def return_splits(self, from_id: bool=True, csv_path: str=None):
+        if from_id:
+            raise NotImplementedError
+        else:
+            assert csv_path 
+            all_splits = pd.read_csv(csv_path)
+            train_split = self.get_split_from_df(all_splits=all_splits, split_key='train')
+            val_split = self.get_split_from_df(all_splits=all_splits, split_key='val')
+            test_split = None #self.get_split_from_df(all_splits=all_splits, split_key='test')
+
+            ### --> Normalizing Data
+            print("****** Normalizing Data ******")
+            scalers = train_split.get_scaler()
+            train_split.apply_scaler(scalers=scalers)
+            val_split.apply_scaler(scalers=scalers)
+            #test_split.apply_scaler(scalers=scalers)
+            ### <--
+        return train_split, val_split#, test_split
+
+
+    def get_list(self, ids):
+        return self.slide_data['slide_id'][ids]
+
+    def getlabel(self, ids):
+        return self.slide_data['label'][ids]
+
+    def __getitem__(self, idx):
+        return None
+
+    def __getitem__(self, idx):
+        return None
+
+
+class Generic_MIL_Survival_Dataset(Generic_WSI_Survival_Dataset):
+    def __init__(self, data_dir, mode: str='omic', **kwargs):
+        super(Generic_MIL_Survival_Dataset, self).__init__(**kwargs)
+        self.data_dir = data_dir
+        self.mode = mode
+        self.use_h5 = False
+
+    def load_from_h5(self, toggle):
+        self.use_h5 = toggle
+
+    def __getitem__(self, idx):
+        case_id = self.slide_data['case_id'][idx]
+        label = torch.Tensor([self.slide_data['disc_label'][idx]])
+        event_time = torch.Tensor([self.slide_data[self.label_col][idx]])
+        c = torch.Tensor([self.slide_data['censorship'][idx]])
+        slide_ids = self.patient_dict[case_id]
+
+        if type(self.data_dir) == dict:
+            source = self.slide_data['oncotree_code'][idx]
+            data_dir = self.data_dir[source]
+        else:
+            data_dir = self.data_dir
+        
+        if not self.use_h5:
+            if self.data_dir:
+                if self.mode == 'path':
+                    path_features = []
+                    for slide_id in slide_ids:
+                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
+                        wsi_bag = torch.load(wsi_path)
+                        path_features.append(wsi_bag)
+                    path_features = torch.cat(path_features, dim=0)
+                    return (path_features, torch.zeros((1,1)), label, event_time, c)
+
+                elif self.mode == 'cluster':
+                    path_features = []
+                    cluster_ids = []
+                    for slide_id in slide_ids:
+                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
+                        wsi_bag = torch.load(wsi_path)
+                        path_features.append(wsi_bag)
+                        cluster_ids.extend(self.fname2ids[slide_id[:-4]+'.pt'])
+                    path_features = torch.cat(path_features, dim=0)
+                    cluster_ids = torch.Tensor(cluster_ids)
+                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
+                    return (path_features, cluster_ids, genomic_features, label, event_time, c)
+
+                elif self.mode == 'omic':
+                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
+                    return (torch.zeros((1,1)), genomic_features.unsqueeze(dim=0), label, event_time, c)
+
+                elif self.mode == 'pathomic':
+                    path_features = []
+                    for slide_id in slide_ids:
+                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
+                        wsi_bag = torch.load(wsi_path)
+                        path_features.append(wsi_bag)
+                    path_features = torch.cat(path_features, dim=0)
+                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
+                    return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c)
+
+                elif self.mode == 'pathomic_fast':
+                    casefeat_path = os.path.join(data_dir, f'split_{self.split_id}_case_pt', f'{case_id}.pt')
+                    path_features = torch.load(casefeat_path)
+                    genomic_features = torch.tensor(self.genomic_features.iloc[idx])
+                    return (path_features, genomic_features.unsqueeze(dim=0), label, event_time, c)
+
+                elif self.mode == 'coattn':
+                    path_features = []
+                    for slide_id in slide_ids:
+                        wsi_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id.rstrip('.svs')))
+                        wsi_bag = torch.load(wsi_path)
+                        path_features.append(wsi_bag)
+                    path_features = torch.cat(path_features, dim=0)
+                    omic1 = torch.tensor(self.genomic_features[self.omic_names[0]].iloc[idx])
+                    omic2 = torch.tensor(self.genomic_features[self.omic_names[1]].iloc[idx])
+                    omic3 = torch.tensor(self.genomic_features[self.omic_names[2]].iloc[idx])
+                    omic4 = torch.tensor(self.genomic_features[self.omic_names[3]].iloc[idx])
+                    omic5 = torch.tensor(self.genomic_features[self.omic_names[4]].iloc[idx])
+                    omic6 = torch.tensor(self.genomic_features[self.omic_names[5]].iloc[idx])
+                    return (path_features, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c)
+
+                else:
+                    raise NotImplementedError('Mode [%s] not implemented.' % self.mode)
+            else:
+                return slide_ids, label, event_time, c
+
+
+class Generic_Split(Generic_MIL_Survival_Dataset):
+    def __init__(self, slide_data, metadata, mode, 
+        signatures=None, data_dir=None, label_col=None, patient_dict=None, num_classes=2):
+        self.use_h5 = False
+        self.slide_data = slide_data
+        self.metadata = metadata
+        self.mode = mode
+        self.data_dir = data_dir
+        self.num_classes = num_classes
+        self.label_col = label_col
+        self.patient_dict = patient_dict
+        self.slide_cls_ids = [[] for i in range(self.num_classes)]
+        for i in range(self.num_classes):
+            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
+
+        ### --> Initializing genomic features in Generic Split
+        self.genomic_features = self.slide_data.drop(self.metadata, axis=1)
+        self.signatures = signatures
+
+        if mode == 'cluster':
+            with open(os.path.join(data_dir, 'fast_cluster_ids.pkl'), 'rb') as handle:
+                self.fname2ids = pickle.load(handle)
+
+        def series_intersection(s1, s2):
+            return pd.Series(list(set(s1) & set(s2)))
+
+        if self.signatures is not None:
+            self.omic_names = []
+            for col in self.signatures.columns:
+                omic = self.signatures[col].dropna().unique()
+                omic = np.concatenate([omic+mode for mode in ['_mut', '_cnv', '_rnaseq']])
+                omic = sorted(series_intersection(omic, self.genomic_features.columns))
+                self.omic_names.append(omic)
+            self.omic_sizes = [len(omic) for omic in self.omic_names]
+        print("Shape", self.genomic_features.shape)
+        ### <--
+
+    def __len__(self):
+        return len(self.slide_data)
+
+    ### --> Getting StandardScaler of self.genomic_features
+    def get_scaler(self):
+        scaler_omic = StandardScaler().fit(self.genomic_features)
+        return (scaler_omic,)
+    ### <--
+
+    ### --> Applying StandardScaler to self.genomic_features
+    def apply_scaler(self, scalers: tuple=None):
+        transformed = pd.DataFrame(scalers[0].transform(self.genomic_features))
+        transformed.columns = self.genomic_features.columns
+        self.genomic_features = transformed
+    ### <--
+
+    def set_split_id(self, split_id):
+        self.split_id = split_id
\ No newline at end of file