In [None]:
import os

import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F, torch.utils.data as data
import lightning as L
from lightning.pytorch.loggers import CSVLogger
import optuna

In [None]:
class Pipeline(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()
        self.hidden_dim = config["hidden_dim"]
        self.input_dim = config["input_dim"]
        self.out_dim = config["out_dim"]
        self.ehr_encoder = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim), nn.GELU())
        self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.out_dim), nn.Dropout(0.2))

    def forward(self, x):
        embedding = self.ehr_encoder(x)
        y_hat = self.head(embedding)
        return y_hat, embedding

    def training_step(self, batch, batch_idx):
        x, y, x_lens, pid = batch
        y_hat, embedding = self(x)
        
        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])
        self.log("train_loss", loss)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y, x_lens, pid = batch
        y_hat, embedding = self(x)
        
        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer




In [None]:
class EhrDataset(data.Dataset):
    def __init__(self, data_path, mode='train'):
        super().__init__()
        self.data = pd.read_pickle(os.path.join(data_path,f'{mode}_x.pkl'))
        self.label = pd.read_pickle(os.path.join(data_path,f'{mode}_y.pkl'))
        self.pid = pd.read_pickle(os.path.join(data_path,f'{mode}_pid.pkl'))

    def __len__(self):
        return len(self.label) # number of patients

    def __getitem__(self, index):
        return self.data[index], self.label[index], self.pid[index]


class EhrDataModule(L.LightningDataModule):
    def __init__(self, data_path, batch_size=32):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size

    def setup(self, stage: str):
        if stage=="fit":
            self.train_dataset = EhrDataset(self.data_path, mode="train")
            self.val_dataset = EhrDataset(self.data_path, mode='val')
        if stage=="test":
            self.test_dataset = EhrDataset(self.data_path, mode='test')

    def train_dataloader(self):
        return data.DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)

    def val_dataloader(self):
        return data.DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)

    def test_dataloader(self):
        return data.DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)

    def pad_collate(self, batch):
        xx, yy, pid = zip(*batch)
        x_lens = [len(x) for x in xx]
        # convert to tensor
        xx = [torch.tensor(x) for x in xx]
        yy = [torch.tensor(y) for y in yy]
        xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)
        yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)
        return xx_pad, yy_pad, x_lens, pid

In [None]:
model_name = "mlp"
stage = "tune"
"""
- tune: hyperparameter search (Only the first fold)
- train: train model with the best hyperparameters (K-fold / repeat with random seeds)
- test: test model on the test set with the saved checkpoints (on best epoch)
"""

def objective(trial: optuna.trial.Trial):
    config = {
        "dataset": "tjh",
        "fold": 0,
        "demo_dim": 2,
        "lab_dim": 73,
        "input_dim": 75,
        "out_dim": 1,
        "hidden_dim": trial.suggest_int("hidden_dim", 16, 1024),
        "batch_size": trial.suggest_int("batch_size", 1, 16),
    }

    dm = EhrDataModule(f'datasets/{config["dataset"]}/processed_data/fold_{config["fold"]}', batch_size=config["batch_size"])
    
    logger = CSVLogger(save_dir="logs", name=config["dataset"], version=f'{model_name}_{stage}_fold{config["fold"]}')
    pipeline = Pipeline(config)
    trainer = L.Trainer(max_epochs=3, logger=logger)
    trainer.fit(pipeline, dm)

    val_loss = trainer.callback_metrics['val_loss'].item()
    return val_loss

search_space = {"hidden_dim": [16, 32, 64], "batch_size": [1, 2, 4, 8, 16]}
study = optuna.create_study(direction="minimize", sampler=optuna.samplers.GridSampler(search_space))
study.optimize(objective, n_trials=100)

In [None]:
print("Best trial:")
trial = study.best_trial
print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")