a b/scripts/plcom2012/plcom2012.py
1
import math
2
from tqdm import tqdm
3
import pickle
4
5
6
class RiskModel(object):
7
    def __init__(self, args):
8
        self.args = args
9
10
    def forward(self, batch):
11
        x_transformed = {
12
            key: func(batch[key]) for key, func in self.input_transformers.items()
13
        }
14
        x_scaled = self.scale_inputs(x_transformed)
15
        risk = self.model(x_scaled)
16
        return risk
17
18
    def test(self, data):
19
        results = []
20
        for sample in tqdm(data.dataset):
21
            sample["golds"] = sample["y"]
22
            sample["probs"] = self.forward(sample)
23
24
        if self.args.save_predictions:
25
            self.save_predictions(data.dataset)
26
27
    def save_predictions(self, data):
28
        predictions_dict = [
29
            {k: v for k, v in d.items() if k in self.save_keys} for d in data
30
        ]
31
        predictions_filename = "{}.{}.predictions".format(
32
            self.args.results_path, self.save_prefix
33
        )
34
        pickle.dump(predictions_dict, open(predictions_filename, "wb"))
35
36
    @property
37
    def input_coef(self):
38
        pass
39
40
    @property
41
    def input_transformers(self):
42
        pass
43
44
class PLCOm2012(RiskModel):
45
    def __init__(self, args):
46
        super(PLCOm2012, self).__init__(args)
47
48
    def model(self, x):
49
        return 1 / (1 + math.exp(-x))
50
51
    def scale_inputs(self, x):
52
        running_sum = -4.532506
53
        for key, beta in self.input_coef.items():
54
            if key == "race":
55
                running_sum += beta[x["race"]]
56
            else:
57
                running_sum += x[key] * beta
58
        return running_sum
59
60
    @property
61
    def input_coef(self):
62
        coefs = {
63
            "age": 0.0778868,
64
            "race": {
65
                "white": 0,
66
                "black": 0.3944778,
67
                "hispanic": -0.7434744,
68
                "asian": -0.466585,
69
                "native_hawaiian_pacific": 0,
70
                "american_indian_alaskan": 1.027152,
71
            },
72
            "education": -0.0812744,
73
            "bmi": -0.0274194,
74
            "cancer_hx": 0.4589971,
75
            "family_lc_hx": 0.587185,
76
            "copd": 0.3553063,
77
            "is_smoker": 0.2597431,
78
            "smoking_intensity": -1.822606,
79
            "smoking_duration": 0.0317321,
80
            "years_since_quit_smoking": -0.0308572,
81
        }
82
        return coefs
83
84
    @property
85
    def input_transformers(self):
86
        funcs = {
87
            "age": lambda x: x - 62,
88
            "race": lambda x: x,
89
            "education": lambda x: x - 4,
90
            "bmi": lambda x: x - 27,
91
            "cancer_hx": lambda x: x,
92
            "family_lc_hx": lambda x: x,
93
            "copd": lambda x: x,
94
            "is_smoker": lambda x: x,
95
            "smoking_intensity": lambda x: 10 / x - 0.4021541613,
96
            "smoking_duration": lambda x: x - 27,
97
            "years_since_quit_smoking": lambda x: x - 10,
98
        }
99
        return funcs
100
101
    @property
102
    def save_keys(self):
103
        return [
104
            "pid",
105
            "age",
106
            "race",
107
            "education",
108
            "bmi",
109
            "cancer_hx",
110
            "family_lc_hx",
111
            "copd",
112
            "is_smoker",
113
            "smoking_intensity",
114
            "smoking_duration",
115
            "years_since_quit_smoking",
116
            "exam",
117
            "golds",
118
            "probs",
119
            "time_at_event",
120
            "y_seq",
121
            "y_mask",
122
            "screen_timepoint",
123
        ]