|
a |
|
b/AICare-baselines/pipelines/dl_pipeline.py |
|
|
1 |
import os |
|
|
2 |
|
|
|
3 |
import lightning as L |
|
|
4 |
import torch |
|
|
5 |
import torch.nn as nn |
|
|
6 |
|
|
|
7 |
import models |
|
|
8 |
from datasets.loader.unpad import unpad_y |
|
|
9 |
from losses import get_loss |
|
|
10 |
from metrics import get_all_metrics, check_metric_is_better |
|
|
11 |
from models.utils import generate_mask, get_last_visit |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class DlPipeline(L.LightningModule): |
|
|
15 |
def __init__(self, config): |
|
|
16 |
super().__init__() |
|
|
17 |
self.save_hyperparameters() |
|
|
18 |
self.demo_dim = config["demo_dim"] |
|
|
19 |
self.lab_dim = config["lab_dim"] |
|
|
20 |
self.input_dim = self.demo_dim + self.lab_dim |
|
|
21 |
config["input_dim"] = self.input_dim |
|
|
22 |
self.hidden_dim = config["hidden_dim"] |
|
|
23 |
self.output_dim = config["output_dim"] |
|
|
24 |
self.learning_rate = config["learning_rate"] |
|
|
25 |
self.task = config["task"] |
|
|
26 |
self.los_info = config["los_info"] |
|
|
27 |
self.model_name = config["model"] |
|
|
28 |
self.main_metric = config["main_metric"] |
|
|
29 |
self.time_aware = config.get("time_aware", False) |
|
|
30 |
self.cur_best_performance = {} |
|
|
31 |
self.embedding: torch.Tensor |
|
|
32 |
|
|
|
33 |
if self.model_name == "StageNet": |
|
|
34 |
config["chunk_size"] = self.hidden_dim |
|
|
35 |
|
|
|
36 |
model_class = getattr(models, self.model_name) |
|
|
37 |
self.ehr_encoder = model_class(**config) |
|
|
38 |
if self.task == "outcome": |
|
|
39 |
self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.5), nn.Sigmoid()) |
|
|
40 |
elif self.task == "los": |
|
|
41 |
self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.0)) |
|
|
42 |
elif self.task == "multitask": |
|
|
43 |
self.head = models.heads.MultitaskHead(self.hidden_dim, self.output_dim, drop=0.0) |
|
|
44 |
|
|
|
45 |
self.validation_step_outputs = [] |
|
|
46 |
self.test_step_outputs = [] |
|
|
47 |
self.test_performance = {} |
|
|
48 |
self.test_outputs = {} |
|
|
49 |
|
|
|
50 |
def forward(self, x, lens): |
|
|
51 |
if self.model_name == "ConCare": |
|
|
52 |
x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens) |
|
|
53 |
embedding, decov_loss = self.ehr_encoder(x_lab, x_demo, mask) |
|
|
54 |
embedding, decov_loss = embedding.to(x.device), decov_loss.to(x.device) |
|
|
55 |
self.embedding = embedding |
|
|
56 |
y_hat = self.head(embedding) |
|
|
57 |
return y_hat, embedding, decov_loss |
|
|
58 |
elif self.model_name in ["GRASP", "Agent", "AICare", "MCGRU"]: |
|
|
59 |
x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens) |
|
|
60 |
embedding = self.ehr_encoder(x_lab, x_demo, mask).to(x.device) |
|
|
61 |
self.embedding = embedding |
|
|
62 |
y_hat = self.head(embedding) |
|
|
63 |
return y_hat, embedding |
|
|
64 |
elif self.model_name in ["AdaCare", "RETAIN", "TCN", "Transformer", "StageNet", "BiLSTM", "GRU", "LSTM", "RNN", "MLP", "GRUAttention", "MTRHN"]: |
|
|
65 |
mask = generate_mask(lens) |
|
|
66 |
embedding = self.ehr_encoder(x, mask).to(x.device) |
|
|
67 |
self.embedding = embedding |
|
|
68 |
y_hat = self.head(embedding) |
|
|
69 |
return y_hat, embedding |
|
|
70 |
|
|
|
71 |
def _get_loss(self, x, y, lens): |
|
|
72 |
if self.model_name == "ConCare": |
|
|
73 |
y_hat, embedding, decov_loss = self(x, lens) |
|
|
74 |
y_hat, y = unpad_y(y_hat, y, lens) |
|
|
75 |
loss = get_loss(y_hat, y, self.task, self.time_aware) |
|
|
76 |
loss += 10*decov_loss |
|
|
77 |
else: |
|
|
78 |
y_hat, embedding = self(x, lens) |
|
|
79 |
y_hat, y = unpad_y(y_hat, y, lens) |
|
|
80 |
loss = get_loss(y_hat, y, self.task, self.time_aware) |
|
|
81 |
return loss, y, y_hat |
|
|
82 |
def training_step(self, batch, batch_idx): |
|
|
83 |
x, y, lens, pid = batch |
|
|
84 |
loss, y, y_hat = self._get_loss(x, y, lens) |
|
|
85 |
self.log("train_loss", loss) |
|
|
86 |
return loss |
|
|
87 |
def validation_step(self, batch, batch_idx): |
|
|
88 |
x, y, lens, pid = batch |
|
|
89 |
loss, y, y_hat = self._get_loss(x, y, lens) |
|
|
90 |
self.log("val_loss", loss) |
|
|
91 |
outs = {'y_pred': y_hat, 'y_true': y, 'val_loss': loss} |
|
|
92 |
self.validation_step_outputs.append(outs) |
|
|
93 |
return loss |
|
|
94 |
def on_validation_epoch_end(self): |
|
|
95 |
y_pred = torch.cat([x['y_pred'] for x in self.validation_step_outputs]).detach().cpu() |
|
|
96 |
y_true = torch.cat([x['y_true'] for x in self.validation_step_outputs]).detach().cpu() |
|
|
97 |
loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean().detach().cpu() |
|
|
98 |
self.log("val_loss_epoch", loss) |
|
|
99 |
metrics = get_all_metrics(y_pred, y_true, self.task, self.los_info) |
|
|
100 |
for k, v in metrics.items(): self.log(k, v) |
|
|
101 |
main_score = metrics[self.main_metric] |
|
|
102 |
if check_metric_is_better(self.cur_best_performance, self.main_metric, main_score, self.task): |
|
|
103 |
self.cur_best_performance = metrics |
|
|
104 |
for k, v in metrics.items(): self.log("best_"+k, v) |
|
|
105 |
self.validation_step_outputs.clear() |
|
|
106 |
return main_score |
|
|
107 |
|
|
|
108 |
def test_step(self, batch, batch_idx): |
|
|
109 |
x, y, lens, pid = batch |
|
|
110 |
loss, y, y_hat = self._get_loss(x, y, lens) |
|
|
111 |
outs = {'y_pred': y_hat, 'y_true': y, 'lens': lens} |
|
|
112 |
self.test_step_outputs.append(outs) |
|
|
113 |
return loss |
|
|
114 |
def on_test_epoch_end(self): |
|
|
115 |
y_pred = torch.cat([x['y_pred'] for x in self.test_step_outputs]).detach().cpu() |
|
|
116 |
y_true = torch.cat([x['y_true'] for x in self.test_step_outputs]).detach().cpu() |
|
|
117 |
lens = torch.cat([x['lens'] for x in self.test_step_outputs]).detach().cpu() |
|
|
118 |
self.test_performance = get_all_metrics(y_pred, y_true, self.task, self.los_info) |
|
|
119 |
self.test_outputs = {'preds': y_pred, 'labels': y_true, 'lens': lens} |
|
|
120 |
self.test_step_outputs.clear() |
|
|
121 |
return self.test_performance |
|
|
122 |
|
|
|
123 |
def configure_optimizers(self): |
|
|
124 |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
|
|
125 |
return optimizer |