a b/src/iterpretability/datasets/data_loader.py
1
import pickle
2
import pandas as pd
3
import numpy as np
4
from catenets.datasets import load as catenets_load
5
from src.iterpretability.datasets.news.process_news import process_news
6
from src.iterpretability.datasets.tcga.process_tcga import process_tcga
7
from src.utils import gen_unique_ids
8
import os
9
10
def remove_unbalanced(df, threshold=0.8):
11
    ub_col = []
12
    for col in df.select_dtypes(exclude='object').columns:
13
        if df[col].nunique() < 6:
14
            distribution = df[col].value_counts(normalize=True)
15
        
16
            if distribution.max() > threshold:
17
                ub_col.append(col)
18
19
    df = df.drop(columns=ub_col)
20
    return df
21
22
def load(dataset_name: str, train_ratio: float = 1.0, 
23
         directory_path_: str = None, 
24
         repo_path: str = None,
25
         debug=False,
26
         sim_type=None,
27
         n_samples=None):
28
    """
29
    Load the dataset
30
    """
31
    feature_names = None
32
33
    if "tcga" in dataset_name:
34
        try:
35
            tcga_dataset = pickle.load(
36
                open(
37
                    repo_path + "/data/tcga/" + str(dataset_name) + ".p",
38
                    "rb",
39
                )
40
            )
41
        except:
42
            process_tcga(
43
                max_num_genes=100, file_location=repo_path + "/data/tcga/"
44
            )
45
            tcga_dataset = pickle.load(
46
                open(
47
                    repo_path + "/data/tcga/" + str(dataset_name) + ".p",
48
                    "rb",
49
                )
50
            )
51
        # get gene names
52
        # ids = tcga_dataset["ids"]
53
        X_raw = tcga_dataset["rnaseq"][:5000]
54
        
55
    elif "news" in dataset_name:
56
        try:
57
            news_dataset = pickle.load(
58
                open(
59
                    "src/iterpretability/datasets/news/" + str(dataset_name) + ".p",
60
                    "rb",
61
                )
62
            )
63
        except:
64
            process_news(
65
                max_num_features=100, file_location="src/iterpretability/datasets/news/"
66
            )
67
            news_dataset = pickle.load(
68
                open(
69
                    "src/iterpretability/datasets/news/" + str(dataset_name) + ".p",
70
                    "rb",
71
                )
72
            )
73
        X_raw = news_dataset
74
75
    elif "twins" in dataset_name:
76
        # Total features  = 39
77
        X_raw, _, _, _, _, _ = catenets_load(dataset_name, train_ratio=1.0)
78
79
        # Remove columns which almost only contain one value from X_raw
80
        # Assuming X_raw is your numpy array
81
        # threshold = 2  # Adjust this value based on your definition of "almost only one value"
82
83
        # # Find the columns to keep
84
        # cols_to_keep = [i for i in range(X_raw.shape[1]) if np.unique(X_raw[:, i]).size > threshold and i != 3]
85
86
        # # Create a new array with only the columns to keep
87
        # X_raw = X_raw[:, cols_to_keep]
88
        
89
    elif "acic" in dataset_name:
90
        # Total features  = 55
91
        X_raw, _, _, _, _, _, _, _ = catenets_load("acic2016")
92
93
    elif dataset_name.startswith("depmap_drug_screen"):
94
        data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0)
95
        outcomes = data[["LFC_az_628", "LFC_imatinib"]].to_numpy()
96
        outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
97
        data = data.drop(["LFC_az_628", "LFC_imatinib"], axis=1)
98
        X_raw = data.to_numpy()
99
100
    elif dataset_name.startswith("depmap_crispr_screen"):
101
        data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0)
102
        outcomes = data[["LFC_BRAF", "LFC_EGFR"]].to_numpy()
103
        outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
104
        data = data.drop(["LFC_BRAF", "LFC_EGFR"], axis=1)
105
        X_raw = data.to_numpy()
106
107
    elif dataset_name.startswith("ovarian_semi_synthetic"):
108
        data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0)
109
        outcomes = data[["pred_a0_y0", "pred_a1_y0"]].to_numpy()
110
        outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
111
        data = data.drop(["pred_a0_y0", "pred_a1_y0"], axis=1)
112
        X_raw = data.to_numpy()
113
114
    elif dataset_name.startswith("melanoma_semi_synthetic"):
115
        data = pd.read_csv(repo_path + "/data/DepMap_24Q2/real/"+dataset_name+".csv", index_col=0)
116
        outcomes = data[["pred_a0_y2 (immuno)", "pred_a1_y2 (immuno)"]].to_numpy()
117
        outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
118
        outcomes = outcomes.astype(int)
119
        data = data.drop(["pred_a0_y2 (immuno)", "pred_a1_y2 (immuno)"], axis=1)
