a b/AICare-baselines/pipelines/ml_pipeline.py
1
import os
2
from pathlib import Path
3
4
import pandas as pd
5
import lightning as L
6
7
import models
8
from datasets.loader.unpad import unpad_batch
9
from metrics import check_metric_is_better, get_all_metrics
10
11
12
class MlPipeline(L.LightningModule):
13
    def __init__(self, config):
14
        super().__init__()
15
        self.save_hyperparameters()
16
        self.task = config["task"]
17
        self.los_info = config["los_info"]
18
        self.model_name = config["model"]
19
        self.main_metric = config["main_metric"]
20
        self.cur_best_performance = {}
21
22
        model_class = getattr(models, self.model_name)
23
        self.model = model_class(**config)
24
25
        self.test_performance = {}
26
        self.test_outputs = {}
27
        checkpoint_folder = f'logs/train/{config["dataset"]}/{config["task"]}/{config["model"]}-fold{config["fold"]}-seed{config["seed"]}/checkpoints/'
28
        Path(checkpoint_folder).mkdir(parents=True, exist_ok=True)
29
        self.checkpoint_path = os.path.join(checkpoint_folder, 'best.ckpt')
30
31
    def forward(self, x):
32
        pass
33
    def training_step(self, batch, batch_idx):
34
        # the batch is large enough to contain the whole training set
35
        x, y, lens, pid = batch
36
        x, y = unpad_batch(x, y, lens)
37
        self.model.fit(x, y) # y contains both [outcome, los]
38
    def validation_step(self, batch, batch_idx):
39
        x, y, lens, pid = batch
40
        x, y = unpad_batch(x, y, lens)
41
        y_hat = self.model.predict(x) # y_hat is the prediction results, outcome or los
42
        metrics = get_all_metrics(y_hat, y, self.task, self.los_info)
43
        # for k, v in metrics.items(): self.log(k, v)
44
        main_score = metrics[self.main_metric]
45
        if check_metric_is_better(self.cur_best_performance, self.main_metric, main_score, self.task):
46
            self.cur_best_performance = metrics
47
            for k, v in metrics.items(): self.log("best_"+k, v)
48
            pd.to_pickle(self.model, self.checkpoint_path)
49
        return main_score
50
    def test_step(self, batch, batch_idx):
51
        x, y, lens, pid = batch
52
        x, y = unpad_batch(x, y, lens)
53
        self.model = pd.read_pickle(self.checkpoint_path)
54
        y_hat = self.model.predict(x)
55
        self.test_performance = get_all_metrics(y_hat, y, self.task, self.los_info)
56
        self.test_outputs = {'preds': y_hat, 'labels': y}
57
        return self.test_performance
58
    def configure_optimizers(self):
59
        pass