a b/02_plot_results.py
1
import argparse
2
import sys
3
import os
4
from typing import Any
5
import pandas as pd 
6
import src.iterpretability.logger as log
7
from src.iterpretability.experiments.sample_size_sensitivity import SampleSizeSensitivity
8
from src.iterpretability.experiments.important_feature_num_sensitivity import ImportantFeatureNumSensitivity
9
from src.iterpretability.experiments.propensity_scale_sensitivity import PropensityScaleSensitivity
10
from src.iterpretability.experiments.predictive_scale_sensitivity import PredictiveScaleSensitivity
11
from src.iterpretability.experiments.treatment_space_sensitivity import TreatmentSpaceSensitivity
12
from src.iterpretability.experiments.sample_size_sensitivity_visualization import SampleSizeSensitivityVisualization
13
from src.iterpretability.experiments.data_dimensionality_sensitivity import DataDimensionalitySensitivity
14
from src.iterpretability.experiments.feature_overlap_sensitivity import FeatureOverlapSensitivity
15
from src.iterpretability.experiments.expertise_sensitivity import ExpertiseSensitivity
16
from src.iterpretability.simulators import TYSimulator, TSimulator
17
from src.iterpretability.datasets.data_loader import load
18
19
# Hydra for configuration
20
import hydra
21
from omegaconf import DictConfig, OmegaConf
22
23
def plot_result(cfg, fig_type="all", compare_axis_values=None):
24
    if cfg.experiment_name == "sample_size_sensitivity":
25
        exp = SampleSizeSensitivity(cfg)
26
27
    elif cfg.experiment_name == "important_feature_num_sensitivity":
28
        exp = ImportantFeatureNumSensitivity(cfg)
29
30
    elif cfg.experiment_name == "propensity_scale_sensitivity":
31
        exp = PropensityScaleSensitivity(cfg)
32
33
    elif cfg.experiment_name == "predictive_scale_sensitivity":
34
        if cfg.simulator.simulation_type == "T":
35
            print("PredictiveScaleSensitivity is not supported for TSimulator")
36
            return None
37
        
38
        exp = PredictiveScaleSensitivity(cfg)
39
40
    elif cfg.experiment_name == "treatment_space_sensitivity":
41
        if cfg.simulator.simulation_type == "T":
42
            print("PredictiveScaleSensitivity is not supported for TSimulator")
43
            return None
44
        
45
        exp = TreatmentSpaceSensitivity(cfg)
46
47
    elif cfg.experiment_name == "sample_size_sensitivity_visualization":
48
        if cfg.simulator.dim_Y != 1 or cfg.simulator.num_T != 2:
49
            print("SampleSizeSensitivityVisualization is only supported for dim_Y=1 and dim_T=2")
50
            return None
51
        
52
        if cfg.simulator.simulation_type == "T":
53
            print("Make sure to only pick two treatments for this to work. And remove this if statement.")
54
            return None
55
        
56
        exp = SampleSizeSensitivityVisualization(cfg)
57
58
    elif cfg.experiment_name == "data_dimensionality_sensitivity":
59
        exp = DataDimensionalitySensitivity(cfg)
60
61
    elif cfg.experiment_name == "feature_overlap_sensitivity":
62
        if cfg.simulator.simulation_type == "T":
63
            print("FeatureOverlapSensitivity is not supported for TSimulator")
64
            return None
65
        
66
        exp = FeatureOverlapSensitivity(cfg)
67
        
68
    elif cfg.experiment_name == "expertise_sensitivity":
69
        exp = ExpertiseSensitivity(cfg)
70
71
    else:
72
        raise ValueError(f"Invalid experiment name: {cfg.experiment_name}")
73
74
    exp.load_and_plot_results(fig_type=fig_type, compare_axis_values=compare_axis_values)
75
76
77
@hydra.main(config_path="conf", config_name="config_TY_tcga", version_base=None)
78
def main(cfg: DictConfig):
79
    ############################
80
    # 1. Plotting
81
    ############################
82
    # Set logging level
83
    log.add(sink=sys.stderr, level=cfg.log_level)
84
85
    for root_out, dirs_out, _ in os.walk(cfg.results_path):
86
        if os.path.exists(root_out) and not os.listdir(root_out):
87
            os.rmdir(root_out)
88
            continue
89
        
90
        for dir_out in dirs_out:
91
            if dir_out.startswith("archive"): # or not dir == "data_dimensionality_sensitivity":
92
                continue
93
            for root,dirs,_ in os.walk(os.path.join(root_out, dir_out)):
94
                for dir in dirs:
95
                    if os.path.exists(os.path.join(root, dir)) and not os.listdir(os.path.join(root, dir)):
96
                        os.rmdir(os.path.join(root, dir))
97
                        continue
98
                    
99
                    # create plots dir if it does not exist
100
                    if not os.path.exists(os.path.join(root, dir, "plots")):
101
                        os.makedirs(os.path.join(root, dir, "plots"))
102
                        
103
                    for _,_,files in os.walk(os.path.join(root, dir)):
104
                        for file in files:
105
                            if file.endswith(".yaml"): # and file.startswith("09_08_control_full_tcga_T2_Y1_TYsim_numbin0"): #"31_07_spo_depmap_crispr_screen_2_kos_T2_Y1_Tsim_numbin0"):
106
                                cfg = OmegaConf.load(os.path.join(root, dir, file))
