a b/debug.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import os\n",
10
    "\n",
11
    "import numpy as np\n",
12
    "import pandas as pd\n",
13
    "import torch, torch.nn as nn, torch.nn.functional as F, torch.utils.data as data\n",
14
    "import lightning as L\n",
15
    "from lightning.pytorch.loggers import CSVLogger\n",
16
    "import optuna"
17
   ]
18
  },
19
  {
20
   "cell_type": "code",
21
   "execution_count": null,
22
   "metadata": {},
23
   "outputs": [],
24
   "source": [
25
    "class Pipeline(L.LightningModule):\n",
26
    "    def __init__(self, config):\n",
27
    "        super().__init__()\n",
28
    "        self.config = config\n",
29
    "        self.save_hyperparameters()\n",
30
    "        self.hidden_dim = config[\"hidden_dim\"]\n",
31
    "        self.input_dim = config[\"input_dim\"]\n",
32
    "        self.out_dim = config[\"out_dim\"]\n",
33
    "        self.ehr_encoder = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim), nn.GELU())\n",
34
    "        self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.out_dim), nn.Dropout(0.2))\n",
35
    "\n",
36
    "    def forward(self, x):\n",
37
    "        embedding = self.ehr_encoder(x)\n",
38
    "        y_hat = self.head(embedding)\n",
39
    "        return y_hat, embedding\n",
40
    "\n",
41
    "    def training_step(self, batch, batch_idx):\n",
42
    "        x, y, x_lens, pid = batch\n",
43
    "        y_hat, embedding = self(x)\n",
44
    "        \n",
45
    "        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n",
46
    "        self.log(\"train_loss\", loss)\n",
47
    "        return loss\n",
48
    "    def validation_step(self, batch, batch_idx):\n",
49
    "        x, y, x_lens, pid = batch\n",
50
    "        y_hat, embedding = self(x)\n",
51
    "        \n",
52
    "        loss = F.binary_cross_entropy_with_logits(y_hat[:,0,0], y[:,0,0])\n",
53
    "        self.log(\"val_loss\", loss)\n",
54
    "        return loss\n",
55
    "\n",
56
    "    def configure_optimizers(self):\n",
57
    "        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
58
    "        return optimizer\n",
59
    "\n",
60
    "\n"
61
   ]
62
  },
63
  {
64
   "cell_type": "code",
65
   "execution_count": null,
66
   "metadata": {},
67
   "outputs": [],
68
   "source": [
69
    "class EhrDataset(data.Dataset):\n",
70
    "    def __init__(self, data_path, mode='train'):\n",
71
    "        super().__init__()\n",
72
    "        self.data = pd.read_pickle(os.path.join(data_path,f'{mode}_x.pkl'))\n",
73
    "        self.label = pd.read_pickle(os.path.join(data_path,f'{mode}_y.pkl'))\n",
74
    "        self.pid = pd.read_pickle(os.path.join(data_path,f'{mode}_pid.pkl'))\n",
75
    "\n",
76
    "    def __len__(self):\n",
77
    "        return len(self.label) # number of patients\n",
78
    "\n",
79
    "    def __getitem__(self, index):\n",
80
    "        return self.data[index], self.label[index], self.pid[index]\n",
81
    "\n",
82
    "\n",
83
    "class EhrDataModule(L.LightningDataModule):\n",
84
    "    def __init__(self, data_path, batch_size=32):\n",
85
    "        super().__init__()\n",
86
    "        self.data_path = data_path\n",
87
    "        self.batch_size = batch_size\n",
88
    "\n",
89
    "    def setup(self, stage: str):\n",
90
    "        if stage==\"fit\":\n",
91
    "            self.train_dataset = EhrDataset(self.data_path, mode=\"train\")\n",
92
    "            self.val_dataset = EhrDataset(self.data_path, mode='val')\n",
93
    "        if stage==\"test\":\n",
94
    "            self.test_dataset = EhrDataset(self.data_path, mode='test')\n",
95
    "\n",
96
    "    def train_dataloader(self):\n",
97
    "        return data.DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
98
    "\n",
99
    "    def val_dataloader(self):\n",
100
    "        return data.DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
101
    "\n",
102
    "    def test_dataloader(self):\n",
103
    "        return data.DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.pad_collate)\n",
104
    "\n",
105
    "    def pad_collate(self, batch):\n",
106
    "        xx, yy, pid = zip(*batch)\n",
107
    "        x_lens = [len(x) for x in xx]\n",
108
    "        # convert to tensor\n",
109
    "        xx = [torch.tensor(x) for x in xx]\n",
110
    "        yy = [torch.tensor(y) for y in yy]\n",
111
    "        xx_pad = torch.nn.utils.rnn.pad_sequence(xx, batch_first=True, padding_value=0)\n",
112
    "        yy_pad = torch.nn.utils.rnn.pad_sequence(yy, batch_first=True, padding_value=0)\n",
113
    "        return xx_pad, yy_pad, x_lens, pid"
114
   ]
115
  },
