[6ac965]: / 01_run_simulation_experiment.py

Download this file

108 lines (85 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import argparse
import sys
import os
import wandb
from typing import Any
import pandas as pd
import src.iterpretability.logger as log
# from src.iterpretability.experiments.experiments_ext import (
# PredictiveSensitivity,
# PropensitySensitivity,
# NonLinearitySensitivity,
# CohortSizeSensitivity,
# )
from src.iterpretability.simulators import TYSimulator, TSimulator
from src.iterpretability.experiments.sample_size_sensitivity import SampleSizeSensitivity
from src.iterpretability.experiments.important_feature_num_sensitivity import ImportantFeatureNumSensitivity
from src.iterpretability.experiments.propensity_scale_sensitivity import PropensityScaleSensitivity
from src.iterpretability.experiments.predictive_scale_sensitivity import PredictiveScaleSensitivity
from src.iterpretability.experiments.treatment_space_sensitivity import TreatmentSpaceSensitivity
from src.iterpretability.experiments.sample_size_sensitivity_visualization import SampleSizeSensitivityVisualization
from src.iterpretability.experiments.data_dimensionality_sensitivity import DataDimensionalitySensitivity
from src.iterpretability.experiments.feature_overlap_sensitivity import FeatureOverlapSensitivity
from src.iterpretability.experiments.expertise_sensitivity import ExpertiseSensitivity
from src.iterpretability.datasets.data_loader import load
# Hydra for configuration
import hydra
from omegaconf import DictConfig, OmegaConf
@hydra.main(config_path="conf", config_name="config_TY_tcga", version_base=None)
def main(cfg: DictConfig):
############################
# 1. SETUP
############################
# print(f"Working directory : {os.getcwd()}")
# print(f"Output directory : {hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}")
# Set logging level
log.add(sink=sys.stderr, level=cfg.log_level)
############################
# 2. EXPERIMENTS
############################
if cfg.experiment_name == "sample_size_sensitivity":
exp = SampleSizeSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "important_feature_num_sensitivity":
exp = ImportantFeatureNumSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "propensity_scale_sensitivity":
exp = PropensityScaleSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "predictive_scale_sensitivity":
if cfg.simulator.simulation_type == "T":
print("PredictiveScaleSensitivity is not supported for TSimulator")
return None
exp = PredictiveScaleSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "treatment_space_sensitivity":
if cfg.simulator.simulation_type == "T":
print("PredictiveScaleSensitivity is not supported for TSimulator")
return None
exp = TreatmentSpaceSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "sample_size_sensitivity_visualization":
if cfg.simulator.dim_Y != 1 or cfg.simulator.num_T != 2:
print("SampleSizeSensitivityVisualization is only supported for dim_Y=1 and dim_T=2")
return None
if cfg.simulator.simulation_type == "T":
print("Make sure to only pick two treatments for this to work. And remove this if statement.")
return None
exp = SampleSizeSensitivityVisualization(cfg)
exp.run()
elif cfg.experiment_name == "data_dimensionality_sensitivity":
exp = DataDimensionalitySensitivity(cfg)
exp.run()
elif cfg.experiment_name == "feature_overlap_sensitivity":
if cfg.simulator.simulation_type == "T":
print("FeatureOverlapSensitivity is not supported for TSimulator")
return None
exp = FeatureOverlapSensitivity(cfg)
exp.run()
elif cfg.experiment_name == "expertise_sensitivity":
exp = ExpertiseSensitivity(cfg)
exp.run()
else:
raise ValueError(f"Invalid experiment name: {cfg.experiment_name}")
if __name__ == "__main__":
main()