--- a +++ b/src/iterpretability/datasets/data_loader.py @@ -0,0 +1,224 @@ +import pickle +import pandas as pd +import numpy as np +from catenets.datasets import load as catenets_load +from src.iterpretability.datasets.news.process_news import process_news +from src.iterpretability.datasets.tcga.process_tcga import process_tcga +from src.utils import gen_unique_ids +import os + +def remove_unbalanced(df, threshold=0.8): + ub_col = [] + for col in df.select_dtypes(exclude='object').columns: + if df[col].nunique() < 6: + distribution = df[col].value_counts(normalize=True) + + if distribution.max() > threshold: + ub_col.append(col) + + df = df.drop(columns=ub_col) + return df + +def load(dataset_name: str, train_ratio: float = 1.0, + directory_path_: str = None, + repo_path: str = None, + debug=False, + sim_type=None, + n_samples=None): + """ + Load the dataset + """ + feature_names = None + + if "tcga" in dataset_name: + try: + tcga_dataset = pickle.load( + open( + repo_path + "/data/tcga/" + str(dataset_name) + ".p", + "rb", + ) + ) + except: + process_tcga( + max_num_genes=100, file_location=repo_path + "/data/tcga/" + ) + tcga_dataset = pickle.load( + open( + repo_path + "/data/tcga/" + str(dataset_name) + ".p", + "rb", + ) + ) + # get gene names + # ids = tcga_dataset["ids"] + X_raw = tcga_dataset["rnaseq"][:5000] + + elif "news" in dataset_name: + try: + news_dataset = pickle.load( + open( + "src/iterpretability/datasets/news/" + str(dataset_name) + ".p", + "rb", + ) + ) + except: + process_news( + max_num_features=100, file_location="src/iterpretability/datasets/news/" + ) + news_dataset = pickle.load( + open( + "src/iterpretability/datasets/news/" + str(dataset_name) + ".p", + "rb", + ) + ) + X_raw = news_dataset + + elif "twins" in dataset_name: + # Total features = 39 + X_raw, _, _, _, _, _ = catenets_load(dataset_name, train_ratio=1.0) + + # Remove columns which almost only contain one value from X_raw + # Assuming X_raw is your numpy array + # threshold = 2 # Adjust this value based on your definition of "almost only one value" + + # # Find the columns to keep + # cols_to_keep = [i for i in range(X_raw.shape[1]) if np.unique(X_raw[:, i]).size > threshold and i != 3] + + # # Create a new array with only the columns to keep + # X_raw = X_raw[:, cols_to_keep] + + elif "acic" in dataset_name: + # Total features = 55 + X_raw, _, _, _, _, _, _, _ = catenets_load("acic2016") + + elif dataset_name.startswith("depmap_drug_screen"): + data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0) + outcomes = data[["LFC_az_628", "LFC_imatinib"]].to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + data = data.drop(["LFC_az_628", "LFC_imatinib"], axis=1) + X_raw = data.to_numpy() + + elif dataset_name.startswith("depmap_crispr_screen"): + data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0) + outcomes = data[["LFC_BRAF", "LFC_EGFR"]].to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + data = data.drop(["LFC_BRAF", "LFC_EGFR"], axis=1) + X_raw = data.to_numpy() + + elif dataset_name.startswith("ovarian_semi_synthetic"): + data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0) + outcomes = data[["pred_a0_y0", "pred_a1_y0"]].to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + data = data.drop(["pred_a0_y0", "pred_a1_y0"], axis=1) + X_raw = data.to_numpy() + + elif dataset_name.startswith("melanoma_semi_synthetic"): + data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0) + outcomes = data[["pred_a0_y2 (immuno)", "pred_a1_y2 (immuno)"]].to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + outcomes = outcomes.astype(int) + data = data.drop(["pred_a0_y2 (immuno)", "pred_a1_y2 (immuno)"], axis=1) + X_raw = data.to_numpy() + + elif dataset_name.startswith("toy_data"): + if sim_type == "T": + raise ValueError("Toy example data does not have treatment outcomes") + data = pd.read_csv(repo_path + "/data/toy/"+dataset_name+".csv", index_col=0) + X_raw = data.to_numpy() + + elif dataset_name.startswith("cytof"): + data = pd.read_csv(directory_path_ + dataset_name + ".csv", index_col=0) + + feature_names = data.columns + + # Also return true outcomes for TSimulation + if sim_type == "T": + all_treatment_outcomes_cols = [col for col in data.columns if col.startswith("09")] + outcomes = data[all_treatment_outcomes_cols[:2]] + data = data.drop(all_treatment_outcomes_cols, axis=1) + + # Raise value error if outcomes contains string or nan values + if outcomes.isnull().values.any(): + raise ValueError("Outcomes contain nan values") + if outcomes.dtypes.apply(lambda x: x == 'object').any(): + raise ValueError("Outcomes contain string values") + + if type(outcomes) == pd.DataFrame: + outcomes = outcomes.replace({False:0, True:1}) + outcomes = outcomes.to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + + X_raw = data + + + elif directory_path_ is not None and os.path.isfile(directory_path_ + dataset_name + ".csv"): + data = pd.read_csv(directory_path_ + dataset_name + ".csv", index_col=0) + + all_treatment_outcomes_cols = [col for col in data.columns if col.startswith("09")][:5] + + # Also return true outcomes for TSimulation + if sim_type == "T": + outcomes = data[all_treatment_outcomes_cols] + + # Raise value error if outcomes contains string or nan values + if outcomes.isnull().values.any(): + raise ValueError("Outcomes contain nan values") + if outcomes.dtypes.apply(lambda x: x == 'object').any(): + raise ValueError("Outcomes contain string values") + + if type(outcomes) == pd.DataFrame: + outcomes = outcomes.replace({False:0, True:1}) + outcomes = outcomes.to_numpy() + outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1) + + # Drop all columns starting with tre or out + X_raw = data.loc[:, ~data.columns.str.startswith('tre')] + X_raw = X_raw.loc[:, ~X_raw.columns.str.startswith('out')] + + # One hot encode all categorical features + X_raw = pd.get_dummies(X_raw) + + # Make sure there are no boolean variables + X_raw = X_raw.replace({False:0, True:1}) + + # Remove highly unbalanced features + X_raw = remove_unbalanced(X_raw, 0.8) + + # Normalize all columns of the dataframe + X_raw = X_raw.apply(lambda x: (x - x.min()) / (x.max() - x.min())) + + # X_raw.index = gen_unique_ids(X_raw.shape[0]) + + else: + raise Exception("Unknown dataset " + str(dataset_name) + "File:", directory_path_ + dataset_name + ".csv") + + debug_size = 200 + if type(X_raw) == pd.DataFrame: + X_raw = X_raw.to_numpy() + + if train_ratio == 1.0: + if debug: + X_raw = X_raw[:int(debug_size*train_ratio),:] + if sim_type == "T": + outcomes = outcomes[:int(debug_size*train_ratio),:] + + if sim_type == "T": + return X_raw, outcomes, feature_names + else: + return X_raw, feature_names + + else: + X_raw_train = X_raw[: int(train_ratio * X_raw.shape[0])] + X_raw_test = X_raw[int(train_ratio * X_raw.shape[0]) :] + + if debug: + X_raw_train = X_raw_train[:int(debug_size*train_ratio),:] + X_raw_test = X_raw_test[:int(debug_size*(1-train_ratio)),:] + + if sim_type == "T": + outcomes_train = outcomes[:int(debug_size*train_ratio),:] + outcomes_test = outcomes[int(debug_size*train_ratio):int(debug_size*(1-train_ratio)),:] + + if sim_type == "T": + return X_raw_train, X_raw_test, outcomes_train, outcomes_test, feature_names + else: + return X_raw_train, X_raw_test, feature_names