Switch to unified view

a b/SessionIII_MultiOmics/SingleCellML.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "## Single-cell RNA data on Minimal Residual disease for melanoma!\n",
8
    "![MRD](mrd.png \"Minimal residual Disease for melanoma\")"
9
   ]
10
  },
11
  {
12
   "cell_type": "markdown",
13
   "metadata": {},
14
   "source": [
15
    "## Learning Outcome\n",
16
    "\n",
17
    "After this session you will be able to load and pre-process your multi-omics data to generate lower-dimensional embeddings. "
18
   ]
19
  },
20
  {
21
   "cell_type": "markdown",
22
   "metadata": {},
23
   "source": [
24
    "### Quality Control\n",
25
    "![qc](mrd1.png \"Quality Control for MRD\")"
26
   ]
27
  },
28
  {
29
   "cell_type": "markdown",
30
   "metadata": {},
31
   "source": [
32
    "### Generate lower-dimensional embeddings\n",
33
    "\n",
34
    "We will perform dimensionality reduction and generate lower-dimensional embeddings of the single-cell RNAseq data using two methods:\n",
35
    "* PCA (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html)\n",
36
    "* Neural Networks (https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html)\n"
37
   ]
38
  },
39
  {
40
   "cell_type": "code",
41
   "execution_count": 2,
42
   "metadata": {},
43
   "outputs": [],
44
   "source": [
45
    "import numpy as np\n",
46
    "import pandas as pd\n",
47
    "import matplotlib\n",
48
    "import matplotlib.pyplot as plt\n",
49
    "import seaborn as sns \n",
50
    "sns.set_style('dark')\n",
51
    "\n",
52
    "# ====== Scikit-learn imports ======\n",
53
    "\n",
54
    "from sklearn.svm import SVC\n",
55
    "from sklearn.metrics import (\n",
56
    "    auc,\n",
57
    "    roc_curve,\n",
58
    "    ConfusionMatrixDisplay,\n",
59
    "    f1_score,\n",
60
    "    balanced_accuracy_score,\n",
61
    ")\n",
62
    "from sklearn.preprocessing import StandardScaler, LabelBinarizer\n",
63
    "from sklearn.model_selection import train_test_split\n",
64
    "from sklearn.decomposition import PCA\n",
65
    "\n",
66
    "## ====== Torch imports ======\n",
67
    "import torch\n",
68
    "import torch.nn as nn \n",
69
    "import torch.nn.functional as F\n",
70
    "import torch.optim as optim\n",
71
    "from lightning.pytorch.utilities.types import OptimizerLRScheduler\n",
72
    "import torch.utils.data\n",
73
    "from torch.utils.data import TensorDataset, DataLoader\n",
74
    "\n",
75
    "import lightning, lightning.pytorch.loggers"
76
   ]
77
  },
78
  {
79
   "cell_type": "markdown",
80
   "metadata": {},
81
   "source": [
82
    "### Read data"
83
   ]
84
  },
85
  {
86
   "cell_type": "code",
87
   "execution_count": 3,
88
   "metadata": {},
89
   "outputs": [
90
    {
91
     "name": "stdout",
92
     "output_type": "stream",
93
     "text": [
94
      "(369, 2002)\n"
95
     ]
96
    },
97
    {
98
     "name": "stderr",
99
     "output_type": "stream",
100
     "text": [
101
      "/tmp/ipykernel_3025929/715312502.py:7: SettingWithCopyWarning: \n",
102
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
103
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
104
      "\n",
105
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
106
      "  labels_filtered['Labels'] = [x.split(' ')[1] for x in labels_filtered['Labels']]\n"
107
     ]
108
    }
109
   ],
110
   "source": [
111
    "df = pd.read_csv(\"../data/GSE116237_forQ.csv\")\n",
112
    "print(df.shape)\n",
113
    "labels = pd.read_csv(\"../data/GSE116237 filtered labels.csv\")\n",
114
    "labels.drop('Unnamed: 0', axis=1, inplace=True)\n",
115
    "labels.columns = ['Cells', 'Labels']\n",
116
    "labels_filtered = labels[labels['Cells'].isin(list(df['Cells']))]\n",
117
    "labels_filtered['Labels'] = [x.split(' ')[1] for x in labels_filtered['Labels']]\n",
118
    "df = pd.merge(df, labels_filtered, on='Cells', how='inner')\n",
119
    "df['Labels'] = df['Labels'].map({'T0': 0, 'phase2': 1})\n",
120
    "y = np.array(df['Labels'])\n",
121
    "X = df[df.columns[1:-1]].values\n",
122
    "\n",
123
    "num_samples = X.shape[0]\n",
124
    "num_feats = X.shape[1]"
125
   ]
126
  },
