Switch to unified view

a b/trivial_classifiers/chance_level_classification.py
1
import os
2
import numpy as np
3
import pandas as pd
4
from joblib import load
5
from sklearn.model_selection import train_test_split
6
from sklearn.metrics import balanced_accuracy_score, auc, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
7
8
if __name__ == "__main__":
9
    
10
    # choose between window-level simulation or subject-level simulations
11
    WINDOW_LEVEL_SIMULATION = False
12
    # choose which class is considered positive
13
    IS_POSITIVE_CLASS_MAJORITY = True
14
    # NOTE: even though the positive class may change when calculating metrics, the label encoding is fixed: 0 - diseased, 1 - healthy
15
    # pos_label = 0 if positive class is diseased (majority), = 1 when positive class is healthy (minority)
16
    POS_LABEL = int(not IS_POSITIVE_CLASS_MAJORITY)
17
    POS_CLASS = "diseased" if POS_LABEL == 0 else "healthy"
18
    # 2 types of chance models = "predict-imbalance-probability" or "predict-all-as-majority-class"
19
    # MODEL_TYPE = "predict-imbalance-probability"
20
    MODEL_TYPE = "predict-all-as-majority-class"
21
22
    print("** LABEL ENCODING IS FIXED: HEALTHY = 1, DISEASED = 0 **\n")
23
    print(f"** METRICS CALCULATION: POSITIVE CLASS = {POS_CLASS}, IS POSITIVE CLASS MAJORITY = {IS_POSITIVE_CLASS_MAJORITY} **\n")
24
    print(f"** CHOSEN TRIVIAL MODEL: {MODEL_TYPE} **\n")
25
26
    MASTER_DATASET_INDEX = pd.read_csv("master_metadata_index.csv")
27
    subjects = MASTER_DATASET_INDEX["patient_ID"].astype("str").unique()
28
    print("[MAIN] Subject list fetched! Total subjects are {}...".format(len(subjects)))
29
30
    # CAUTION: splitting whole subjects into train+validation and heldout test
31
    SEED = 42
32
    train_subjects, test_subjects = train_test_split(subjects, test_size=0.30, random_state=SEED)
33
    print("[MAIN] (Train + validation) and (heldout test) split made at subject level. 30 percent subjects held out for testing.")
34
35
    if WINDOW_LEVEL_SIMULATION:
36
        print("** WINDOW-LEVEL SIMULATIONS **\n")
37
38
        NUM_TEST_SAMPLES = 68778
39
        # imbalace factor = #diseased/#healthy
40
        IMBALANCE_FACTOR = 8.96220
41
        # prob threshold = #minority/(#minority + #majority) - TAKEN FROM TRAINING SET WINDOWS
42
        # interpretation depends on whether minority or majority is the positive class
43
        PREDICTION_PROBABILITY_THRESHOLD = 0.100379
44
45
        # get indices for test subjects!
46
        heldout_test_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(test_subjects)].tolist()
47
        y = load("labels_y", mmap_mode='r')
48
        label_mapping, y = np.unique(y, return_inverse = True)
49
        print("[MAIN] unique labels to [0 1] mapping:", label_mapping)
50
        truth_labels = np.array(y[heldout_test_indices])
51
    
52
    else:
53
        print("** SUBJECT-LEVEL SIMULATIONS **\n")
54
55
        # subject-level simulations!
56
        NUM_TEST_SAMPLES = 478
57
        # imbalace factor = #diseased/#healthy - TAKEN FROM TRAINING SET SUBJECTS!
58
        IMBALANCE_FACTOR = 6.384105 
59
        # prob threshold = #minority/(#minority + #majority) - TAKEN FROM TRAINING SET SUBJECTS
60
        # interpretation depends on whether minority or majority is the positive class
61
        PREDICTION_PROBABILITY_THRESHOLD = 0.135426
62
63
        # NOTE: labeling healthy = 1, diseased = 0, consistent with the if clause
64
        truth_labels = np.array([1 if "sub-" in x else 0 for x in test_subjects])
