[0f1df3]: / AICare-baselines / pipelines / dl_pipeline.py

Download this file

125 lines (114 with data), 5.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import lightning as L
import torch
import torch.nn as nn
import models
from datasets.loader.unpad import unpad_y
from losses import get_loss
from metrics import get_all_metrics, check_metric_is_better
from models.utils import generate_mask, get_last_visit
class DlPipeline(L.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.demo_dim = config["demo_dim"]
self.lab_dim = config["lab_dim"]
self.input_dim = self.demo_dim + self.lab_dim
config["input_dim"] = self.input_dim
self.hidden_dim = config["hidden_dim"]
self.output_dim = config["output_dim"]
self.learning_rate = config["learning_rate"]
self.task = config["task"]
self.los_info = config["los_info"]
self.model_name = config["model"]
self.main_metric = config["main_metric"]
self.time_aware = config.get("time_aware", False)
self.cur_best_performance = {}
self.embedding: torch.Tensor
if self.model_name == "StageNet":
config["chunk_size"] = self.hidden_dim
model_class = getattr(models, self.model_name)
self.ehr_encoder = model_class(**config)
if self.task == "outcome":
self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.5), nn.Sigmoid())
elif self.task == "los":
self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.0))
elif self.task == "multitask":
self.head = models.heads.MultitaskHead(self.hidden_dim, self.output_dim, drop=0.0)
self.validation_step_outputs = []
self.test_step_outputs = []
self.test_performance = {}
self.test_outputs = {}
def forward(self, x, lens):
if self.model_name == "ConCare":
x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens)
embedding, decov_loss = self.ehr_encoder(x_lab, x_demo, mask)
embedding, decov_loss = embedding.to(x.device), decov_loss.to(x.device)
self.embedding = embedding
y_hat = self.head(embedding)
return y_hat, embedding, decov_loss
elif self.model_name in ["GRASP", "Agent", "AICare", "MCGRU"]:
x_demo, x_lab, mask = x[:, 0, :self.demo_dim], x[:, :, self.demo_dim:], generate_mask(lens)
embedding = self.ehr_encoder(x_lab, x_demo, mask).to(x.device)
self.embedding = embedding
y_hat = self.head(embedding)
return y_hat, embedding
elif self.model_name in ["AdaCare", "RETAIN", "TCN", "Transformer", "StageNet", "BiLSTM", "GRU", "LSTM", "RNN", "MLP", "GRUAttention", "MTRHN"]:
mask = generate_mask(lens)
embedding = self.ehr_encoder(x, mask).to(x.device)
self.embedding = embedding
y_hat = self.head(embedding)
return y_hat, embedding
def _get_loss(self, x, y, lens):
if self.model_name == "ConCare":
y_hat, embedding, decov_loss = self(x, lens)
y_hat, y = unpad_y(y_hat, y, lens)
loss = get_loss(y_hat, y, self.task, self.time_aware)
loss += 10*decov_loss
else:
y_hat, embedding = self(x, lens)
y_hat, y = unpad_y(y_hat, y, lens)
loss = get_loss(y_hat, y, self.task, self.time_aware)
return loss, y, y_hat
def training_step(self, batch, batch_idx):
x, y, lens, pid = batch
loss, y, y_hat = self._get_loss(x, y, lens)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y, lens, pid = batch
loss, y, y_hat = self._get_loss(x, y, lens)
self.log("val_loss", loss)
outs = {'y_pred': y_hat, 'y_true': y, 'val_loss': loss}
self.validation_step_outputs.append(outs)
return loss
def on_validation_epoch_end(self):
y_pred = torch.cat([x['y_pred'] for x in self.validation_step_outputs]).detach().cpu()
y_true = torch.cat([x['y_true'] for x in self.validation_step_outputs]).detach().cpu()
loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean().detach().cpu()
self.log("val_loss_epoch", loss)
metrics = get_all_metrics(y_pred, y_true, self.task, self.los_info)
for k, v in metrics.items(): self.log(k, v)
main_score = metrics[self.main_metric]
if check_metric_is_better(self.cur_best_performance, self.main_metric, main_score, self.task):
self.cur_best_performance = metrics
for k, v in metrics.items(): self.log("best_"+k, v)
self.validation_step_outputs.clear()
return main_score
def test_step(self, batch, batch_idx):
x, y, lens, pid = batch
loss, y, y_hat = self._get_loss(x, y, lens)
outs = {'y_pred': y_hat, 'y_true': y, 'lens': lens}
self.test_step_outputs.append(outs)
return loss
def on_test_epoch_end(self):
y_pred = torch.cat([x['y_pred'] for x in self.test_step_outputs]).detach().cpu()
y_true = torch.cat([x['y_true'] for x in self.test_step_outputs]).detach().cpu()
lens = torch.cat([x['lens'] for x in self.test_step_outputs]).detach().cpu()
self.test_performance = get_all_metrics(y_pred, y_true, self.task, self.los_info)
self.test_outputs = {'preds': y_pred, 'labels': y_true, 'lens': lens}
self.test_step_outputs.clear()
return self.test_performance
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer