Diff of /fine-tune.ipynb [000000] .. [a108cb]

Switch to unified view

a b/fine-tune.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "6614ed41-7f65-4495-aa3d-8fba22013597",
6
   "metadata": {},
7
   "source": [
8
    "Different learning rates, number of epochs, and batch sizes were explored, but the results did not change much."
9
   ]
10
  },
11
  {
12
   "cell_type": "code",
13
   "execution_count": 1,
14
   "id": "9bff167f-c177-4f9f-9408-bd499f2d63e8",
15
   "metadata": {},
16
   "outputs": [],
17
   "source": [
18
    "import pandas as pd\n",
19
    "import numpy as np\n",
20
    "import torch\n",
21
    "from sklearn.model_selection import train_test_split\n",
22
    "from torch.utils.data import TensorDataset, DataLoader\n",
23
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
24
    "from tqdm import tqdm\n",
25
    "import torch.nn as nn"
26
   ]
27
  },
28
  {
29
   "cell_type": "code",
30
   "execution_count": 2,
31
   "id": "c1d30e92-0aea-498c-8d3f-d8b24cf4aeeb",
32
   "metadata": {},
33
   "outputs": [
34
    {
35
     "name": "stdout",
36
     "output_type": "stream",
37
     "text": [
38
      "Number of examples is: 560\n"
39
     ]
40
    },
41
    {
42
     "data": {
43
      "text/html": [
44
       "<div>\n",
45
       "<style scoped>\n",
46
       "    .dataframe tbody tr th:only-of-type {\n",
47
       "        vertical-align: middle;\n",
48
       "    }\n",
49
       "\n",
50
       "    .dataframe tbody tr th {\n",
51
       "        vertical-align: top;\n",
52
       "    }\n",
53
       "\n",
54
       "    .dataframe thead th {\n",
55
       "        text-align: right;\n",
56
       "    }\n",
57
       "</style>\n",
58
       "<table border=\"1\" class=\"dataframe\">\n",
59
       "  <thead>\n",
60
       "    <tr style=\"text-align: right;\">\n",
61
       "      <th></th>\n",
62
       "      <th>CANONICAL_SMILES</th>\n",
63
       "      <th>pIC50</th>\n",
64
       "    </tr>\n",
65
       "  </thead>\n",
66
       "  <tbody>\n",
67
       "    <tr>\n",
68
       "      <th>0</th>\n",
69
       "      <td>Nc1nc(N)c2c(Sc3ccccc3)cccc2n1</td>\n",
70
       "      <td>6.21</td>\n",
71
       "    </tr>\n",
72
       "    <tr>\n",
73
       "      <th>1</th>\n",
74
       "      <td>COc1ccc(OC)c(Cc2sc3nc(N)nc(N)c3c2C)c1</td>\n",
75
       "      <td>6.14</td>\n",
76
       "    </tr>\n",
77
       "    <tr>\n",
78
       "      <th>2</th>\n",
79
       "      <td>CN(Cc1coc2nc(N)nc(N)c12)c3ccc(cc3)C(=O)N[C@@H]...</td>\n",
80
       "      <td>6.66</td>\n",
81
       "    </tr>\n",
82
       "    <tr>\n",
83
       "      <th>3</th>\n",
84
       "      <td>Nc1nc(N)c2nc(CSc3ccc(cc3)C(=O)NC(CCC(=O)O)C(=O...</td>\n",
85
       "      <td>5.57</td>\n",
86
       "    </tr>\n",
87
       "    <tr>\n",
88
       "      <th>4</th>\n",
89
       "      <td>Nc1nc(N)c2nc(CCSc3ccc(cc3)C(=O)NC(CCC(=O)O)C(=...</td>\n",
90
       "      <td>4.60</td>\n",
91
       "    </tr>\n",
92
       "  </tbody>\n",
93
       "</table>\n",
94
       "</div>"
95
      ],
96
      "text/plain": [
97
       "                                    CANONICAL_SMILES  pIC50\n",
98
       "0                      Nc1nc(N)c2c(Sc3ccccc3)cccc2n1   6.21\n",
99
       "1              COc1ccc(OC)c(Cc2sc3nc(N)nc(N)c3c2C)c1   6.14\n",
100
       "2  CN(Cc1coc2nc(N)nc(N)c12)c3ccc(cc3)C(=O)N[C@@H]...   6.66\n",
101
       "3  Nc1nc(N)c2nc(CSc3ccc(cc3)C(=O)NC(CCC(=O)O)C(=O...   5.57\n",
102
       "4  Nc1nc(N)c2nc(CCSc3ccc(cc3)C(=O)NC(CCC(=O)O)C(=...   4.60"
103
      ]
104
     },
105
     "execution_count": 2,
106
     "metadata": {},
107
     "output_type": "execute_result"
108
    }
109
   ],
110
   "source": [
111
    "# Load the dataset\n",
112
    "df = pd.read_csv(\"data/hDHFR_pIC50_data.csv\")[[\"CANONICAL_SMILES\", \"pIC50\"]]\n",
113
    "print(f'Number of examples is: {len(df)}')\n",
114
    "df.head()"
115
   ]
116
  },