107
                                cfg.metrics_to_plot = ["Policy Precision", "Pred Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "RMSE Y0",  "RMSE Y1", "PEHE", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise", "Factual RMSE Y0", "CF RMSE Y0", "Factual RMSE Y1", "CF RMSE Y1", "Factual RMSE", "CF RMSE", 'Normalized F-RMSE', 'Normalized CF-RMSE', 'Normalized PEHE', 'Swap AUROC@all', 'Swap AUPRC@all', "FC PEHE", "FC CF-RMSE", "FC Swap AUROC", "FC Swap AUPRC", 'Pred: Pred features ACC', 'Pred: Prog features ACC', 'Prog: Prog features ACC', 'Prog: Pred features ACC', "GT Expertise Ratio", "GT-ES Pred Expertise Diff", "GT-ES Prog Expertise Diff", "GT-ES Total Expertise Diff", 'T Distribution: Train', 'T Distribution: Test', 'Training Duration']
108
                                plots_folder = "plots/"
109
110
                                compare_axis_values = None
111
                                # if cfg.experiment_name == "propensity_scale_sensitivity" and "toy1_nonlinear" in cfg.propensity_types:
112
                                #     #compare_axis_values = [i for i in cfg.propensity_types if i != "toy5"]
113
                                #     compare_axis_values = ["toy1_nonlinear", "toy3_nonlinear", "toy2_nonlinear", "toy6_nonlinear"]
114
115
                                # cfg.repo_path = "/home/mike/UZH_USZ"
116
                                # cfg.results_path = "${repo_path}/code_mv/data_simulation/results"
117
118
                                try:
119
                                    cfg.plot_name_prefix = plots_folder+"bias"
120
                                    plot_result(cfg, fig_type="expertise", compare_axis_values=compare_axis_values)
121
                                except Exception as e:
122
                                    print("Error:", e)
123
124
125
                                try:
126
                                    cfg.plot_name_prefix = plots_folder+"pehe_and_rmse_per_model"
127
                                    plot_result(cfg, fig_type="performance_f_cf", compare_axis_values=compare_axis_values)
128
                                except Exception as e:
129
                                    print("Error:", e)
130
131
                                    
132
                                try:
133
                                    #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_DragonNet_4"] #"Torch_ActionNet", "Torch_SLearner"]
134
                                    cfg.plot_name_prefix = plots_folder+"model_performance_comparison"
135
                                    plot_result(cfg, fig_type="performance", compare_axis_values=compare_axis_values)
136
                                except Exception as e:
137
                                    print("Error:", e)
138
139
                                # try:
140
                                #     #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_SLearner", "Torch_TARNet", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet", "Torch_ActionNet"] #"Torch_ActionNet", "Torch_SLearner"]
141
                                #     #cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #"Torch_ActionNet", "Torch_SLearner"]
142
                                #     cfg.plot_name_prefix = plots_folder+"prec_all_models"
143
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
144
                                # except Exception as e:
145
                                #     print("Error:", e)
146
147
                                # try:
148
                                #     cfg.plot_name_prefix = plots_folder+"v1_no_S_Action"
149
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
150
                                # except Exception as e:
151
                                #     print("Error:", e)
152
153
                                # try:
154
                                #     cfg.model_names = ["Torch_TARNet", "Torch_DragonNet","Torch_DragonNet_2", "Torch_DragonNet_4", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
155
                                #     cfg.plot_name_prefix = plots_folder+"v2_only_expertise_paper"
156
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
157
                                # except Exception as e:
158
                                #     print("Error:", e)
159
160
                                # try:
161
                                #     cfg.model_names = ["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
162
                                #     cfg.plot_name_prefix = plots_folder+"v3_only_balancing"
163
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
164
                                # except Exception as e:
165
                                #     print("Error:", e)
166
167
                                # try:
168
                                #     cfg.model_names = ["Torch_TLearner", "Torch_SLearner", "EconML_TLearner_Lasso", "EconML_SLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
169
                                #     cfg.plot_name_prefix = plots_folder+"v4_torch_vs_econml"
170
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
171
                                # except Exception as e:
172
                                #     print("Error:", e)
173
174
                                # try:
175
                                #     cfg.model_names = ["EconML_TLearner_Lasso"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
176
                                #     cfg.plot_name_prefix = plots_folder+"v5_only_TLearner_Lasso"
177
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
178
                                # except Exception as e:
179
                                #     print("Error:", e)
180
181
                                # try:
182
                                #     cfg.model_names = ["EconML_TLearner_Lasso", "Torch_TLearner", "Torch_XLearner", "Torch_CFRNet_0.001", "Torch_DragonNet"] #["Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
183
                                #     cfg.plot_name_prefix = plots_folder+"v6_direct_vs_indirect"
184
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
185
                                # except Exception as e:
186
                                #     print("Error:", e)
187
188
                                # try:
189
                                #     cfg.model_names = ["Torch_DragonNet", "Torch_DragonNet_2", "Torch_DragonNet_4"] #["Torch_TARNet", "Torch_DragonNet", "Torch_CFRNet_0.01", "Torch_CFRNet_0.001", "Torch_CFRNet_0.0001"] #, "Torch_TLearner", "Torch_XLearner", "EconML_TLearner_Lasso", "Torch_ActionNet", "Torch_SLearner"]
190
                                #     cfg.plot_name_prefix = plots_folder+"v7_only_action_predictive"
191
                                #     plot_result(cfg, compare_axis_values=compare_axis_values)
192
                                # except Exception as e:
193
                                #     print("Error:", e)
194
195
    
196
if __name__ == "__main__":
197
    main()