Diff of /debug.ipynb [000000] .. [d6904d]

Switch to side-by-side view

--- a
+++ b/debug.ipynb
@@ -0,0 +1,194 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import torch, torch.nn as nn, torch.nn.functional as F, torch.utils.data as data\n",
+    "import lightning as L\n",
+    "from lightning.pytorch.loggers import CSVLogger\n",
+    "import optuna"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Pipeline(L.LightningModule):\n",
+    "    def __init__(self, config):\n",
+    "        super().__init__()\n",
+    "        self.config = config\n",
+    "        self.save_hyperparameters()\n",
+    "        self.hidden_dim = config[\"hidden_dim\"]\n",
+    "        self.input_dim = config[\"input_dim\"]\n",
+    "        self.out_dim = config[\"out_dim\"]\n",
+    "        self.ehr_encoder = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim), nn.GELU())\n",
+    "        self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.out_dim), nn.Dropout(0.2))\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        embedding = self.ehr_encoder(x)\n",
+    "        y_hat = self.head(embedding)\n",
+    "        return y_hat, embedding\n",
+    "\n",
+    "    def training_step(self, batch, batch_idx):\n",
+    "        x, y, x_lens, pid = batch\n",
+    "        y_hat, embedding = self(x)\n",
+    "        \n",
+    "        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n",
+    "        self.log(\"train_loss\", loss)\n",
+    "        return loss\n",
+    "    def validation_step(self, batch, batch_idx):\n",
+    "        x, y, x_lens, pid = batch\n",
+    "        y_hat, embedding = self(x)\n",
+    "        \n",
+    "        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n",
+    "        self.log(\"val_loss\", loss)\n",
+    "        return loss\n",
+    "\n",
+    "    def configure_optimizers(self):\n",
+    "        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
+    "        return optimizer\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class EhrDataset(data.Dataset):\n",
+    "    def __init__(self, data_path, mode='train'):\n",
+    "        super().__init__()\n",
+    "        self.data = pd.read_pickle(os.path.join(data_path,f'{mode}_x.pkl'))\n",
+    "        self.label = pd.read_pickle(os.path.join(data_path,f'{mode}_y.pkl'))\n",
+    "        self.pid = pd.read_pickle(os.path.join(data_path,f'{mode}_pid.pkl'))\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return len(self.label) # number of patients\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        return self.data[index], self.label[index], self.pid[index]\n",
+    "\n",
+    "\n",
+    "class EhrDataModule(L.LightningDataModule):\n",
+    "    def __init__(self, data_path, batch_size=32):\n",
+    "        super().__init__()\n",
+    "        self.data_path = data_path\n",
+    "        self.batch_size = batch_size\n",
+    "\n",
+    "    def setup(self, stage: str):\n",
+    "        if stage==\"fit\":\n",
+    "            self.train_dataset = EhrDataset(self.data_path, mode=\"train\")\n",
+    "            self.val_dataset = EhrDataset(self.data_path, mode='val')\n",
+    "        if stage==\"test\":\n",
+    "            self.test_dataset = EhrDataset(self.data_path, mode='test')\n",
+    "\n",
+    "    def train_dataloader(self):\n",
+    "        return data.DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
+    "\n",
+    "    def val_dataloader(self):\n",
+    "        return data.DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
+    "\n",
+    "    def test_dataloader(self):\n",
+    "        return data.DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
+    "\n",
+    "    def pad_collate(self, batch):\n",
+    "        xx, yy, pid = zip(*batch)\n",
+    "        x_lens = [len(x) for x in xx]\n",
+    "        # convert to tensor\n",
+    "        xx = [torch.tensor(x) for x in xx]\n",
+    "        yy = [torch.tensor(y) for y in yy]\n",
+    "        xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)\n",
+    "        yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)\n",
+    "        return xx_pad, yy_pad, x_lens, pid"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_name = \"mlp\"\n",
+    "stage = \"tune\"\n",
+    "\"\"\"\n",
+    "- tune: hyperparameter search (Only the first fold)\n",
+    "- train: train model with the best hyperparameters (K-fold / repeat with random seeds)\n",
+    "- test: test model on the test set with the saved checkpoints (on best epoch)\n",
+    "\"\"\"\n",
+    "\n",
+    "def objective(trial: optuna.trial.Trial):\n",
+    "    config = {\n",
+    "        \"dataset\": \"tjh\",\n",
+    "        \"fold\": 0,\n",
+    "        \"demo_dim\": 2,\n",
+    "        \"lab_dim\": 73,\n",
+    "        \"input_dim\": 75,\n",
+    "        \"out_dim\": 1,\n",
+    "        \"hidden_dim\": trial.suggest_int(\"hidden_dim\", 16, 1024),\n",
+    "        \"batch_size\": trial.suggest_int(\"batch_size\", 1, 16),\n",
+    "    }\n",
+    "\n",
+    "    dm = EhrDataModule(f'datasets/{config[\"dataset\"]}/processed_data/fold_{config[\"fold\"]}', batch_size=config[\"batch_size\"])\n",
+    "    \n",
+    "    logger = CSVLogger(save_dir=\"logs\", name=config[\"dataset\"], version=f'{model_name}_{stage}_fold{config[\"fold\"]}')\n",
+    "    pipeline = Pipeline(config)\n",
+    "    trainer = L.Trainer(max_epochs=3, logger=logger)\n",
+    "    trainer.fit(pipeline, dm)\n",
+    "\n",
+    "    val_loss = trainer.callback_metrics['val_loss'].item()\n",
+    "    return val_loss\n",
+    "\n",
+    "search_space = {\"hidden_dim\": [16, 32, 64], \"batch_size\": [1, 2, 4, 8, 16]}\n",
+    "study = optuna.create_study(direction=\"minimize\", sampler=optuna.samplers.GridSampler(search_space))\n",
+    "study.optimize(objective, n_trials=100)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"Best trial:\")\n",
+    "trial = study.best_trial\n",
+    "print(\"  Value: \", trial.value)\n",
+    "print(\"  Params: \")\n",
+    "for key, value in trial.params.items():\n",
+    "    print(f\"    {key}: {value}\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "pytorch",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.5"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}