127
  {
128
   "cell_type": "markdown",
129
   "metadata": {},
130
   "source": [
131
    "### Split into training and testing"
132
   ]
133
  },
134
  {
135
   "cell_type": "code",
136
   "execution_count": 4,
137
   "metadata": {},
138
   "outputs": [],
139
   "source": [
140
    "X_working, X_held_out, y_working, y_held_out = train_test_split(X,\n",
141
    "                                                    y,\n",
142
    "                                                    train_size=0.8,\n",
143
    "                                                    shuffle=True)"
144
   ]
145
  },
146
  {
147
   "cell_type": "markdown",
148
   "metadata": {},
149
   "source": [
150
    "### PCA"
151
   ]
152
  },
153
  {
154
   "cell_type": "code",
155
   "execution_count": 5,
156
   "metadata": {},
157
   "outputs": [
158
    {
159
     "name": "stdout",
160
     "output_type": "stream",
161
     "text": [
162
      "Training data dimensions:  (295, 2001)\n",
163
      "Embedded train dimensions:  (295, 5)\n",
164
      "Testing data dimensions:  (74, 2001)\n",
165
      "Embedded test dimensions:  (74, 5)\n"
166
     ]
167
    }
168
   ],
169
   "source": [
170
    "output_dim = 5\n",
171
    "pca = PCA(n_components=output_dim)\n",
172
    "scaler = StandardScaler()\n",
173
    "X_train_scaled = scaler.fit_transform(X_working)\n",
174
    "embedding_train = pca.fit_transform(X_train_scaled)\n",
175
    "X_test_scaled = scaler.fit_transform(X_held_out)\n",
176
    "embedding_test = pca.fit_transform(X_test_scaled)\n",
177
    "print(\"Training data dimensions: \", X_working.shape)\n",
178
    "print(\"Embedded train dimensions: \", embedding_train.shape)\n",
179
    "print(\"Testing data dimensions: \", X_held_out.shape)\n",
180
    "print(\"Embedded test dimensions: \", embedding_test.shape)"
181
   ]
182
  },
183
  {
184
   "cell_type": "markdown",
185
   "metadata": {},
186
   "source": [
187
    "### Neural Networks"
188
   ]
189
  },
190
  {
191
   "cell_type": "markdown",
192
   "metadata": {},
193
   "source": [
194
    "Define the network in Pytorch lightning"
195
   ]
196
  },
197
  {
198
   "cell_type": "code",
199
   "execution_count": 6,
200
   "metadata": {},
201
   "outputs": [],
202
   "source": [
203
    "embedder = torch.nn.Sequential(\n",
204
    "            nn.Linear(num_feats,256),\n",
205
    "            nn.LeakyReLU(),\n",
206
    "            nn.Linear(256,64),\n",
207
    "            nn.LeakyReLU(),\n",
208
    "            nn.Linear(64,32), \n",
209
    "            nn.LeakyReLU(),\n",
210
    "            nn.Linear(32,16), \n",
211
    "            nn.LeakyReLU(),\n",
212
    "            nn.Linear(16,output_dim),\n",
213
    "            nn.LeakyReLU(), \n",
214
    "            ) \n",
215
    "\n",
216
    "classifier = torch.nn.Sequential(\n",
217
    "            nn.Linear(output_dim,1),\n",
218
    "            nn.Softmax(dim=1)\n",
219
    "            )\n",
220
    "class BinaryClassifierModel(lightning.LightningModule):\n",
221
    "    def __init__(self, embedder, classifier, input_dim,learning_rate=1e-3):\n",
222
    "        super().__init__()\n",
223
    "        self.input_dim = input_dim\n",
224
    "        self.embedder = embedder\n",
225
    "        self.classifier = classifier\n",
226
    "        self.learning_rate = learning_rate\n",
227
    "        self.loss_fun = nn.BCELoss()\n",
228
    "        \n",
229
    "    def train_dataloader(self):\n",
230
    "        return DataLoader(self.train_data, batch_size=32, shuffle=True)\n",
231
    "\n",
232
    "    def val_dataloader(self):\n",
233
    "        return DataLoader(self.val_data, batch_size=32)  # No shuffling for validation\n",
234
    "    \n",
235
    "    def forward(self, X): \n",
236
    "        x = self.embedder(X)\n",
237
    "        x = self.classifier(x) \n",
238
    "        return x \n",
239
    "    \n",
240
    "    def training_step(self, batch, batch_idx):\n",
241
    "        x, y = batch\n",
242
    "        y = y.unsqueeze(1)\n",
243
    "        y_float = y.float()\n",
244
    "        x_embedder = self.embedder(x)\n",
245
    "        y_hat = self.classifier(x_embedder)\n",
246
    "        #y_hat = torch.argmax(y_hat, dim=1)\n",
247
    "        loss = self.loss_fun(y_hat, y_float)\n",
248
    "        self.log(\"train_loss\", loss, \n",
249
    "                prog_bar=True, \n",
250
    "                logger=True)\n",
251
    "        return loss\n",
252
    "    \n",
253
    "    def validation_step(self, batch, batch_idx):\n",
254
    "        x, y = batch\n",
255
    "        y = y.unsqueeze(1)\n",
256
    "        y_float = y.float()\n",
257
    "        x_embedder = self.embedder(x)\n",
258
    "        y_hat = self.classifier(x_embedder)\n",
259
    "        val_loss = self.loss_fun(y_hat, y_float)\n",
260
    "        f1score = f1_score(y_hat, y)\n",
261
    "        #print(f1score)\n",
262
    "        #print(val_loss)\n",
263
    "        self.log(\"val_loss\", val_loss, prog_bar=False, logger=True)  # Log on epoch end\n",
264
    "        return val_loss\n",
265
    "\n",
266
    "        \n",
267
    "    def configure_optimizers(self):\n",
268
    "        return torch.optim.Adam(self.classifier.parameters(), lr=self.learning_rate)\n",
269
    "    \n",
270
    "def prepare_data(X_train, y_train, X_val, y_val):\n",
271
    "    # Assuming X and y are NumPy arrays\n",
272
    "\n",
273
    "    train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32), \n",
274
    "                        torch.tensor(y_train, dtype=torch.float32))\n",
275
    "    val_data = TensorDataset(torch.tensor(X_val, dtype=torch.float32), \n",
276
    "                        torch.tensor(y_val, dtype=torch.float32))\n",
277
    "    \n",
278
    "    return train_data, val_data "
279
   ]
280
  },
281
  {
282
   "cell_type": "markdown",
283
   "metadata": {},
284
   "source": [
285
    "Compute the embeddings using the network and get lower dimension embeddings in the form of matrices for downstream QML tasks"
286
   ]
287
  },
288
  {
289
   "cell_type": "code",
290
   "execution_count": 7,
291
   "metadata": {},
292
   "outputs": [
293
    {
294
     "name": "stderr",
295
     "output_type": "stream",
296
     "text": [
297
      "GPU available: False, used: False\n",
298
      "TPU available: False, using: 0 TPU cores\n",
299
      "HPU available: False, using: 0 HPUs\n",
300
      "2024-07-12 13:12:26.638257: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
301
      "2024-07-12 13:12:26.651571: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
302
      "2024-07-12 13:12:26.671424: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
303
      "2024-07-12 13:12:26.671449: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
304
      "2024-07-12 13:12:26.684350: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
305
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
306
      "2024-07-12 13:12:27.415951: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
307
      "\n",
308
      "  | Name       | Type       | Params | Mode \n",
309
      "--------------------------------------------------\n",
310
      "0 | embedder   | Sequential | 531 K  | train\n",
311
      "1 | classifier | Sequential | 6      | train\n",
312
      "2 | loss_fun   | BCELoss    | 0      | train\n",
313
      "--------------------------------------------------\n",
314
      "531 K     Trainable params\n",
315
      "0         Non-trainable params\n",
316
      "531 K     Total params\n",
317
      "2.127     Total estimated model params size (MB)\n"
318
     ]
319
    },
320
    {
321
     "data": {
322
      "application/vnd.jupyter.widget-view+json": {
323
       "model_id": "afd281a3a6ac4be6ae161d27f15c8e56",
324
       "version_major": 2,
325
       "version_minor": 0
326
      },
327
      "text/plain": [
328
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
329
      ]
330
     },
331
     "metadata": {},
332
     "output_type": "display_data"
333
    },
334
    {
335
     "name": "stderr",
336
     "output_type": "stream",
337
     "text": [
338
      "/home/abose/QMLOmics/SessionIII_MultiOmics/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=191` in the `DataLoader` to improve performance.\n",
339
      "/home/abose/QMLOmics/SessionIII_MultiOmics/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=191` in the `DataLoader` to improve performance.\n",
340
      "/home/abose/QMLOmics/SessionIII_MultiOmics/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
341
     ]
342
    },
343
    {
344
     "data": {
345
      "application/vnd.jupyter.widget-view+json": {
346
       "model_id": "2feb697f40ce4fe1a7d2f98f2e328f92",
347
       "version_major": 2,
348
       "version_minor": 0
349
      },
350
      "text/plain": [
351
       "Training: |          | 0/? [00:00<?, ?it/s]"
352
      ]
353
     },
354
     "metadata": {},
355
     "output_type": "display_data"
356
    },
357
    {
358
     "data": {
359
      "application/vnd.jupyter.widget-view+json": {
360
       "model_id": "1866594d0bde4a43aa631377b0d8df6b",
361
       "version_major": 2,
362
       "version_minor": 0
363
      },
364
      "text/plain": [
365
       "Validation: |          | 0/? [00:00<?, ?it/s]"
366
      ]
367
     },
368
     "metadata": {},
369
     "output_type": "display_data"
370
    },
371
    {
372
     "data": {
373
      "application/vnd.jupyter.widget-view+json": {
374
       "model_id": "50fbcc83f53f442cbad46e38913a218b",
375
       "version_major": 2,
376
       "version_minor": 0
377
      },
378
      "text/plain": [
379
       "Validation: |          | 0/? [00:00<?, ?it/s]"
380
      ]
381
     },
382
     "metadata": {},
383
     "output_type": "display_data"
384
    },
385
    {
386
     "data": {
387
      "application/vnd.jupyter.widget-view+json": {
388
       "model_id": "41d8b1ed01d2431696fce99afae9dc71",
389
       "version_major": 2,
390
       "version_minor": 0
391
      },
392
      "text/plain": [
393
       "Validation: |          | 0/? [00:00<?, ?it/s]"
394
      ]
395
     },
396
     "metadata": {},
397
     "output_type": "display_data"
398
    },
399
    {
400
     "data": {
401
      "application/vnd.jupyter.widget-view+json": {
402
       "model_id": "902026fc5c1347b1b4f150d49313c067",
403
       "version_major": 2,
404
       "version_minor": 0
405
      },
406
      "text/plain": [
407
       "Validation: |          | 0/? [00:00<?, ?it/s]"
408
      ]
409
     },
410
     "metadata": {},
411
     "output_type": "display_data"
412
    },
413
    {
414
     "data": {
415
      "application/vnd.jupyter.widget-view+json": {
416
       "model_id": "26bfff9443aa4a2697c707c06b246e2a",
417
       "version_major": 2,
418
       "version_minor": 0
419
      },
420
      "text/plain": [
421
       "Validation: |          | 0/? [00:00<?, ?it/s]"
422
      ]
423
     },
424
     "metadata": {},
425
     "output_type": "display_data"
426
    },
427
    {
428
     "data": {
429
      "application/vnd.jupyter.widget-view+json": {
430
       "model_id": "eaa3399fcdf147bf84b357300828c816",
431
       "version_major": 2,
432
       "version_minor": 0
433
      },
434
      "text/plain": [
435
       "Validation: |          | 0/? [00:00<?, ?it/s]"
436
      ]
437
     },
438
     "metadata": {},
439
     "output_type": "display_data"
440
    },
441
    {
442
     "data": {
443
      "application/vnd.jupyter.widget-view+json": {
444
       "model_id": "1fbaed0d8f4f4575ad03ad4dd78fcd9e",
445
       "version_major": 2,
446
       "version_minor": 0
447
      },
448
      "text/plain": [
449
       "Validation: |          | 0/? [00:00<?, ?it/s]"
450
      ]
451
     },
452
     "metadata": {},
453
     "output_type": "display_data"
454
    },
455
    {
456
     "data": {
457
      "application/vnd.jupyter.widget-view+json": {
458
       "model_id": "1d38ac2789784e20a612a71e38fda98f",
459
       "version_major": 2,
460
       "version_minor": 0
461
      },
462
      "text/plain": [
463
       "Validation: |          | 0/? [00:00<?, ?it/s]"
464
      ]
465
     },
466
     "metadata": {},
467
     "output_type": "display_data"
468
    },
469
    {
470
     "data": {
471
      "application/vnd.jupyter.widget-view+json": {
472
       "model_id": "3ed88b70b68e4421b9be0774799d7a11",
473
       "version_major": 2,
474
       "version_minor": 0
475
      },
476
      "text/plain": [
477
       "Validation: |          | 0/? [00:00<?, ?it/s]"
478
      ]
479
     },
480
     "metadata": {},
481
     "output_type": "display_data"
482
    },
483
    {
484
     "data": {
485
      "application/vnd.jupyter.widget-view+json": {
486
       "model_id": "410e31b0fc8b4947b617d00d181ed5fe",
487
       "version_major": 2,
488
       "version_minor": 0
489
      },
490
      "text/plain": [
491
       "Validation: |          | 0/? [00:00<?, ?it/s]"
492
      ]
493
     },
494
     "metadata": {},
495
     "output_type": "display_data"
496
    },
497
    {
498
     "data": {
499
      "application/vnd.jupyter.widget-view+json": {
500
       "model_id": "3c809906d16c4181adbf366e46774a23",
501
       "version_major": 2,
502
       "version_minor": 0
503
      },
504
      "text/plain": [
505
       "Validation: |          | 0/? [00:00<?, ?it/s]"
506
      ]
507
     },
508
     "metadata": {},
509
     "output_type": "display_data"
510
    },
511
    {
512
     "data": {
513
      "application/vnd.jupyter.widget-view+json": {
514
       "model_id": "9a2c0a33fb524a599cb094dc3a690484",
515
       "version_major": 2,
516
       "version_minor": 0
517
      },
518
      "text/plain": [
519
       "Validation: |          | 0/? [00:00<?, ?it/s]"
520
      ]
521
     },
522
     "metadata": {},
523
     "output_type": "display_data"
524
    },
525
    {
526
     "data": {
527
      "application/vnd.jupyter.widget-view+json": {
528
       "model_id": "0cb0d96880a64b4d94c86ab36f4029cd",
529
       "version_major": 2,
530
       "version_minor": 0
531
      },
532
      "text/plain": [
533
       "Validation: |          | 0/? [00:00<?, ?it/s]"
534
      ]
535
     },
536
     "metadata": {},
537
     "output_type": "display_data"
538
    },
539
    {
540
     "data": {
541
      "application/vnd.jupyter.widget-view+json": {
542
       "model_id": "ae97bce46fbe4f128928d3d06ce481b0",
543
       "version_major": 2,
544
       "version_minor": 0
545
      },
546
      "text/plain": [
547
       "Validation: |          | 0/? [00:00<?, ?it/s]"
548
      ]
549
     },
550
     "metadata": {},
551
     "output_type": "display_data"
552
    },
553
    {
554
     "data": {
555
      "application/vnd.jupyter.widget-view+json": {
556
       "model_id": "b35824fd78e843dd9cd332a139636f4c",
557
       "version_major": 2,
558
       "version_minor": 0
559
      },
560
      "text/plain": [
561
       "Validation: |          | 0/? [00:00<?, ?it/s]"
562
      ]
563
     },
564
     "metadata": {},
565
     "output_type": "display_data"
566
    },
567
    {
568
     "data": {
569
      "application/vnd.jupyter.widget-view+json": {
570
       "model_id": "5da6734d8a90481a9463bfce99e5f786",
571
       "version_major": 2,
572
       "version_minor": 0
573
      },
574
      "text/plain": [
575
       "Validation: |          | 0/? [00:00<?, ?it/s]"
576
      ]
577
     },
578
     "metadata": {},
579
     "output_type": "display_data"
580
    },
581
    {
582
     "data": {
583
      "application/vnd.jupyter.widget-view+json": {
584
       "model_id": "6b36092d92594a59967e21f0ff258005",
585
       "version_major": 2,
586
       "version_minor": 0
587
      },
588
      "text/plain": [
589
       "Validation: |          | 0/? [00:00<?, ?it/s]"
590
      ]
591
     },
592
     "metadata": {},
593
     "output_type": "display_data"
594
    },
595
    {
596
     "data": {
597
      "application/vnd.jupyter.widget-view+json": {
598
       "model_id": "e9f9bd47d5d64909916996ee563c55ce",
599
       "version_major": 2,
600
       "version_minor": 0
601
      },
602
      "text/plain": [
603
       "Validation: |          | 0/? [00:00<?, ?it/s]"
604
      ]
605
     },
606
     "metadata": {},
607
     "output_type": "display_data"
608
    },
609
    {
610
     "data": {
611
      "application/vnd.jupyter.widget-view+json": {
612
       "model_id": "a309af8ebdb043f686d7fca5002b7fc5",
613
       "version_major": 2,
614
       "version_minor": 0
615
      },
616
      "text/plain": [
617
       "Validation: |          | 0/? [00:00<?, ?it/s]"
618
      ]
619
     },
620
     "metadata": {},
621
     "output_type": "display_data"
622
    },
623
    {
624
     "data": {
625
      "application/vnd.jupyter.widget-view+json": {
626
       "model_id": "5916500305484ddba658d94c139f0ce2",
627
       "version_major": 2,
628
       "version_minor": 0
629
      },
630
      "text/plain": [
631
       "Validation: |          | 0/? [00:00<?, ?it/s]"
632
      ]
633
     },
634
     "metadata": {},
635
     "output_type": "display_data"
636
    },
637
    {
638
     "data": {
639
      "application/vnd.jupyter.widget-view+json": {
640
       "model_id": "29397c5350bd4fcca1e3e6ec15289f16",
641
       "version_major": 2,
642
       "version_minor": 0
643
      },
644
      "text/plain": [
645
       "Validation: |          | 0/? [00:00<?, ?it/s]"
646
      ]
647
     },
648
     "metadata": {},
649
     "output_type": "display_data"
650
    },
651
    {
652
     "data": {
653
      "application/vnd.jupyter.widget-view+json": {
654
       "model_id": "0b973052d23d4633b7ad8c857ecab896",
655
       "version_major": 2,
656
       "version_minor": 0
657
      },
658
      "text/plain": [
659
       "Validation: |          | 0/? [00:00<?, ?it/s]"
660
      ]
661
     },
662
     "metadata": {},
663
     "output_type": "display_data"
664
    },
665
    {
666
     "data": {
667
      "application/vnd.jupyter.widget-view+json": {
668
       "model_id": "800621fcc3d447c9bb9aae4612469fe3",
669
       "version_major": 2,
670
       "version_minor": 0
671
      },
672
      "text/plain": [
673
       "Validation: |          | 0/? [00:00<?, ?it/s]"
674
      ]
675
     },
676
     "metadata": {},
677
     "output_type": "display_data"
678
    },
679
    {
680
     "data": {
681
      "application/vnd.jupyter.widget-view+json": {
682
       "model_id": "589ab988ed3a4a9191e91ceeffd95bbb",
683
       "version_major": 2,
684
       "version_minor": 0
685
      },
686
      "text/plain": [
687
       "Validation: |          | 0/? [00:00<?, ?it/s]"
688
      ]
689
     },
690
     "metadata": {},
691
     "output_type": "display_data"
692
    },
693
    {
694
     "data": {
695
      "application/vnd.jupyter.widget-view+json": {
696
       "model_id": "8ab7fdd5b30a4d78a61037b76843c883",
697
       "version_major": 2,
698
       "version_minor": 0
699
      },
700
      "text/plain": [
701
       "Validation: |          | 0/? [00:00<?, ?it/s]"
702
      ]
703
     },
704
     "metadata": {},
705
     "output_type": "display_data"
706
    },
707
    {
708
     "data": {
709
      "application/vnd.jupyter.widget-view+json": {
710
       "model_id": "6953b322b6bb48989d80fcd563914d96",
711
       "version_major": 2,
712
       "version_minor": 0
713
      },
714
      "text/plain": [
715
       "Validation: |          | 0/? [00:00<?, ?it/s]"
716
      ]
717
     },
718
     "metadata": {},
719
     "output_type": "display_data"
720
    },
721
    {
722
     "data": {
723
      "application/vnd.jupyter.widget-view+json": {
724
       "model_id": "b9695946233d4e1d9f67e68bf459b2a4",
725
       "version_major": 2,
726
       "version_minor": 0
727
      },
728
      "text/plain": [
729
       "Validation: |          | 0/? [00:00<?, ?it/s]"
730
      ]
731
     },
732
     "metadata": {},
733
     "output_type": "display_data"
734
    },
735
    {
736
     "data": {
737
      "application/vnd.jupyter.widget-view+json": {
738
       "model_id": "667a451df9eb49ceaece4de0780fca12",
739
       "version_major": 2,
740
       "version_minor": 0
741
      },
742
      "text/plain": [
743
       "Validation: |          | 0/? [00:00<?, ?it/s]"
744
      ]
745
     },
746
     "metadata": {},
747
     "output_type": "display_data"
748
    },
749
    {
750
     "data": {
751
      "application/vnd.jupyter.widget-view+json": {
752
       "model_id": "a938f5b709934ed6a3c46843276d9d19",
753
       "version_major": 2,
754
       "version_minor": 0
755
      },
756
      "text/plain": [
757
       "Validation: |          | 0/? [00:00<?, ?it/s]"
758
      ]
759
     },
760
     "metadata": {},
761
     "output_type": "display_data"
762
    },
763
    {
764
     "data": {
765
      "application/vnd.jupyter.widget-view+json": {
766
       "model_id": "21272ee9f2934f699b97b3cc84f90eef",
767
       "version_major": 2,
768
       "version_minor": 0
769
      },
770
      "text/plain": [
771
       "Validation: |          | 0/? [00:00<?, ?it/s]"
772
      ]
773
     },
774
     "metadata": {},
775
     "output_type": "display_data"
776
    },
777
    {
778
     "data": {
779
      "application/vnd.jupyter.widget-view+json": {
780
       "model_id": "949bd80e3bb9469e8496eb333c4efa79",
781
       "version_major": 2,
782
       "version_minor": 0
783
      },
784
      "text/plain": [
785
       "Validation: |          | 0/? [00:00<?, ?it/s]"
786
      ]
787
     },
788
     "metadata": {},
789
     "output_type": "display_data"
790
    },
791
    {
792
     "data": {
793
      "application/vnd.jupyter.widget-view+json": {
794
       "model_id": "8f06b68975cb4cdea69328626b62d0ce",
795
       "version_major": 2,
796
       "version_minor": 0
797
      },
798
      "text/plain": [
799
       "Validation: |          | 0/? [00:00<?, ?it/s]"
800
      ]
801
     },
802
     "metadata": {},
803
     "output_type": "display_data"
804
    },
805
    {
806
     "data": {
807
      "application/vnd.jupyter.widget-view+json": {
808
       "model_id": "2d4486c1d5094d419cea9b4249ee67ba",
809
       "version_major": 2,
810
       "version_minor": 0
811
      },
812
      "text/plain": [
813
       "Validation: |          | 0/? [00:00<?, ?it/s]"
814
      ]
815
     },
816
     "metadata": {},
817
     "output_type": "display_data"
818
    },
819
    {
820
     "data": {
821
      "application/vnd.jupyter.widget-view+json": {
822
       "model_id": "43f06e9af03549d48c56074332d1b0ce",
823
       "version_major": 2,
824
       "version_minor": 0
825
      },
826
      "text/plain": [
827
       "Validation: |          | 0/? [00:00<?, ?it/s]"
828
      ]
829
     },
830
     "metadata": {},
831
     "output_type": "display_data"
832
    },
833
    {
834
     "data": {
835
      "application/vnd.jupyter.widget-view+json": {
836
       "model_id": "4680640dc1904e7d87c1d02cb817b8bf",
837
       "version_major": 2,
838
       "version_minor": 0
839
      },
840
      "text/plain": [
841
       "Validation: |          | 0/? [00:00<?, ?it/s]"
842
      ]
843
     },
844
     "metadata": {},
845
     "output_type": "display_data"
846
    },
847
    {
848
     "data": {
849
      "application/vnd.jupyter.widget-view+json": {
850
       "model_id": "92ef38a4c8314805a3c27ed94929b46f",
851
       "version_major": 2,
852
       "version_minor": 0
853
      },
854
      "text/plain": [
855
       "Validation: |          | 0/? [00:00<?, ?it/s]"
856
      ]
857
     },
858
     "metadata": {},
859
     "output_type": "display_data"
860
    },
861
    {
862
     "data": {
863
      "application/vnd.jupyter.widget-view+json": {
864
       "model_id": "5b65eb5da1ad451fb01235fc22d83478",
865
       "version_major": 2,
866
       "version_minor": 0
867
      },
868
      "text/plain": [
869
       "Validation: |          | 0/? [00:00<?, ?it/s]"
870
      ]
871
     },
872
     "metadata": {},
873
     "output_type": "display_data"
874
    },
875
    {
876
     "data": {
877
      "application/vnd.jupyter.widget-view+json": {
878
       "model_id": "46cf08a04d00430cb780c555189c56aa",
879
       "version_major": 2,
880
       "version_minor": 0
881
      },
882
      "text/plain": [
883
       "Validation: |          | 0/? [00:00<?, ?it/s]"
884
      ]
885
     },
886
     "metadata": {},
887
     "output_type": "display_data"
888
    },
889
    {
890
     "data": {
891
      "application/vnd.jupyter.widget-view+json": {
892
       "model_id": "19dc01567a744838a98b2ae6ea833c11",
893
       "version_major": 2,
894
       "version_minor": 0
895
      },
896
      "text/plain": [
897
       "Validation: |          | 0/? [00:00<?, ?it/s]"
898
      ]
899
     },
900
     "metadata": {},
901
     "output_type": "display_data"
902
    },
903
    {
904
     "data": {
905
      "application/vnd.jupyter.widget-view+json": {
906
       "model_id": "d4e9d6f7d8214b47ad0beb604847d952",
907
       "version_major": 2,
908
       "version_minor": 0
909
      },
910
      "text/plain": [
911
       "Validation: |          | 0/? [00:00<?, ?it/s]"
912
      ]
913
     },
914
     "metadata": {},
915
     "output_type": "display_data"
916
    },
917
    {
918
     "name": "stderr",
919
     "output_type": "stream",
920
     "text": [
921
      "`Trainer.fit` stopped: `max_epochs=40` reached.\n"
922
     ]
923
    }
924
   ],
925
   "source": [
926
    "f1s = []\n",
927
    "embeddings_train = []\n",
928
    "embeddings_test = []\n",
929
    "train_labels = []\n",
930
    "test_labels = [] \n",
931
    "\n",
932
    "num_iter = 1\n",
933
    "for i in range(num_iter): \n",
934
    "\n",
935
    "    X_train, X_test, y_train, y_test = train_test_split(X_working,\n",
936
    "                                                        y_working,\n",
937
    "                                                        train_size=0.9,\n",
938
    "                                                        shuffle=True)\n",
939
    "\n",
940
    "    num_epochs = 40\n",
941
    "    model = BinaryClassifierModel(embedder, classifier, input_dim=num_feats)\n",
942
    "    model.train_data, model.val_data = prepare_data(X_train, y_train, X_test, y_test)  # Prepare data for training\n",
943
    "    logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=\".\",name=\"original_classifier\")\n",
944
    "    # Train the model\n",
945
    "    trainer = lightning.Trainer(max_epochs=num_epochs, \n",
946
    "                                logger=logger)  # Adjust progress bar refresh rate as needed\n",
947
    "    trainer.fit(model)\n",
948
    "    model.eval()\n",
949
    "    embedded_test = model.embedder(torch.tensor(X_test, dtype=torch.float32))\n",
950
    "    y_pred = model.classifier(embedded_test)\n",
951
    "    y_pred_proba = y_pred.detach().cpu().numpy()\n",
952
    "    y_pred_class = np.round(y_pred_proba)\n",
953
    "\n",
954
    "    f1 = f1_score(y_test, y_pred_class)\n",
955
    "    f1s.append(f1)\n",
956
    "    \n",
957
    "    embedded_train = model.embedder(torch.tensor(X_train, dtype=torch.float32)).detach().numpy()\n",
958
    "    embeddings_train.append(embedded_train)\n",
959
    "    embeddings_test.append(embedded_test.detach().numpy())\n",
960
    "    train_labels.append(y_train)\n",
961
    "    test_labels.append(y_test)"
962
   ]
963
  },
964
  {
965
   "cell_type": "markdown",
966
   "metadata": {},
967
   "source": [
968
    "### Save the embeddings"
969
   ]
970
  },
971
  {
972
   "cell_type": "code",
973
   "execution_count": 11,
974
   "metadata": {},
975
   "outputs": [],
976
   "source": [
977
    "fname = \"MRD\"\n",
978
    "\n",
979
    "for i,x in enumerate(embeddings_train): \n",
980
    "    fname_train = fname + \"_iter\" + str(i) + \"_train_embedding\"\n",
981
    "    np.save(f\"checkpoints/{fname}/{fname_train}\", x)\n",
982
    "\n",
983
    "for i,x in enumerate(train_labels): \n",
984
    "    fname_train_y = fname + \"_iter\" + str(i) + \"_train_labels\"\n",
985
    "    np.save(f\"checkpoints/{fname}/{fname_train_y}\", x)\n",
986
    "\n",
987
    "for i,x in enumerate(embeddings_test): \n",
988
    "    fname_test = fname + \"_iter\" + str(i) + \"_test_embedding\"\n",
989
    "    np.save(f\"checkpoints/{fname}/{fname_test}\", x)\n",
990
    "    \n",
991
    "for i,x in enumerate(test_labels): \n",
992
    "    fname_test_y = fname + \"_iter\" + str(i) + \"_test_labels\"\n",
993
    "    np.save(f\"checkpoints/{fname}/{fname_test_y}\", x)\n"
994
   ]
995
  }
996
 ],
997
 "metadata": {
998
  "kernelspec": {
999
   "display_name": "Python 3",
1000
   "language": "python",
1001
   "name": "python3"
1002
  },
1003
  "language_info": {
1004
   "codemirror_mode": {
1005
    "name": "ipython",
1006
    "version": 3
1007
   },
1008
   "file_extension": ".py",
1009
   "mimetype": "text/x-python",
1010
   "name": "python",
1011
   "nbconvert_exporter": "python",
1012
   "pygments_lexer": "ipython3",
1013
   "version": "3.10.12"
1014
  }
1015
 },
1016
 "nbformat": 4,
1017
 "nbformat_minor": 2
1018
}