65
66
    
67
    SEED = 42
68
    np.random.seed(SEED)
69
    print ("GROUND TRUTH LABELS: ", np.unique(truth_labels, return_counts=True))
70
    assert len(truth_labels) == NUM_TEST_SAMPLES
71
72
    if MODEL_TYPE == "predict-imbalance-probability":
73
74
        # run simulations for multiple seeds/multiple times
75
        precision_scores = [ ]
76
        recall_scores = [ ]
77
        f1_scores = [ ]
78
        bal_acc_scores = [ ]
79
        auroc_scores = [ ]
80
81
        for i in range(1000):
82
83
            # make chance-level predictions with a blind model - predict positive class with imbalance probability
84
            # NOTE: ASSUMING TEST DISTRIBUTION FOLLOWS THE TRAINING LABEL DISTRIBUTION! (which it does for the scope of the paper)
85
            predictions = np.random.choice([0, 1], NUM_TEST_SAMPLES, p=[(1-PREDICTION_PROBABILITY_THRESHOLD), PREDICTION_PROBABILITY_THRESHOLD])            
86
            # class probability for the greater label (0) 
87
            prediction_probabilites = np.array([1.0 if x == 0 else 0.0 for x in list(predictions)])
88
            print ("PREDICTIONS: ", np.unique(predictions, return_counts=True))
89
            # print ("PREDICTION PROBA: ", np.unique(prediction_probabilites, return_counts=True))
90
91
            # get subject-level metrics
92
            precision_test =  precision_score(truth_labels, predictions, pos_label=POS_LABEL)
93
            recall_test =  recall_score(truth_labels, predictions, pos_label=POS_LABEL)
94
            f1_test = f1_score(truth_labels, predictions, pos_label=POS_LABEL)
95
            bal_acc_test = balanced_accuracy_score(truth_labels, predictions)
96
            auroc_test = roc_auc_score(truth_labels, prediction_probabilites)
97
98
            precision_scores.append(precision_test)
99
            recall_scores.append(recall_test)
100
            f1_scores.append(f1_test)
101
            bal_acc_scores.append(bal_acc_test)
102
            auroc_scores.append(auroc_test)
103
        
104
        # print mean and std. dev. across all simulations
105
        import statistics as stats
106
        print(f"PRECISION: {stats.mean(precision_scores)} ({stats.stdev(precision_scores)})")
107
        print(f"RECALL: {stats.mean(recall_scores)} ({stats.stdev(recall_scores)})")
108
        print(f"F-1: {stats.mean(f1_scores)} ({stats.stdev(f1_scores)})")
109
        print(f"BALANCED ACCURACY: {stats.mean(bal_acc_scores)} ({stats.stdev(bal_acc_scores)})")
110
        print(f"AUC: {stats.mean(auroc_scores)} ({stats.stdev(auroc_scores)})")
111
        print("[MAIN] exiting...")
112
113
    elif MODEL_TYPE == "predict-all-as-majority-class":
114
        predictions = np.zeros((NUM_TEST_SAMPLES, ), dtype=int)
115
        # probabilities for the greater class (0) only, therefore 1.0
116
        prediction_probabilites = np.ones((NUM_TEST_SAMPLES, ), dtype=float)
117
        print ("PREDICTIONS: ", np.unique(predictions, return_counts=True))
118
119
        # get subject-level metrics
120
        precision_test =  precision_score(truth_labels, predictions, pos_label=POS_LABEL)
121
        recall_test =  recall_score(truth_labels, predictions, pos_label=POS_LABEL)
122
        f1_test = f1_score(truth_labels, predictions, pos_label=POS_LABEL)
123
        bal_acc_test = balanced_accuracy_score(truth_labels, predictions)
124
        auroc_test = roc_auc_score(truth_labels, prediction_probabilites)
125
126
        print(f"Precision: {precision_test}\nRecall: {recall_test}\nF-1: {f1_test}\nBalanced Accuracy: {bal_acc_test}\nAUC: {auroc_test}")
127
        print("[MAIN] exiting...")
128