117
  {
118
   "cell_type": "code",
119
   "execution_count": 3,
120
   "id": "04851bc1-5707-4bd5-b39d-0d5d618dd116",
121
   "metadata": {},
122
   "outputs": [
123
    {
124
     "name": "stderr",
125
     "output_type": "stream",
126
     "text": [
127
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MLM and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias']\n",
128
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
129
     ]
130
    },
131
    {
132
     "name": "stdout",
133
     "output_type": "stream",
134
     "text": [
135
      "512\n"
136
     ]
137
    }
138
   ],
139
   "source": [
140
    "# Load the pre-trained model and tokenizer\n",
141
    "model_name = \"DeepChem/ChemBERTa-77M-MLM\"\n",
142
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n",
143
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
144
    "\n",
145
    "# Determine the maximum sequence length\n",
146
    "max_length = tokenizer.model_max_length\n",
147
    "print(max_length)"
148
   ]
149
  },
150
  {
151
   "cell_type": "code",
152
   "execution_count": 4,
153
   "id": "755634da-6e15-46c3-a9d2-fdc7bd2aa75e",
154
   "metadata": {},
155
   "outputs": [
156
    {
157
     "name": "stderr",
158
     "output_type": "stream",
159
     "text": [
160
      "100%|██████████████████████████████████████| 560/560 [00:00<00:00, 15447.41it/s]\n"
161
     ]
162
    }
163
   ],
164
   "source": [
165
    "# Tokenization function\n",
166
    "def tokenize(string):\n",
167
    "    \"\"\"\n",
168
    "    Tokenize and encode a string using the provided tokenizer.\n",
169
    "    \n",
170
    "    Parameters:\n",
171
    "        string (str): Input string to be tokenized.\n",
172
    "    \n",
173
    "    Returns:\n",
174
    "        Tuple of input_ids and attention_mask.\n",
175
    "    \"\"\"\n",
176
    "    encodings = tokenizer.encode_plus(\n",
177
    "        string,\n",
178
    "        add_special_tokens=True,\n",
179
    "        truncation=True,\n",
180
    "        padding=\"max_length\",\n",
181
    "        max_length=max_length,\n",
182
    "        return_attention_mask=True\n",
183
    "    )\n",
184
    "    input_ids = encodings[\"input_ids\"]\n",
185
    "    attention_mask = encodings[\"attention_mask\"]\n",
186
    "    return input_ids, attention_mask\n",
187
    "\n",
188
    "# Tokenize the 'CANONICAL_SMILES' column and create new columns 'input_ids' and 'attention_mask'\n",
189
    "tqdm.pandas()\n",
190
    "df[[\"input_ids\", \"attention_mask\"]] = df[\"CANONICAL_SMILES\"].progress_apply(lambda x: tokenize(x)).apply(pd.Series)"
191
   ]
192
  },
