[6ac965]: / 02_plot_results.py

Download this file

197 lines (157 with data), 13.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import argparse
import sys
import os
from typing import Any
import pandas as pd
import src.iterpretability.logger as log
from src.iterpretability.experiments.sample_size_sensitivity import SampleSizeSensitivity
from src.iterpretability.experiments.important_feature_num_sensitivity import ImportantFeatureNumSensitivity
from src.iterpretability.experiments.propensity_scale_sensitivity import PropensityScaleSensitivity
from src.iterpretability.experiments.predictive_scale_sensitivity import PredictiveScaleSensitivity
from src.iterpretability.experiments.treatment_space_sensitivity import TreatmentSpaceSensitivity
from src.iterpretability.experiments.sample_size_sensitivity_visualization import SampleSizeSensitivityVisualization
from src.iterpretability.experiments.data_dimensionality_sensitivity import DataDimensionalitySensitivity
from src.iterpretability.experiments.feature_overlap_sensitivity import FeatureOverlapSensitivity
from src.iterpretability.experiments.expertise_sensitivity import ExpertiseSensitivity
from src.iterpretability.simulators import TYSimulator, TSimulator
from src.iterpretability.datasets.data_loader import load
# Hydra for configuration
import hydra
from omegaconf import DictConfig, OmegaConf
def plot_result(cfg, fig_type="all", compare_axis_values=None):
if cfg.experiment_name == "sample_size_sensitivity":
exp = SampleSizeSensitivity(cfg)
elif cfg.experiment_name == "important_feature_num_sensitivity":
exp = ImportantFeatureNumSensitivity(cfg)
elif cfg.experiment_name == "propensity_scale_sensitivity":
exp = PropensityScaleSensitivity(cfg)
elif cfg.experiment_name == "predictive_scale_sensitivity":
if cfg.simulator.simulation_type == "T":
print("PredictiveScaleSensitivity is not supported for TSimulator")
return None
exp = PredictiveScaleSensitivity(cfg)
elif cfg.experiment_name == "treatment_space_sensitivity":
if cfg.simulator.simulation_type == "T":
print("PredictiveScaleSensitivity is not supported for TSimulator")
return None
exp = TreatmentSpaceSensitivity(cfg)
elif cfg.experiment_name == "sample_size_sensitivity_visualization":
if cfg.simulator.dim_Y != 1 or cfg.simulator.num_T != 2:
print("SampleSizeSensitivityVisualization is only supported for dim_Y=1 and dim_T=2")
return None
if cfg.simulator.simulation_type == "T":
print("Make sure to only pick two treatments for this to work. And remove this if statement.")
return None
exp = SampleSizeSensitivityVisualization(cfg)
elif cfg.experiment_name == "data_dimensionality_sensitivity":
exp = DataDimensionalitySensitivity(cfg)
elif cfg.experiment_name == "feature_overlap_sensitivity":
if cfg.simulator.simulation_type == "T":
print("FeatureOverlapSensitivity is not supported for TSimulator")
return None
exp = FeatureOverlapSensitivity(cfg)
elif cfg.experiment_name == "expertise_sensitivity":
exp = ExpertiseSensitivity(cfg)
else:
raise ValueError(f"Invalid experiment name: {cfg.experiment_name}")
exp.load_and_plot_results(fig_type=fig_type, compare_axis_values=compare_axis_values)
@hydra.main(config_path="conf", config_name="config_TY_tcga", version_base=None)
def main(cfg: DictConfig):
############################
# 1. Plotting
############################
# Set logging level
log.add(sink=sys.stderr, level=cfg.log_level)
for root_out, dirs_out, _ in os.walk(cfg.results_path):
if os.path.exists(root_out) and not os.listdir(root_out):
os.rmdir(root_out)
continue
for dir_out in dirs_out:
if dir_out.startswith("archive"): # or not dir == "data_dimensionality_sensitivity":
continue
for root,dirs,_ in os.walk(os.path.join(root_out, dir_out)):
for dir in dirs:
if os.path.exists(os.path.join(root, dir)) and not os.listdir(os.path.join(root, dir)):
os.rmdir(os.path.join(root, dir))
continue
# create plots dir if it does not exist
if not os.path.exists(os.path.join(root, dir, "plots")):
os.makedirs(os.path.join(root, dir, "plots"))
for _,_,files in os.walk(os.path.join(root, dir)):
for file in files:
if file.endswith(".yaml"): # and file.startswith("09_08_control_full_tcga_T2_Y1_TYsim_numbin0"): #"31_07_spo_depmap_crispr_screen_2_kos_T2_Y1_Tsim_numbin0"):
cfg = OmegaConf.load(os.path.join(root, dir, file))
cfg.metrics_to_plot = ["Policy Precision", "Pred Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "RMSE Y0", "RMSE Y1", "PEHE", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise", "Factual RMSE Y0", "CF RMSE Y0", "Factual RMSE Y1", "CF RMSE Y1", "Factual RMSE", "CF RMSE", 'Normalized F-RMSE', 'Normalized CF-RMSE', 'Normalized PEHE', 'Swap AUROC@all', 'Swap AUPRC@all', "FC PEHE", "FC CF-RMSE", "FC Swap AUROC", "FC Swap AUPRC", 'Pred: Pred features ACC', 'Pred: Prog features ACC', 'Prog: Prog features ACC', 'Prog: Pred features ACC', "GT Expertise Ratio", "GT-ES Pred Expertise Diff", "GT-ES Prog Expertise Diff", "GT-ES Total Expertise Diff", 'T Distribution: Train', 'T Distribution: Test', 'Training Duration']
plots_folder = "plots/"
compare_axis_values = None
# if cfg.experiment_name == "propensity_scale_sensitivity" and "toy1_nonlinear" in cfg.propensity_types:
# #compare_axis_values = [i for i in cfg.propensity_types if i != "toy5"]
# compare_axis_values = ["toy1_nonlinear", "toy3_nonlinear", "toy2_nonlinear", "toy6_nonlinear"]
# cfg.repo_path = "/home/mike/UZH_USZ"
# cfg.results_path = "${repo_path}/code_mv/data_simulation/results"
try:
cfg.plot_name_prefix = plots_folder+"bias"
plot_result(cfg, fig_type="expertise", compare_axis_values=compare_axis_values)
except Exception as e:
print("Error:", e)
try:
cfg.plot_name_prefix = plots_folder+"pehe_and_rmse_per_model"
plot_result(cfg, fig_type="performance_f_cf", compare_axis_values=compare_axis_values)
except Exception as e:
print("Error:", e)
try:
#cfg.model_names = ["EconML_TLearner_Lasso", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_DragonNet_4"] #"Torch_ActionNet", "Torch_SLearner"]
cfg.plot_name_prefix = plots_folder+"model_performance_comparison"
plot_result(cfg, fig_type="performance", compare_axis_values=compare_axis_values)
except Exception as e:
print("Error:", e)
# try:
# #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_SLearner", "Torch_TARNet", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_ActionNet"] #"Torch_ActionNet", "Torch_SLearner"]
# #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #"Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"prec_all_models"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.plot_name_prefix = plots_folder+"v1_no_S_Action"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["Torch_TARNet", "Torch_DragonNet","Torch_DragonNet_2", "Torch_DragonNet_4", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v2_only_expertise_paper"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v3_only_balancing"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["Torch_TLearner", "Torch_SLearner", "EconML_TLearner_Lasso", "EconML_SLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v4_torch_vs_econml"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["EconML_TLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v5_only_TLearner_Lasso"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v6_direct_vs_indirect"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
# try:
# cfg.model_names = ["Torch_DragonNet", "Torch_DragonNet_2", "Torch_DragonNet_4"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
# cfg.plot_name_prefix = plots_folder+"v7_only_action_predictive"
# plot_result(cfg, compare_axis_values=compare_axis_values)
# except Exception as e:
# print("Error:", e)
if __name__ == "__main__":
main()