a b/01_run_simulation_experiment.py
1
import argparse
2
import sys
3
import os
4
import wandb
5
from typing import Any
6
import pandas as pd 
7
import src.iterpretability.logger as log
8
# from src.iterpretability.experiments.experiments_ext import (
9
#     PredictiveSensitivity,
10
#     PropensitySensitivity,
11
#     NonLinearitySensitivity,
12
#     CohortSizeSensitivity,
13
# )
14
from src.iterpretability.simulators import TYSimulator, TSimulator
15
16
from src.iterpretability.experiments.sample_size_sensitivity import SampleSizeSensitivity
17
from src.iterpretability.experiments.important_feature_num_sensitivity import ImportantFeatureNumSensitivity
18
from src.iterpretability.experiments.propensity_scale_sensitivity import PropensityScaleSensitivity
19
from src.iterpretability.experiments.predictive_scale_sensitivity import PredictiveScaleSensitivity
20
from src.iterpretability.experiments.treatment_space_sensitivity import TreatmentSpaceSensitivity
21
from src.iterpretability.experiments.sample_size_sensitivity_visualization import SampleSizeSensitivityVisualization
22
from src.iterpretability.experiments.data_dimensionality_sensitivity import DataDimensionalitySensitivity
23
from src.iterpretability.experiments.feature_overlap_sensitivity import FeatureOverlapSensitivity
24
from src.iterpretability.experiments.expertise_sensitivity import ExpertiseSensitivity
25
26
27
from src.iterpretability.datasets.data_loader import load
28
29
# Hydra for configuration
30
import hydra
31
from omegaconf import DictConfig, OmegaConf
32
33
@hydra.main(config_path="conf", config_name="config_TY_tcga", version_base=None)
34
def main(cfg: DictConfig):
35
    ############################
36
    # 1. SETUP
37
    ############################
38
    # print(f"Working directory : {os.getcwd()}")
39
    # print(f"Output directory  : {hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}")
40
    
41
    # Set logging level
42
    log.add(sink=sys.stderr, level=cfg.log_level)
43
44
    ############################
45
    # 2. EXPERIMENTS
46
    ############################
47
    
48
    if cfg.experiment_name == "sample_size_sensitivity":
49
        exp = SampleSizeSensitivity(cfg)
50
        exp.run()
51
52
    elif cfg.experiment_name == "important_feature_num_sensitivity":
53
        exp = ImportantFeatureNumSensitivity(cfg)
54
        exp.run()
55
56
    elif cfg.experiment_name == "propensity_scale_sensitivity":
57
        exp = PropensityScaleSensitivity(cfg)
58
        exp.run()
59
60
    elif cfg.experiment_name == "predictive_scale_sensitivity":
61
        if cfg.simulator.simulation_type == "T":
62
            print("PredictiveScaleSensitivity is not supported for TSimulator")
63
            return None
64
        
65
        exp = PredictiveScaleSensitivity(cfg)
66
        exp.run()
67
68
    elif cfg.experiment_name == "treatment_space_sensitivity":
69
        if cfg.simulator.simulation_type == "T":
70
            print("PredictiveScaleSensitivity is not supported for TSimulator")
71
            return None
72
        
73
        exp = TreatmentSpaceSensitivity(cfg)
74
        exp.run()
75
76
    elif cfg.experiment_name == "sample_size_sensitivity_visualization":
77
        if cfg.simulator.dim_Y != 1 or cfg.simulator.num_T != 2:
78
            print("SampleSizeSensitivityVisualization is only supported for dim_Y=1 and dim_T=2")
79
            return None
80
        
81
        if cfg.simulator.simulation_type == "T":
82
            print("Make sure to only pick two treatments for this to work. And remove this if statement.")
83
            return None 
84
        
85
        exp = SampleSizeSensitivityVisualization(cfg)
86
        exp.run()
87
88
    elif cfg.experiment_name == "data_dimensionality_sensitivity":
89
        exp = DataDimensionalitySensitivity(cfg)
90
        exp.run()
91
92
    elif cfg.experiment_name == "feature_overlap_sensitivity":
93
        if cfg.simulator.simulation_type == "T":
94
            print("FeatureOverlapSensitivity is not supported for TSimulator")
95
            return None
96
        
97
        exp = FeatureOverlapSensitivity(cfg)
98
        exp.run()
99
100
    elif cfg.experiment_name == "expertise_sensitivity":
101
        exp = ExpertiseSensitivity(cfg)
102
        exp.run()
103
            
104
    else:
105
        raise ValueError(f"Invalid experiment name: {cfg.experiment_name}")
106
    
107
if __name__ == "__main__":
108
    main()