193
  {
194
   "cell_type": "code",
195
   "execution_count": 5,
196
   "id": "63fade54-27fa-4ff1-af69-ea8d96c4ec1d",
197
   "metadata": {},
198
   "outputs": [
199
    {
200
     "name": "stdout",
201
     "output_type": "stream",
202
     "text": [
203
      "There are 448 molecules in Train df.\n",
204
      "There are 56 molecules in Val df.\n",
205
      "There are 56 molecules in Test df.\n"
206
     ]
207
    }
208
   ],
209
   "source": [
210
    "# Split the dataset into train, validation, and test sets\n",
211
    "train_df, temp_df = train_test_split(df, test_size=0.2, random_state=21)\n",
212
    "val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=21)\n",
213
    "print(f\"There are {len(train_df)} molecules in Train df.\")\n",
214
    "print(f\"There are {len(val_df)} molecules in Val df.\")\n",
215
    "print(f\"There are {len(test_df)} molecules in Test df.\")"
216
   ]
217
  },
218
  {
219
   "cell_type": "code",
220
   "execution_count": 6,
221
   "id": "a7d9638c-b3fc-4063-8ea8-a1303178466f",
222
   "metadata": {},
223
   "outputs": [],
224
   "source": [
225
    "# Function to convert data to PyTorch tensors\n",
226
    "def get_tensor_data(data):\n",
227
    "    \"\"\"\n",
228
    "    Convert data to PyTorch tensors.\n",
229
    "    \n",
230
    "    Parameters:\n",
231
    "        data (DataFrame): Input data containing 'input_ids', 'attention_mask', and 'pIC50' columns.\n",
232
    "    \n",
233
    "    Returns:\n",
234
    "        TensorDataset containing input_ids, attention_mask, and labels tensors.\n",
235
    "    \"\"\"\n",
236
    "    input_ids_tensor = torch.tensor(data[\"input_ids\"].tolist(), dtype=torch.int32)\n",
237
    "    attention_mask_tensor = torch.tensor(data[\"attention_mask\"].tolist(), dtype=torch.int32)\n",
238
    "    labels_tensor = torch.tensor(data[\"pIC50\"].tolist(), dtype=torch.float32)\n",
239
    "    return TensorDataset(input_ids_tensor, attention_mask_tensor, labels_tensor)\n",
240
    "\n",
241
    "# Create datasets and data loaders\n",
242
    "train_dataset = get_tensor_data(train_df)\n",
243
    "val_dataset = get_tensor_data(val_df)\n",
244
    "test_dataset = get_tensor_data(test_df)\n",
245
    "\n",
246
    "batch_size = 32\n",
247
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
248
    "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
249
    "test_loader = DataLoader(test_dataset, batch_size=batch_size)"
250
   ]
251
  },
252
  {
253
   "cell_type": "code",
254
   "execution_count": 7,
255
   "id": "6ebfebf6-b6db-4082-ab9a-70b79884014f",
256
   "metadata": {},
257
   "outputs": [
258
    {
259
     "name": "stderr",
260
     "output_type": "stream",
261
     "text": [
262
      " 10%|████▍                                       | 1/10 [00:04<00:37,  4.12s/it]"
263
     ]
264
    },
265
    {
266
     "name": "stdout",
267
     "output_type": "stream",
268
     "text": [
269
      "epoch 1: Train Loss 13.7115, Val Loss 3.4070\n"
270
     ]
271
    },
272
    {
273
     "name": "stderr",
274
     "output_type": "stream",
275
     "text": [
276
      " 20%|████████▊                                   | 2/10 [00:07<00:29,  3.68s/it]"
277
     ]
278
    },
279
    {
280
     "name": "stdout",
281
     "output_type": "stream",
282
     "text": [
283
      "epoch 2: Train Loss 1.8040, Val Loss 1.3049\n"
284
     ]
285
    },
286
    {
287
     "name": "stderr",
288
     "output_type": "stream",
289
     "text": [
290
      " 30%|█████████████▏                              | 3/10 [00:10<00:24,  3.54s/it]"
291
     ]
292
    },
293
    {
294
     "name": "stdout",
295
     "output_type": "stream",
296
     "text": [
297
      "epoch 3: Train Loss 1.4318, Val Loss 1.2696\n"
298
     ]
299
    },
300
    {
301
     "name": "stderr",
302
     "output_type": "stream",
303
     "text": [
304
      " 40%|█████████████████▌                          | 4/10 [00:14<00:20,  3.48s/it]"
305
     ]
306
    },
307
    {
308
     "name": "stdout",
309
     "output_type": "stream",
310
     "text": [
311
      "epoch 4: Train Loss 1.4239, Val Loss 1.2848\n"
312
     ]
313
    },
314
    {
315
     "name": "stderr",
316
     "output_type": "stream",
317
     "text": [
318
      " 50%|██████████████████████                      | 5/10 [00:17<00:17,  3.44s/it]"
319
     ]
320
    },
321
    {
322
     "name": "stdout",
323
     "output_type": "stream",
324
     "text": [
325
      "epoch 5: Train Loss 1.4157, Val Loss 1.3262\n"
326
     ]
327
    },
328
    {
329
     "name": "stderr",
330
     "output_type": "stream",
331
     "text": [
332
      " 60%|██████████████████████████▍                 | 6/10 [00:21<00:13,  3.42s/it]"
333
     ]
334
    },
335
    {
336
     "name": "stdout",
337
     "output_type": "stream",
338
     "text": [
339
      "epoch 6: Train Loss 1.3806, Val Loss 1.2986\n"
340
     ]
341
    },
342
    {
343
     "name": "stderr",
344
     "output_type": "stream",
345
     "text": [
346
      " 70%|██████████████████████████████▊             | 7/10 [00:24<00:10,  3.44s/it]"
347
     ]
348
    },
349
    {
350
     "name": "stdout",
351
     "output_type": "stream",
352
     "text": [
353
      "epoch 7: Train Loss 1.3577, Val Loss 1.2738\n"
354
     ]
355
    },
356
    {
357
     "name": "stderr",
358
     "output_type": "stream",
359
     "text": [
360
      " 80%|███████████████████████████████████▏        | 8/10 [00:27<00:06,  3.46s/it]"
361
     ]
362
    },
363
    {
364
     "name": "stdout",
365
     "output_type": "stream",
366
     "text": [
367
      "epoch 8: Train Loss 1.3578, Val Loss 1.2921\n"
368
     ]
369
    },
370
    {
371
     "name": "stderr",
372
     "output_type": "stream",
373
     "text": [
374
      " 90%|███████████████████████████████████████▌    | 9/10 [00:31<00:03,  3.47s/it]"
375
     ]
376
    },
377
    {
378
     "name": "stdout",
379
     "output_type": "stream",
380
     "text": [
381
      "epoch 9: Train Loss 1.3547, Val Loss 1.3023\n"
382
     ]
383
    },
384
    {
385
     "name": "stderr",
386
     "output_type": "stream",
387
     "text": [
388
      "100%|███████████████████████████████████████████| 10/10 [00:34<00:00,  3.49s/it]"
389
     ]
390
    },
391
    {
392
     "name": "stdout",
393
     "output_type": "stream",
394
     "text": [
395
      "epoch 10: Train Loss 1.3551, Val Loss 1.3414\n"
396
     ]
397
    },
398
    {
399
     "name": "stderr",
400
     "output_type": "stream",
401
     "text": [
402
      "\n"
403
     ]
404
    }
405
   ],
406
   "source": [
407
    "# Loss criterion and optimizer\n",
408
    "criterion = nn.MSELoss()\n",
409
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n",
410
    "device = torch.device(\"mps\")\n",
411
    "model.to(device)\n",
412
    "\n",
413
    "epochs = 10\n",
414
    "torch.manual_seed(12345)\n",
415
    "for epoch in tqdm(range(epochs)):\n",
416
    "    # Training loop\n",
417
    "    model.train()\n",
418
    "    total_train_loss = 0\n",
419
    "    for batch in train_loader:\n",
420
    "        optimizer.zero_grad(set_to_none=True)\n",
421
    "        input_ids, attention_mask, labels = batch\n",
422
    "        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)\n",
423
    "        output_dict = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
424
    "        predictions = output_dict.logits.squeeze(dim=1)\n",
425
    "        loss = criterion(predictions, labels)\n",
426
    "        loss.backward()\n",
427
    "        optimizer.step()\n",
428
    "        total_train_loss += loss.item()\n",
429
    "    avg_train_loss = total_train_loss / len(train_loader)\n",
430
    "\n",
431
    "    # Validation loop\n",
432
    "    model.eval()\n",
433
    "    total_val_loss = 0\n",
434
    "    with torch.no_grad():\n",
435
    "        for batch in val_loader:\n",
436
    "            input_ids, attention_mask, labels = batch\n",
437
    "            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)\n",
438
    "            output_dict = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
439
    "            predictions = output_dict.logits.squeeze(dim=1)\n",
440
    "            loss = criterion(predictions, labels)\n",
441
    "            total_val_loss += loss.item()\n",
442
    "    avg_val_loss = total_val_loss / len(val_loader)\n",
443
    "    print(f\"epoch {epoch + 1}: Train Loss {avg_train_loss:.4f}, Val Loss {avg_val_loss:.4f}\")\n"
444
   ]
445
  },