120
        X_raw = data.to_numpy()
121
122
    elif dataset_name.startswith("toy_data"):
123
        if sim_type == "T":
124
            raise ValueError("Toy example data does not have treatment outcomes")
125
        data = pd.read_csv(repo_path + "/data/toy/"+dataset_name+".csv", index_col=0)
126
        X_raw = data.to_numpy()
127
128
    elif dataset_name.startswith("cytof"):
129
        data = pd.read_csv(directory_path_ + dataset_name + ".csv", index_col=0)
130
131
        feature_names = data.columns 
132
133
        # Also return true outcomes for TSimulation
134
        if sim_type == "T":
135
            all_treatment_outcomes_cols = [col for col in data.columns if col.startswith("09")]
136
            outcomes = data[all_treatment_outcomes_cols[:2]] 
137
            data = data.drop(all_treatment_outcomes_cols, axis=1)
138
          
139
            # Raise value error if outcomes contains string or nan values
140
            if outcomes.isnull().values.any():
141
                raise ValueError("Outcomes contain nan values")
142
            if outcomes.dtypes.apply(lambda x: x == 'object').any():
143
                raise ValueError("Outcomes contain string values")
144
            
145
            if type(outcomes) == pd.DataFrame:
146
                outcomes = outcomes.replace({False:0, True:1})
147
                outcomes = outcomes.to_numpy()
148
                outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
149
150
        X_raw = data
151
152
153
    elif directory_path_ is not None and os.path.isfile(directory_path_ + dataset_name + ".csv"):
154
        data = pd.read_csv(directory_path_ + dataset_name + ".csv", index_col=0)
155
156
        all_treatment_outcomes_cols = [col for col in data.columns if col.startswith("09")][:5]
157
158
        # Also return true outcomes for TSimulation
159
        if sim_type == "T":
160
            outcomes = data[all_treatment_outcomes_cols]
161
162
            # Raise value error if outcomes contains string or nan values
163
            if outcomes.isnull().values.any():
164
                raise ValueError("Outcomes contain nan values")
165
            if outcomes.dtypes.apply(lambda x: x == 'object').any():
166
                raise ValueError("Outcomes contain string values")
167
            
168
            if type(outcomes) == pd.DataFrame:
169
                outcomes = outcomes.replace({False:0, True:1})
170
                outcomes = outcomes.to_numpy()
171
                outcomes = outcomes.reshape(outcomes.shape[0], outcomes.shape[1], 1)
172
173
        # Drop all columns starting with tre or out 
174
        X_raw = data.loc[:, ~data.columns.str.startswith('tre')]
175
        X_raw = X_raw.loc[:, ~X_raw.columns.str.startswith('out')]
176
177
        # One hot encode all categorical features
178
        X_raw = pd.get_dummies(X_raw)
179
180
        # Make sure there are no boolean variables
181
        X_raw = X_raw.replace({False:0, True:1})
182
183
        # Remove highly unbalanced features
184
        X_raw = remove_unbalanced(X_raw, 0.8)
185
186
        # Normalize all columns of the dataframe
187
        X_raw = X_raw.apply(lambda x: (x - x.min()) / (x.max() - x.min()))
188
189
        # X_raw.index = gen_unique_ids(X_raw.shape[0])
190
191
    else:
192
        raise Exception("Unknown dataset " + str(dataset_name) + "File:", directory_path_ + dataset_name + ".csv")
193
194
    debug_size = 200
195
    if type(X_raw) == pd.DataFrame:
196
        X_raw = X_raw.to_numpy()
197
    
198
    if train_ratio == 1.0:
199
        if debug:
200
            X_raw = X_raw[:int(debug_size*train_ratio),:]
201
            if sim_type == "T":
202
                outcomes = outcomes[:int(debug_size*train_ratio),:]
203
204
        if sim_type == "T":
205
            return X_raw, outcomes, feature_names
206
        else:
207
            return X_raw, feature_names
208
    
209
    else:
210
        X_raw_train = X_raw[: int(train_ratio * X_raw.shape[0])]
211
        X_raw_test = X_raw[int(train_ratio * X_raw.shape[0]) :]
212
213
        if debug:
214
            X_raw_train = X_raw_train[:int(debug_size*train_ratio),:]
215
            X_raw_test = X_raw_test[:int(debug_size*(1-train_ratio)),:]
216
            
217
            if sim_type == "T":
218
                outcomes_train = outcomes[:int(debug_size*train_ratio),:]
219
                outcomes_test = outcomes[int(debug_size*train_ratio):int(debug_size*(1-train_ratio)),:]
220
221
        if sim_type == "T":
222
            return X_raw_train, X_raw_test, outcomes_train, outcomes_test, feature_names
223
        else:
224
            return X_raw_train, X_raw_test, feature_names