|
a |
|
b/02_plot_results.py |
|
|
1 |
import argparse |
|
|
2 |
import sys |
|
|
3 |
import os |
|
|
4 |
from typing import Any |
|
|
5 |
import pandas as pd |
|
|
6 |
import src.iterpretability.logger as log |
|
|
7 |
from src.iterpretability.experiments.sample_size_sensitivity import SampleSizeSensitivity |
|
|
8 |
from src.iterpretability.experiments.important_feature_num_sensitivity import ImportantFeatureNumSensitivity |
|
|
9 |
from src.iterpretability.experiments.propensity_scale_sensitivity import PropensityScaleSensitivity |
|
|
10 |
from src.iterpretability.experiments.predictive_scale_sensitivity import PredictiveScaleSensitivity |
|
|
11 |
from src.iterpretability.experiments.treatment_space_sensitivity import TreatmentSpaceSensitivity |
|
|
12 |
from src.iterpretability.experiments.sample_size_sensitivity_visualization import SampleSizeSensitivityVisualization |
|
|
13 |
from src.iterpretability.experiments.data_dimensionality_sensitivity import DataDimensionalitySensitivity |
|
|
14 |
from src.iterpretability.experiments.feature_overlap_sensitivity import FeatureOverlapSensitivity |
|
|
15 |
from src.iterpretability.experiments.expertise_sensitivity import ExpertiseSensitivity |
|
|
16 |
from src.iterpretability.simulators import TYSimulator, TSimulator |
|
|
17 |
from src.iterpretability.datasets.data_loader import load |
|
|
18 |
|
|
|
19 |
# Hydra for configuration |
|
|
20 |
import hydra |
|
|
21 |
from omegaconf import DictConfig, OmegaConf |
|
|
22 |
|
|
|
23 |
def plot_result(cfg, fig_type="all", compare_axis_values=None): |
|
|
24 |
if cfg.experiment_name == "sample_size_sensitivity": |
|
|
25 |
exp = SampleSizeSensitivity(cfg) |
|
|
26 |
|
|
|
27 |
elif cfg.experiment_name == "important_feature_num_sensitivity": |
|
|
28 |
exp = ImportantFeatureNumSensitivity(cfg) |
|
|
29 |
|
|
|
30 |
elif cfg.experiment_name == "propensity_scale_sensitivity": |
|
|
31 |
exp = PropensityScaleSensitivity(cfg) |
|
|
32 |
|
|
|
33 |
elif cfg.experiment_name == "predictive_scale_sensitivity": |
|
|
34 |
if cfg.simulator.simulation_type == "T": |
|
|
35 |
print("PredictiveScaleSensitivity is not supported for TSimulator") |
|
|
36 |
return None |
|
|
37 |
|
|
|
38 |
exp = PredictiveScaleSensitivity(cfg) |
|
|
39 |
|
|
|
40 |
elif cfg.experiment_name == "treatment_space_sensitivity": |
|
|
41 |
if cfg.simulator.simulation_type == "T": |
|
|
42 |
print("PredictiveScaleSensitivity is not supported for TSimulator") |
|
|
43 |
return None |
|
|
44 |
|
|
|
45 |
exp = TreatmentSpaceSensitivity(cfg) |
|
|
46 |
|
|
|
47 |
elif cfg.experiment_name == "sample_size_sensitivity_visualization": |
|
|
48 |
if cfg.simulator.dim_Y != 1 or cfg.simulator.num_T != 2: |
|
|
49 |
print("SampleSizeSensitivityVisualization is only supported for dim_Y=1 and dim_T=2") |
|
|
50 |
return None |
|
|
51 |
|
|
|
52 |
if cfg.simulator.simulation_type == "T": |
|
|
53 |
print("Make sure to only pick two treatments for this to work. And remove this if statement.") |
|
|
54 |
return None |
|
|
55 |
|
|
|
56 |
exp = SampleSizeSensitivityVisualization(cfg) |
|
|
57 |
|
|
|
58 |
elif cfg.experiment_name == "data_dimensionality_sensitivity": |
|
|
59 |
exp = DataDimensionalitySensitivity(cfg) |
|
|
60 |
|
|
|
61 |
elif cfg.experiment_name == "feature_overlap_sensitivity": |
|
|
62 |
if cfg.simulator.simulation_type == "T": |
|
|
63 |
print("FeatureOverlapSensitivity is not supported for TSimulator") |
|
|
64 |
return None |
|
|
65 |
|
|
|
66 |
exp = FeatureOverlapSensitivity(cfg) |
|
|
67 |
|
|
|
68 |
elif cfg.experiment_name == "expertise_sensitivity": |
|
|
69 |
exp = ExpertiseSensitivity(cfg) |
|
|
70 |
|
|
|
71 |
else: |
|
|
72 |
raise ValueError(f"Invalid experiment name: {cfg.experiment_name}") |
|
|
73 |
|
|
|
74 |
exp.load_and_plot_results(fig_type=fig_type, compare_axis_values=compare_axis_values) |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
@hydra.main(config_path="conf", config_name="config_TY_tcga", version_base=None) |
|
|
78 |
def main(cfg: DictConfig): |
|
|
79 |
############################ |
|
|
80 |
# 1. Plotting |
|
|
81 |
############################ |
|
|
82 |
# Set logging level |
|
|
83 |
log.add(sink=sys.stderr, level=cfg.log_level) |
|
|
84 |
|
|
|
85 |
for root_out, dirs_out, _ in os.walk(cfg.results_path): |
|
|
86 |
if os.path.exists(root_out) and not os.listdir(root_out): |
|
|
87 |
os.rmdir(root_out) |
|
|
88 |
continue |
|
|
89 |
|
|
|
90 |
for dir_out in dirs_out: |
|
|
91 |
if dir_out.startswith("archive"): # or not dir == "data_dimensionality_sensitivity": |
|
|
92 |
continue |
|
|
93 |
for root,dirs,_ in os.walk(os.path.join(root_out, dir_out)): |
|
|
94 |
for dir in dirs: |
|
|
95 |
if os.path.exists(os.path.join(root, dir)) and not os.listdir(os.path.join(root, dir)): |
|
|
96 |
os.rmdir(os.path.join(root, dir)) |
|
|
97 |
continue |
|
|
98 |
|
|
|
99 |
# create plots dir if it does not exist |
|
|
100 |
if not os.path.exists(os.path.join(root, dir, "plots")): |
|
|
101 |
os.makedirs(os.path.join(root, dir, "plots")) |
|
|
102 |
|
|
|
103 |
for _,_,files in os.walk(os.path.join(root, dir)): |
|
|
104 |
for file in files: |
|
|
105 |
if file.endswith(".yaml"): # and file.startswith("09_08_control_full_tcga_T2_Y1_TYsim_numbin0"): #"31_07_spo_depmap_crispr_screen_2_kos_T2_Y1_Tsim_numbin0"): |
|
|
106 |
cfg = OmegaConf.load(os.path.join(root, dir, file)) |
|
|
107 |
cfg.metrics_to_plot = ["Policy Precision", "Pred Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "RMSE Y0", "RMSE Y1", "PEHE", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise", "Factual RMSE Y0", "CF RMSE Y0", "Factual RMSE Y1", "CF RMSE Y1", "Factual RMSE", "CF RMSE", 'Normalized F-RMSE', 'Normalized CF-RMSE', 'Normalized PEHE', 'Swap AUROC@all', 'Swap AUPRC@all', "FC PEHE", "FC CF-RMSE", "FC Swap AUROC", "FC Swap AUPRC", 'Pred: Pred features ACC', 'Pred: Prog features ACC', 'Prog: Prog features ACC', 'Prog: Pred features ACC', "GT Expertise Ratio", "GT-ES Pred Expertise Diff", "GT-ES Prog Expertise Diff", "GT-ES Total Expertise Diff", 'T Distribution: Train', 'T Distribution: Test', 'Training Duration'] |
|
|
108 |
plots_folder = "plots/" |
|
|
109 |
|
|
|
110 |
compare_axis_values = None |
|
|
111 |
# if cfg.experiment_name == "propensity_scale_sensitivity" and "toy1_nonlinear" in cfg.propensity_types: |
|
|
112 |
# #compare_axis_values = [i for i in cfg.propensity_types if i != "toy5"] |
|
|
113 |
# compare_axis_values = ["toy1_nonlinear", "toy3_nonlinear", "toy2_nonlinear", "toy6_nonlinear"] |
|
|
114 |
|
|
|
115 |
# cfg.repo_path = "/home/mike/UZH_USZ" |
|
|
116 |
# cfg.results_path = "${repo_path}/code_mv/data_simulation/results" |
|
|
117 |
|
|
|
118 |
try: |
|
|
119 |
cfg.plot_name_prefix = plots_folder+"bias" |
|
|
120 |
plot_result(cfg, fig_type="expertise", compare_axis_values=compare_axis_values) |
|
|
121 |
except Exception as e: |
|
|
122 |
print("Error:", e) |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
try: |
|
|
126 |
cfg.plot_name_prefix = plots_folder+"pehe_and_rmse_per_model" |
|
|
127 |
plot_result(cfg, fig_type="performance_f_cf", compare_axis_values=compare_axis_values) |
|
|
128 |
except Exception as e: |
|
|
129 |
print("Error:", e) |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
try: |
|
|
133 |
#cfg.model_names = ["EconML_TLearner_Lasso", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_DragonNet_4"] #"Torch_ActionNet", "Torch_SLearner"] |
|
|
134 |
cfg.plot_name_prefix = plots_folder+"model_performance_comparison" |
|
|
135 |
plot_result(cfg, fig_type="performance", compare_axis_values=compare_axis_values) |
|
|
136 |
except Exception as e: |
|
|
137 |
print("Error:", e) |
|
|
138 |
|
|
|
139 |
# try: |
|
|
140 |
# #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_SLearner", "Torch_TARNet", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_ActionNet"] #"Torch_ActionNet", "Torch_SLearner"] |
|
|
141 |
# #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #"Torch_ActionNet", "Torch_SLearner"] |
|
|
142 |
# cfg.plot_name_prefix = plots_folder+"prec_all_models" |
|
|
143 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
144 |
# except Exception as e: |
|
|
145 |
# print("Error:", e) |
|
|
146 |
|
|
|
147 |
# try: |
|
|
148 |
# cfg.plot_name_prefix = plots_folder+"v1_no_S_Action" |
|
|
149 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
150 |
# except Exception as e: |
|
|
151 |
# print("Error:", e) |
|
|
152 |
|
|
|
153 |
# try: |
|
|
154 |
# cfg.model_names = ["Torch_TARNet", "Torch_DragonNet","Torch_DragonNet_2", "Torch_DragonNet_4", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
155 |
# cfg.plot_name_prefix = plots_folder+"v2_only_expertise_paper" |
|
|
156 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
157 |
# except Exception as e: |
|
|
158 |
# print("Error:", e) |
|
|
159 |
|
|
|
160 |
# try: |
|
|
161 |
# cfg.model_names = ["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
162 |
# cfg.plot_name_prefix = plots_folder+"v3_only_balancing" |
|
|
163 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
164 |
# except Exception as e: |
|
|
165 |
# print("Error:", e) |
|
|
166 |
|
|
|
167 |
# try: |
|
|
168 |
# cfg.model_names = ["Torch_TLearner", "Torch_SLearner", "EconML_TLearner_Lasso", "EconML_SLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
169 |
# cfg.plot_name_prefix = plots_folder+"v4_torch_vs_econml" |
|
|
170 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
171 |
# except Exception as e: |
|
|
172 |
# print("Error:", e) |
|
|
173 |
|
|
|
174 |
# try: |
|
|
175 |
# cfg.model_names = ["EconML_TLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
176 |
# cfg.plot_name_prefix = plots_folder+"v5_only_TLearner_Lasso" |
|
|
177 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
178 |
# except Exception as e: |
|
|
179 |
# print("Error:", e) |
|
|
180 |
|
|
|
181 |
# try: |
|
|
182 |
# cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
183 |
# cfg.plot_name_prefix = plots_folder+"v6_direct_vs_indirect" |
|
|
184 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
185 |
# except Exception as e: |
|
|
186 |
# print("Error:", e) |
|
|
187 |
|
|
|
188 |
# try: |
|
|
189 |
# cfg.model_names = ["Torch_DragonNet", "Torch_DragonNet_2", "Torch_DragonNet_4"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"] |
|
|
190 |
# cfg.plot_name_prefix = plots_folder+"v7_only_action_predictive" |
|
|
191 |
# plot_result(cfg, compare_axis_values=compare_axis_values) |
|
|
192 |
# except Exception as e: |
|
|
193 |
# print("Error:", e) |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
if __name__ == "__main__": |
|
|
197 |
main() |