446
  {
447
   "cell_type": "code",
448
   "execution_count": 8,
449
   "id": "e15f9a1f-80de-4e97-bb52-03ce819f9ca4",
450
   "metadata": {},
451
   "outputs": [
452
    {
453
     "name": "stdout",
454
     "output_type": "stream",
455
     "text": [
456
      "Test Loss 1.4417\n"
457
     ]
458
    }
459
   ],
460
   "source": [
461
    "# Testing loop\n",
462
    "total_test_loss = 0\n",
463
    "test_labels = []\n",
464
    "test_predictions = []\n",
465
    "with torch.no_grad():\n",
466
    "    for batch in test_loader:\n",
467
    "        input_ids, attention_mask, labels = batch\n",
468
    "        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)\n",
469
    "        output_dict = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
470
    "        predictions = output_dict.logits.squeeze(dim=1)\n",
471
    "        loss = criterion(predictions, labels)\n",
472
    "        total_test_loss += loss.item()\n",
473
    "        test_labels.extend(labels.tolist())\n",
474
    "        test_predictions.extend(predictions.tolist())\n",
475
    "avg_test_loss = total_test_loss / len(test_loader)\n",
476
    "print(f\"Test Loss {avg_test_loss:.4f}\")"
477
   ]
478
  },
479
  {
480
   "cell_type": "code",
481
   "execution_count": 9,
482
   "id": "a5325e19-2edd-4732-8732-e6b679fe7561",
483
   "metadata": {},
484
   "outputs": [
485
    {
486
     "data": {
487
      "image/png": "",
488
      "text/plain": [
489
       "<Figure size 640x480 with 1 Axes>"
490
      ]
491
     },
492
     "metadata": {},
493
     "output_type": "display_data"
494
    }
495
   ],
496
   "source": [
497
    "import matplotlib.pyplot as plt\n",
498
    "plt.scatter(test_predictions, test_labels)\n",
499
    "plt.plot([4, 8], [4, 8], 'k--')\n",
500
    "plt.xlabel('Predicted')\n",
501
    "plt.ylabel('Actual');"
502
   ]
503
  }
504
 ],
505
 "metadata": {
506
  "kernelspec": {
507
   "display_name": "Python 3 (ipykernel)",
508
   "language": "python",
509
   "name": "python3"
510
  },
511
  "language_info": {
512
   "codemirror_mode": {
513
    "name": "ipython",
514
    "version": 3
515
   },
516
   "file_extension": ".py",
517
   "mimetype": "text/x-python",
518
   "name": "python",
519
   "nbconvert_exporter": "python",
520
   "pygments_lexer": "ipython3",
521
   "version": "3.9.18"
522
  }
523
 },
524
 "nbformat": 4,
525
 "nbformat_minor": 5
526
}