--- a +++ b/debug.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch, torch.nn as nn, torch.nn.functional as F, torch.utils.data as data\n", + "import lightning as L\n", + "from lightning.pytorch.loggers import CSVLogger\n", + "import optuna" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Pipeline(L.LightningModule):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.config = config\n", + " self.save_hyperparameters()\n", + " self.hidden_dim = config[\"hidden_dim\"]\n", + " self.input_dim = config[\"input_dim\"]\n", + " self.out_dim = config[\"out_dim\"]\n", + " self.ehr_encoder = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim), nn.GELU())\n", + " self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.out_dim), nn.Dropout(0.2))\n", + "\n", + " def forward(self, x):\n", + " embedding = self.ehr_encoder(x)\n", + " y_hat = self.head(embedding)\n", + " return y_hat, embedding\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y, x_lens, pid = batch\n", + " y_hat, embedding = self(x)\n", + " \n", + " loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n", + " self.log(\"train_loss\", loss)\n", + " return loss\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y, x_lens, pid = batch\n", + " y_hat, embedding = self(x)\n", + " \n", + " loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n", + " self.log(\"val_loss\", loss)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", + " return optimizer\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class EhrDataset(data.Dataset):\n", + " def __init__(self, data_path, mode='train'):\n", + " super().__init__()\n", + " self.data = pd.read_pickle(os.path.join(data_path,f'{mode}_x.pkl'))\n", + " self.label = pd.read_pickle(os.path.join(data_path,f'{mode}_y.pkl'))\n", + " self.pid = pd.read_pickle(os.path.join(data_path,f'{mode}_pid.pkl'))\n", + "\n", + " def __len__(self):\n", + " return len(self.label) # number of patients\n", + "\n", + " def __getitem__(self, index):\n", + " return self.data[index], self.label[index], self.pid[index]\n", + "\n", + "\n", + "class EhrDataModule(L.LightningDataModule):\n", + " def __init__(self, data_path, batch_size=32):\n", + " super().__init__()\n", + " self.data_path = data_path\n", + " self.batch_size = batch_size\n", + "\n", + " def setup(self, stage: str):\n", + " if stage==\"fit\":\n", + " self.train_dataset = EhrDataset(self.data_path, mode=\"train\")\n", + " self.val_dataset = EhrDataset(self.data_path, mode='val')\n", + " if stage==\"test\":\n", + " self.test_dataset = EhrDataset(self.data_path, mode='test')\n", + "\n", + " def train_dataloader(self):\n", + " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n", + "\n", + " def val_dataloader(self):\n", + " return data.DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n", + "\n", + " def test_dataloader(self):\n", + " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n", + "\n", + " def pad_collate(self, batch):\n", + " xx, yy, pid = zip(*batch)\n", + " x_lens = [len(x) for x in xx]\n", + " # convert to tensor\n", + " xx = [torch.tensor(x) for x in xx]\n", + " yy = [torch.tensor(y) for y in yy]\n", + " xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)\n", + " yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)\n", + " return xx_pad, yy_pad, x_lens, pid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"mlp\"\n", + "stage = \"tune\"\n", + "\"\"\"\n", + "- tune: hyperparameter search (Only the first fold)\n", + "- train: train model with the best hyperparameters (K-fold / repeat with random seeds)\n", + "- test: test model on the test set with the saved checkpoints (on best epoch)\n", + "\"\"\"\n", + "\n", + "def objective(trial: optuna.trial.Trial):\n", + " config = {\n", + " \"dataset\": \"tjh\",\n", + " \"fold\": 0,\n", + " \"demo_dim\": 2,\n", + " \"lab_dim\": 73,\n", + " \"input_dim\": 75,\n", + " \"out_dim\": 1,\n", + " \"hidden_dim\": trial.suggest_int(\"hidden_dim\", 16, 1024),\n", + " \"batch_size\": trial.suggest_int(\"batch_size\", 1, 16),\n", + " }\n", + "\n", + " dm = EhrDataModule(f'datasets/{config[\"dataset\"]}/processed_data/fold_{config[\"fold\"]}', batch_size=config[\"batch_size\"])\n", + " \n", + " logger = CSVLogger(save_dir=\"logs\", name=config[\"dataset\"], version=f'{model_name}_{stage}_fold{config[\"fold\"]}')\n", + " pipeline = Pipeline(config)\n", + " trainer = L.Trainer(max_epochs=3, logger=logger)\n", + " trainer.fit(pipeline, dm)\n", + "\n", + " val_loss = trainer.callback_metrics['val_loss'].item()\n", + " return val_loss\n", + "\n", + "search_space = {\"hidden_dim\": [16, 32, 64], \"batch_size\": [1, 2, 4, 8, 16]}\n", + "study = optuna.create_study(direction=\"minimize\", sampler=optuna.samplers.GridSampler(search_space))\n", + "study.optimize(objective, n_trials=100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Best trial:\")\n", + "trial = study.best_trial\n", + "print(\" Value: \", trial.value)\n", + "print(\" Params: \")\n", + "for key, value in trial.params.items():\n", + " print(f\" {key}: {value}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}