a b/AICare-baselines/pipelines/dl_pipeline.py
1
import os
2
3
import lightning as L
4
import torch
5
import torch.nn as nn
6
7
import models
8
from datasets.loader.unpad import unpad_y
9
from losses import get_loss
10
from metrics import get_all_metrics, check_metric_is_better
11
from models.utils import generate_mask, get_last_visit
12
13
14
class DlPipeline(L.LightningModule):
15
    def __init__(self, config):
16
        super().__init__()
17
        self.save_hyperparameters()
18
        self.demo_dim = config["demo_dim"]
19
        self.lab_dim = config["lab_dim"]
20
        self.input_dim = self.demo_dim + self.lab_dim
21
        config["input_dim"] = self.input_dim
22
        self.hidden_dim = config["hidden_dim"]
23
        self.output_dim = config["output_dim"]
24
        self.learning_rate = config["learning_rate"]
25
        self.task = config["task"]
26
        self.los_info = config["los_info"]
27
        self.model_name = config["model"]
28
        self.main_metric = config["main_metric"]
29
        self.time_aware = config.get("time_aware", False)
30
        self.cur_best_performance = {}
31
        self.embedding: torch.Tensor
32
33
        if self.model_name == "StageNet":
34
            config["chunk_size"] = self.hidden_dim
35
36
        model_class = getattr(models, self.model_name)
37
        self.ehr_encoder = model_class(**config)
38
        if self.task == "outcome":
39
            self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.5), nn.Sigmoid())
40
        elif self.task == "los":
41
            self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.0))
42
        elif self.task == "multitask":
43
            self.head = models.heads.MultitaskHead(self.hidden_dim, self.output_dim, drop=0.0)
44
45
        self.validation_step_outputs = []
46
        self.test_step_outputs = []
47
        self.test_performance = {}
48
        self.test_outputs = {}
49
50
    def forward(self, x, lens):
51
        if self.model_name == "ConCare":
52
            x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens)
53
            embedding, decov_loss = self.ehr_encoder(x_lab, x_demo, mask)
54
            embedding, decov_loss = embedding.to(x.device), decov_loss.to(x.device)
55
            self.embedding = embedding
56
            y_hat = self.head(embedding)
57
            return y_hat, embedding, decov_loss
58
        elif self.model_name in ["GRASP", "Agent", "AICare", "MCGRU"]:
59
            x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens)
60
            embedding = self.ehr_encoder(x_lab, x_demo, mask).to(x.device)
61
            self.embedding = embedding
62
            y_hat = self.head(embedding)
63
            return y_hat, embedding
64
        elif self.model_name in ["AdaCare", "RETAIN", "TCN", "Transformer", "StageNet", "BiLSTM", "GRU", "LSTM", "RNN", "MLP", "GRUAttention", "MTRHN"]:
65
            mask = generate_mask(lens)
66
            embedding = self.ehr_encoder(x, mask).to(x.device)
67
            self.embedding = embedding
68
            y_hat = self.head(embedding)
69
            return y_hat, embedding
70
71
    def _get_loss(self, x, y, lens):
72
        if self.model_name == "ConCare":
73
            y_hat, embedding, decov_loss = self(x, lens)
74
            y_hat, y = unpad_y(y_hat, y, lens)
75
            loss = get_loss(y_hat, y, self.task, self.time_aware)
76
            loss += 10*decov_loss
77
        else:
78
            y_hat, embedding = self(x, lens)
79
            y_hat, y = unpad_y(y_hat, y, lens)
80
            loss = get_loss(y_hat, y, self.task, self.time_aware)
81
        return loss, y, y_hat
82
    def training_step(self, batch, batch_idx):
83
        x, y, lens, pid = batch
84
        loss, y, y_hat = self._get_loss(x, y, lens)
85
        self.log("train_loss", loss)
86
        return loss
87
    def validation_step(self, batch, batch_idx):
88
        x, y, lens, pid = batch
89
        loss, y, y_hat = self._get_loss(x, y, lens)
90
        self.log("val_loss", loss)
91
        outs = {'y_pred': y_hat, 'y_true': y, 'val_loss': loss}
92
        self.validation_step_outputs.append(outs)
93
        return loss
94
    def on_validation_epoch_end(self):
95
        y_pred = torch.cat([x['y_pred'] for x in self.validation_step_outputs]).detach().cpu()
96
        y_true = torch.cat([x['y_true'] for x in self.validation_step_outputs]).detach().cpu()
97
        loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean().detach().cpu()
98
        self.log("val_loss_epoch", loss)
99
        metrics = get_all_metrics(y_pred, y_true, self.task, self.los_info)
100
        for k, v in metrics.items(): self.log(k, v)
101
        main_score = metrics[self.main_metric]
102
        if check_metric_is_better(self.cur_best_performance, self.main_metric, main_score, self.task):
103
            self.cur_best_performance = metrics
104
            for k, v in metrics.items(): self.log("best_"+k, v)
105
        self.validation_step_outputs.clear()
106
        return main_score
107
108
    def test_step(self, batch, batch_idx):
109
        x, y, lens, pid = batch
110
        loss, y, y_hat = self._get_loss(x, y, lens)
111
        outs = {'y_pred': y_hat, 'y_true': y, 'lens': lens}
112
        self.test_step_outputs.append(outs)
113
        return loss
114
    def on_test_epoch_end(self):
115
        y_pred = torch.cat([x['y_pred'] for x in self.test_step_outputs]).detach().cpu()
116
        y_true = torch.cat([x['y_true'] for x in self.test_step_outputs]).detach().cpu()
117
        lens = torch.cat([x['lens'] for x in self.test_step_outputs]).detach().cpu()
118
        self.test_performance = get_all_metrics(y_pred, y_true, self.task, self.los_info)
119
        self.test_outputs = {'preds': y_pred, 'labels': y_true, 'lens': lens}
120
        self.test_step_outputs.clear()
121
        return self.test_performance
122
123
    def configure_optimizers(self):
124
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
125
        return optimizer