a b/src/iterpretability/experiments/expertise_sensitivity.py
1
from src.iterpretability.experiments.experiments_base import ExperimentBase
2
3
from pathlib import Path
4
import os
5
import catenets.models as cate_models
6
import numpy as np
7
import pandas as pd
8
import wandb 
9
from PIL import Image
10
import src.iterpretability.logger as log
11
from src.plotting import (
12
    plot_results_datasets_compare, 
13
    merge_pngs
14
)
15
from src.iterpretability.explain import Explainer
16
from src.iterpretability.datasets.data_loader import load
17
from src.iterpretability.simulators import (
18
    TYSimulator,
19
    TSimulator
20
)
21
from src.iterpretability.utils import (
22
    attribution_accuracy,
23
)
24
25
# For contour plotting
26
import umap 
27
from sklearn.decomposition import PCA
28
from sklearn.manifold import TSNE
29
from sklearn.linear_model import LogisticRegression
30
from sklearn.metrics import mean_squared_error
31
import matplotlib.pyplot as plt
32
import matplotlib.tri as tri
33
from sklearn.model_selection import KFold, StratifiedKFold
34
import matplotlib.gridspec as gridspec
35
from matplotlib.colors import Normalize
36
from matplotlib.ticker import FuncFormatter
37
import imageio
38
import torch
39
import shap
40
41
# Hydra for configuration
42
import hydra
43
from omegaconf import DictConfig, OmegaConf
44
45
46
class ExpertiseSensitivity(ExperimentBase):
47
    """
48
    Sensitivity analysis for varying propensity scales. This experiment will generate a .csv with the recorded metrics.
49
    It will also generate a gif, showing the progression on dimensionality-reduced spaces.
50
    """
51
52
    def __init__(
53
        self, cfg: DictConfig
54
    ) -> None:
55
        super().__init__(cfg)
56
57
        # Experiment specific settings
58
        self.alphas = cfg.alphas
59
        self.propensity_types = cfg.propensity_types
60
61
    def run(self) -> None:
62
        """
63
        Run the experiment.
64
        """
65
        # Log
66
        log.info(
67
            f"Starting propensity scale sensitivity experiment for dataset {self.cfg.dataset}."
68
        )
69
        
70
        # Main Loop
71
        results_data = []
72
   
73
        for seed in self.seeds:
74
            #for unbalancedness_exp in self.unbalancedness_exps:
75
            for propensity_type in self.propensity_types:
76
                for alpha in self.alphas:
77
                    log.info(
78
                        f"Running experiment for seed {seed} and alpha: {alpha}."
79
                    )
80
81
                    # Initialize the simulator
82
                    if self.simulation_type == "TY":
83
                        sim = TYSimulator(dim_X = self.X.shape[1], **self.cfg.simulator, seed=seed)
84
                    elif self.simulation_type == "T":
85
                        sim = TSimulator(dim_X = self.X.shape[1], **self.cfg.simulator, seed=seed)
86
                    # Overwrite propensity and nonlinearity
87
                    sim.alpha = alpha
88
                    # sim.unbalancedness_exp = unbalancedness_exp
89
                    sim.propensity_type = propensity_type
90
91
                    # Retrieve important features
92
                    self.all_important_features = sim.all_important_features
93
                    self.pred_features = sim.predictive_features
94
                    self.prog_features = sim.prognostic_features
95
                    self.select_features = sim.selective_features
96
97
                    # Simulate outcomes and treatment assignments
98
                    sim.simulate(X=self.X, outcomes=self.outcomes)
99
                    
100
                    (
101
                        X,
102
                        T,
103
                        Y,
104
                        outcomes,
105
                        propensities
106
                    ) = sim.get_simulated_data()
107
108
                    # Get splits for cross validation
109
                    if self.discrete_outcome:
110
                        Y = Y.astype(bool)
111
                        kf = StratifiedKFold(n_splits=self.n_splits)
112
                    else:
113
                        kf = KFold(n_splits=self.n_splits)  # Change n_splits to the number of folds you want
114
115
                    # Repeat everything for each fold
116
                    for split_id, (train_index, test_index) in enumerate(kf.split(X, Y)):
117
118
                        # Extract the data and split it into train and test
119
                        train_size = len(train_index)
120
                        test_size = len(test_index)
121
                        X_train, X_test = X[train_index], X[test_index]
122
                        T_train, T_test = T[train_index], T[test_index]
123
                        Y_train, Y_test = Y[train_index], Y[test_index]
124
                        outcomes_train, outcomes_test = outcomes[train_index], outcomes[test_index]
125
                        propensities_train, propensities_test = propensities[train_index], propensities[test_index]
126
                            
127
                        log.info(
128
                            f"Running experiment for seed {seed} and alpha: {alpha}."
129
                        )
130
131
                        metrics_df = self.compute_metrics(
132
                            results_data,
133
                            sim,
134
                            X_train,
135
                            Y_train,
136
                            T_train,
137
                            X_test,
138
                            Y_test,
139
                            T_test,
140
                            outcomes_train,
141
                            outcomes_test,
142
                            propensities_train,
143
                            propensities_test,
144
                            alpha, 
145
                            "Alpha",
146
                            propensity_type, 
147
                            "Propensity Type",
148
                            seed,
149
                            split_id
150
                        )
151
152
        # Save results and plot
153
        self.save_results(metrics_df)
154