116
  {
117
   "cell_type": "code",
118
   "execution_count": null,
119
   "metadata": {},
120
   "outputs": [],
121
   "source": [
122
    "model_name = \"mlp\"\n",
123
    "stage = \"tune\"\n",
124
    "\"\"\"\n",
125
    "- tune: hyperparameter search (Only the first fold)\n",
126
    "- train: train model with the best hyperparameters (K-fold / repeat with random seeds)\n",
127
    "- test: test model on the test set with the saved checkpoints (on best epoch)\n",
128
    "\"\"\"\n",
129
    "\n",
130
    "def objective(trial: optuna.trial.Trial):\n",
131
    "    config = {\n",
132
    "        \"dataset\": \"tjh\",\n",
133
    "        \"fold\": 0,\n",
134
    "        \"demo_dim\": 2,\n",
135
    "        \"lab_dim\": 73,\n",
136
    "        \"input_dim\": 75,\n",
137
    "        \"out_dim\": 1,\n",
138
    "        \"hidden_dim\": trial.suggest_int(\"hidden_dim\", 16, 1024),\n",
139
    "        \"batch_size\": trial.suggest_int(\"batch_size\", 1, 16),\n",
140
    "    }\n",
141
    "\n",
142
    "    dm = EhrDataModule(f'datasets/{config[\"dataset\"]}/processed_data/fold_{config[\"fold\"]}', batch_size=config[\"batch_size\"])\n",
143
    "    \n",
144
    "    logger = CSVLogger(save_dir=\"logs\", name=config[\"dataset\"], version=f'{model_name}_{stage}_fold{config[\"fold\"]}')\n",
145
    "    pipeline = Pipeline(config)\n",
146
    "    trainer = L.Trainer(max_epochs=3, logger=logger)\n",
147
    "    trainer.fit(pipeline, dm)\n",
148
    "\n",
149
    "    val_loss = trainer.callback_metrics['val_loss'].item()\n",
150
    "    return val_loss\n",
151
    "\n",
152
    "search_space = {\"hidden_dim\": [16, 32, 64], \"batch_size\": [1, 2, 4, 8, 16]}\n",
153
    "study = optuna.create_study(direction=\"minimize\", sampler=optuna.samplers.GridSampler(search_space))\n",
154
    "study.optimize(objective, n_trials=100)"
155
   ]
156
  },
157
  {
158
   "cell_type": "code",
159
   "execution_count": null,
160
   "metadata": {},
161
   "outputs": [],
162
   "source": [
163
    "print(\"Best trial:\")\n",
164
    "trial = study.best_trial\n",
165
    "print(\"  Value: \", trial.value)\n",
166
    "print(\"  Params: \")\n",
167
    "for key, value in trial.params.items():\n",
168
    "    print(f\"    {key}: {value}\")"
169
   ]
170
  }
171
 ],
172
 "metadata": {
173
  "kernelspec": {
174
   "display_name": "pytorch",
175
   "language": "python",
176
   "name": "python3"
177
  },
178
  "language_info": {
179
   "codemirror_mode": {
180
    "name": "ipython",
181
    "version": 3
182
   },
183
   "file_extension": ".py",
184
   "mimetype": "text/x-python",
185
   "name": "python",
186
   "nbconvert_exporter": "python",
187
   "pygments_lexer": "ipython3",
188
   "version": "3.9.5"
189
  },
190
  "orig_nbformat": 4
191
 },
192
 "nbformat": 4,
193
 "nbformat_minor": 2
194
}