--- a
+++ b/transfer-learning-plus-descriptors.ipynb
@@ -0,0 +1,367 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "d3426c73-9556-4223-a0f4-afdd5edbd911",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import torch\n",
+    "import xgboost as xgb\n",
+    "from scipy.stats import spearmanr\n",
+    "from sklearn.model_selection import train_test_split, GridSearchCV, KFold\n",
+    "from sklearn.metrics import mean_squared_error\n",
+    "from transformers import AutoTokenizer, AutoModel"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "cc1849cc-efb9-4eb9-946e-3ce2a0480fee",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Number of examples is: 560\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "Index(['XLogP', 'LipinskiFailures', 'nRotB', 'MLogP', 'nAtomLAC', 'nAtomP',\n",
+       "       'nAtomLC', 'nBase', 'naAromAtom', 'ALogP', 'ALogp2', 'nSmallRings',\n",
+       "       'nRingBlocks', 'nAromBlocks', 'nRings6', 'WPATH', 'WTPT.2', 'WTPT.4',\n",
+       "       'WTPT.5', 'MDEC.11', 'MDEC.12', 'MDEC.13', 'MDEC.22', 'MDEC.23',\n",
+       "       'MDEC.33', 'MDEO.11', 'MDEO.22', 'MDEN.11', 'MDEN.12', 'MDEN.13',\n",
+       "       'MDEN.22', 'khs.ssCH2', 'khs.dsCH', 'khs.aaCH', 'khs.sssCH', 'khs.tsC',\n",
+       "       'khs.dssC', 'khs.aasC', 'khs.aaaC', 'khs.ssssC', 'khs.sNH2', 'khs.ssNH',\n",
+       "       'khs.aaNH', 'khs.aaN', 'khs.sssN', 'khs.aasN', 'khs.sOH', 'khs.ssO',\n",
+       "       'khs.aaO', 'khs.ssS', 'khs.aaS', 'khs.sCl', 'HybRatio', 'FMF', 'ECCEN',\n",
+       "       'SP.7', 'VP.7', 'SPC.6', 'VPC.6', 'SC.3', 'SC.5', 'VC.3', 'VC.5',\n",
+       "       'SCH.5', 'SCH.6', 'SCH.7', 'VCH.5', 'VCH.6', 'C1SP2', 'C2SP2', 'C3SP2',\n",
+       "       'C1SP3', 'C2SP3', 'C3SP3', 'ATSp5', 'ATSm1', 'tpsaEfficiency.1',\n",
+       "       'nHBDon', 'bpol', 'topoShape.1', 'pIC50', 'CANONICAL_SMILES'],\n",
+       "      dtype='object')"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# Load the the dataset, including molecular descriptors\n",
+    "df = pd.read_csv(\"data/hDHFR_pIC50_data.csv\")\n",
+    "print(f'Number of examples is: {len(df)}')\n",
+    "df.columns"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "ff644daa-469a-42d8-98e9-41cdc060e3e8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Specify the column names of the descriptors\n",
+    "descriptors = ['XLogP', 'LipinskiFailures', 'nRotB', 'MLogP', 'nAtomLAC', 'nAtomP',\n",
+    "               'nAtomLC', 'nBase', 'naAromAtom', 'ALogP', 'ALogp2', 'nSmallRings',\n",
+    "               'nRingBlocks', 'nAromBlocks', 'nRings6', 'WPATH', 'WTPT.2', 'WTPT.4',\n",
+    "               'WTPT.5', 'MDEC.11', 'MDEC.12', 'MDEC.13', 'MDEC.22', 'MDEC.23',\n",
+    "               'MDEC.33', 'MDEO.11', 'MDEO.22', 'MDEN.11', 'MDEN.12', 'MDEN.13',\n",
+    "               'MDEN.22', 'khs.ssCH2', 'khs.dsCH', 'khs.aaCH', 'khs.sssCH', 'khs.tsC',\n",
+    "               'khs.dssC', 'khs.aasC', 'khs.aaaC', 'khs.ssssC', 'khs.sNH2', 'khs.ssNH',\n",
+    "               'khs.aaNH', 'khs.aaN', 'khs.sssN', 'khs.aasN', 'khs.sOH', 'khs.ssO',\n",
+    "               'khs.aaO', 'khs.ssS', 'khs.aaS', 'khs.sCl', 'HybRatio', 'FMF', 'ECCEN',\n",
+    "               'SP.7', 'VP.7', 'SPC.6', 'VPC.6', 'SC.3', 'SC.5', 'VC.3', 'VC.5',\n",
+    "               'SCH.5', 'SCH.6', 'SCH.7', 'VCH.5', 'VCH.6', 'C1SP2', 'C2SP2', 'C3SP2',\n",
+    "               'C1SP3', 'C2SP3', 'C3SP3', 'ATSp5', 'ATSm1', 'tpsaEfficiency.1',\n",
+    "               'nHBDon', 'bpol', 'topoShape.1']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "005b3f00-7291-4ea0-9726-18a8217d5a22",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MLM and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
+      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Load pre-trained model and tokenizer\n",
+    "model_name = 'DeepChem/ChemBERTa-77M-MLM'\n",
+    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+    "model = AutoModel.from_pretrained(model_name)\n",
+    "\n",
+    "# Tokenize your input data\n",
+    "input_texts = list(df[\"CANONICAL_SMILES\"])\n",
+    "input_ids = tokenizer(input_texts, padding=True, truncation=True, return_tensors=\"pt\")\n",
+    "\n",
+    "# Pass tokenized input through the model to obtain embeddings\n",
+    "with torch.no_grad():\n",
+    "    outputs = model(**input_ids)\n",
+    "    embeddings = outputs.last_hidden_state"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "2157c60e-ff96-44c8-8938-8b2e1773d55d",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([560, 166, 384])"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# (number_of_examples, max_length, hidden_dimentions)\n",
+    "embeddings.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "f2f0def2-3c37-46c1-a335-90c70eb7a1a9",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([560, 384])"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# Flatten embeddings by pooling across the sequence dimension\n",
+    "pooled_embeddings = torch.mean(embeddings, dim=1)  # use mean; other pooling methods can be used\n",
+    "pooled_embeddings.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "8250d5d7-fe14-40a5-ad50-0368ab2b39d8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Concatenate the embeddings with the dataset\n",
+    "column_names = ['x' + str(i) for i in range(pooled_embeddings.shape[1])]\n",
+    "pooled_embeddings_df = pd.DataFrame(data=pooled_embeddings, columns=column_names)\n",
+    "data = pd.concat([df, pooled_embeddings_df], axis=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "d9efd79f-ddbc-4941-86c1-c782203fa850",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "There are 448 molecules in Train df.\n",
+      "There are 56 molecules in Val df.\n",
+      "There are 56 molecules in Test df.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Split the dataset into train, validation, and test sets\n",
+    "train_df, temp_df = train_test_split(data, test_size=0.2, random_state=21)\n",
+    "val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=21)\n",
+    "print(f\"There are {len(train_df)} molecules in Train df.\")\n",
+    "print(f\"There are {len(val_df)} molecules in Val df.\")\n",
+    "print(f\"There are {len(test_df)} molecules in Test df.\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "8681a83d-124f-48a2-a904-403bb18785e7",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+       "             colsample_bylevel=None, colsample_bynode=None,\n",
+       "             colsample_bytree=0.9, device=None, early_stopping_rounds=None,\n",
+       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
+       "             gamma=None, grow_policy=None, importance_type=None,\n",
+       "             interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
+       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
+       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
+       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+       "             multi_strategy=None, n_estimators=200, n_jobs=None,\n",
+       "             num_parallel_tree=None, random_state=42, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBRegressor</label><div class=\"sk-toggleable__content\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+       "             colsample_bylevel=None, colsample_bynode=None,\n",
+       "             colsample_bytree=0.9, device=None, early_stopping_rounds=None,\n",
+       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
+       "             gamma=None, grow_policy=None, importance_type=None,\n",
+       "             interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
+       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
+       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
+       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+       "             multi_strategy=None, n_estimators=200, n_jobs=None,\n",
+       "             num_parallel_tree=None, random_state=42, ...)</pre></div></div></div></div></div>"
+      ],
+      "text/plain": [
+       "XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+       "             colsample_bylevel=None, colsample_bynode=None,\n",
+       "             colsample_bytree=0.9, device=None, early_stopping_rounds=None,\n",
+       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
+       "             gamma=None, grow_policy=None, importance_type=None,\n",
+       "             interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
+       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
+       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
+       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+       "             multi_strategy=None, n_estimators=200, n_jobs=None,\n",
+       "             num_parallel_tree=None, random_state=42, ...)"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# Create an XGBoost regressor with the specified parameters; TODO: hyperparameter-tuning\n",
+    "xgb_model = xgb.XGBRegressor(\n",
+    "    objective='reg:squarederror',\n",
+    "    colsample_bytree=0.9,\n",
+    "    learning_rate=0.1,\n",
+    "    max_depth=4,\n",
+    "    n_estimators=200,\n",
+    "    subsample=0.8,\n",
+    "    random_state=42  \n",
+    ")\n",
+    "\n",
+    "# Fit the model to the data\n",
+    "xgb_model.fit(train_df[column_names + descriptors], train_df[\"pIC50\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "4086aff3-6eea-4c1e-8698-5cf7383f24bf",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Train Mean Squared Error: 0.0011\n",
+      "Train Spearman Correlation: 0.9995\n",
+      "Val Mean Squared Error: 0.5704\n",
+      "Val Spearman Correlation: 0.7104\n",
+      "Test Mean Squared Error: 0.61\n",
+      "Test Spearman Correlation: 0.7555\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Calculate and print Mean Squared Error (MSE) and Spearman Correlation\n",
+    "def get_metrics(X_, y_, model, option='Train'):\n",
+    "    # Get training predictions\n",
+    "    y_pred = xgb_model.predict(X_)\n",
+    "    \n",
+    "    # Calculate Mean Squared Error\n",
+    "    mse = np.round(mean_squared_error(y_, y_pred), 4)\n",
+    "    print(f\"{option} Mean Squared Error:\", mse)\n",
+    "\n",
+    "    # Calculate Spearman Correlation\n",
+    "    spearman_corr, _ = np.round(spearmanr(y_, y_pred), 4)\n",
+    "    print(f\"{option} Spearman Correlation:\", spearman_corr)\n",
+    "\n",
+    "# Get metrics for train, validation, and test sets\n",
+    "get_metrics(train_df[column_names + descriptors], train_df[\"pIC50\"], xgb_model, option='Train')\n",
+    "get_metrics(val_df[column_names + descriptors], val_df[\"pIC50\"], xgb_model, option='Val')\n",
+    "get_metrics(test_df[column_names + descriptors], test_df[\"pIC50\"], xgb_model, option='Test')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "8332903c-8d23-4517-9177-940025fd7ae3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Calculate the predictions for the test set\n",
+    "pred = xgb_model.predict(test_df[column_names + descriptors])\n",
+    "actual = test_df['pIC50']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "117d1b8e-bd88-49ea-8545-7f6e0ff2e997",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Scatterplot of predicted vs. actual values \n",
+    "plt.scatter(pred, actual)\n",
+    "plt.plot([min(actual), max(actual)], [min(actual), max(actual)], 'k--')\n",
+    "plt.xlabel('Predicted')\n",
+    "plt.ylabel('Actual')\n",
+    "plt.grid();"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}