463 lines (462 with data), 32.4 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the path to the `xls` file"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"training_file = \"../TrainDataset2024.xls\"\n",
"# training_file = \"/kaggle/input/dataset/TrainDataset2024.xls\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-12-07 16:54:09.391306: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2024-12-07 16:54:09.571207: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1733590449.640280 43324 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1733590449.660207 43324 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-12-07 16:54:09.824402: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"\n",
"# Add the parent directory to the system path\n",
"sys.path.append(os.path.abspath('../')) # Adjust the path as needed\n",
"\n",
"from my_util import df_to_corr_matrix, remove_outliers\n",
"\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import plotly.graph_objects as go\n",
"\n",
"from matplotlib.colors import Normalize\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.model_selection import train_test_split, KFold, cross_val_score, GridSearchCV, cross_val_predict, StratifiedKFold\n",
"from sklearn.preprocessing import StandardScaler, RobustScaler\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, accuracy_score, f1_score, make_scorer, balanced_accuracy_score, r2_score\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.svm import SVC\n",
"from sklearn.feature_selection import SelectKBest, f_classif, chi2, mutual_info_classif\n",
"from sklearn.impute import KNNImputer\n",
"\n",
"\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.pipeline import Pipeline\n",
"\n",
"from joblib import Parallel, delayed\n",
"\n",
"import xgboost as xgb\n",
"from xgboost import XGBClassifier, XGBRegressor\n",
"\n",
"from pickle import dump , load\n",
"\n",
"import warnings"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded '../FeatureSelection/pkl/regression_features_corr_25_selected_features.pkl' to selected_feature\n",
"(400, 25) (400,)\n",
"['original_firstorder_InterquartileRange', 'original_firstorder_Kurtosis', 'TumourStage', 'original_shape_MajorAxisLength', 'original_firstorder_90Percentile', 'ChemoGrade', 'HER2', 'original_shape_Maximum2DDiameterRow', 'original_shape_LeastAxisLength', 'original_shape_Maximum2DDiameterColumn', 'original_glszm_SmallAreaEmphasis', 'Age', 'original_shape_Sphericity', 'original_firstorder_10Percentile', 'original_glszm_SizeZoneNonUniformityNormalized', 'original_gldm_DependenceEntropy', 'original_ngtdm_Busyness', 'original_glcm_Imc1', 'Gene', 'original_gldm_SmallDependenceEmphasis', 'original_glszm_GrayLevelNonUniformityNormalized', 'PgR', 'TrippleNegative', 'original_shape_Elongation', 'original_glcm_Correlation']\n"
]
}
],
"source": [
"NUM_OF_SELECTED_FEATURES = \"regression_features_corr_25\"\n",
"\n",
"data = pd.read_excel(training_file)\n",
"data.replace(999, np.nan, inplace=True)\n",
"\n",
"data.drop([\"ID\", \"pCR (outcome)\"], axis=1, inplace=True)\n",
"data.dropna(subset=[\"RelapseFreeSurvival (outcome)\"], inplace=True)\n",
"\n",
"with open(f'../FeatureSelection/pkl/{NUM_OF_SELECTED_FEATURES}_selected_features.pkl', mode='rb') as file:\n",
" selected_features = load(file)\n",
" print(f\"Loaded '{file.name}' to selected_feature\")\n",
"\n",
"X = data[selected_features]\n",
"y = data[\"RelapseFreeSurvival (outcome)\"]\n",
"print(X.shape, y.shape)\n",
"\n",
"print(selected_features)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Splited the data into train and test. The test will not be used in the training, but just for test the xgb. \n",
"The training data has 320 data. The testing data has 80 data. \n",
"RandomState = 7\n"
]
}
],
"source": [
"import random\n",
"\n",
"randomstate = random.randint(0, 1000)\n",
"randomstate = 7\n",
"X_train_full, X_test_reserved, y_train_full, y_test_reserved = train_test_split(X, y, test_size=0.2, random_state=randomstate) # similar distribution of 1 and 0\n",
"\n",
"X_train_full.reset_index(drop=True, inplace=True)\n",
"X_test_reserved.reset_index(drop=True, inplace=True)\n",
"y_train_full.reset_index(drop=True, inplace=True)\n",
"y_test_reserved.reset_index(drop=True, inplace=True)\n",
"\n",
"\n",
"print(\"Splited the data into train and test. The test will not be used in the training, but just for test the xgb. \")\n",
"print(f\"The training data has {len(X_train_full)} data. The testing data has {len(X_test_reserved)} data. \")\n",
"print(f\"RandomState = {randomstate}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.00499132, 0.00434028, 0.01540799, 0.01236979, 0.01388889,\n",
" 0.00976562, 0.00499132, 0.0015191 , 0.00065104, 0.0015191 ]),\n",
" array([ 0. , 14.4, 28.8, 43.2, 57.6, 72. , 86.4, 100.8, 115.2,\n",
" 129.6, 144. ]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGeCAYAAABsJvAoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwXElEQVR4nO3dfXBUVZ7/8U/MQwcVIoYlMSOERP0tifGBdFwmwYCuThB0lBpKoquRKR1+ZgcMSVaXJ113qdUGZZDFQNi4GWtYV0jtBsbMVlhpR4kgLUpI4gMpmVkzJsakMmGdbpU1T5zfH/7osu0mpAMYcni/qm4Vffp77z3fTpl8PN33doQxxggAAGCUu2CkJwAAAHAmEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACtEjfQEvk/Hjx/XZ599prFjxyoiImKkpwMAAIbAGKMvvvhCSUlJuuCCQdZjzDBs2rTJTJkyxTgcDpOZmWnefPPNQev37NljMjMzjcPhMCkpKaa8vDzg+Q8++MD85Cc/McnJyUaSee6550Ie59NPPzX33XefufTSS82YMWPMddddZw4ePDjkebe1tRlJbGxsbGxsbKNwa2trG/TvfNgrNVVVVSouLtbmzZs1Y8YM/fM//7PmzJmjw4cPa/LkyUH1LS0tmjt3rhYtWqSXXnpJb731ln7+85/rz/7szzR//nxJ0rFjx5Samqq7775bJSUlIc/7+eefa8aMGbr55pu1a9cuTZw4Uf/93/+tSy65ZMhzHzt2rCSpra1N48aNC7d1AAAwAnw+nyZNmuT/O34yEcaE94WW06dPV2ZmpsrLy/1jaWlpmjdvnlwuV1D9smXLVFNTo+bmZv9YYWGhmpqa5PF4guqnTJmi4uJiFRcXB4wvX75cb731lvbu3RvOdAP4fD7FxcXJ6/USagAAGCWG+vc7rA8K9/b2qr6+Xnl5eQHjeXl52r9/f8h9PB5PUP3s2bN18OBB9fX1DfncNTU1ysrK0t13362JEydq2rRpeuGFFwbdp6enRz6fL2ADAAB2CivUdHd3a2BgQAkJCQHjCQkJ6uzsDLlPZ2dnyPr+/n51d3cP+dwff/yxysvLddVVV+nVV19VYWGhioqKtHXr1pPu43K5FBcX598mTZo05PMBAIDRZViXdH/3yiFjzKBXE4WqDzU+mOPHjyszM1NPP/20pk2bpocffliLFi0KeBvsu1asWCGv1+vf2trahnw+AAAwuoQVaiZMmKDIyMigVZmurq6g1ZgTEhMTQ9ZHRUUpPj5+yOe+7LLLlJ6eHjCWlpam1tbWk+7jcDg0bty4gA0AANgprFATExMjp9Mpt9sdMO52u5WTkxNyn+zs7KD63bt3KysrS9HR0UM+94wZM/TRRx8FjB05ckTJyclDPgYAALBX2G8/lZaW6l/+5V/0y1/+Us3NzSopKVFra6sKCwslffOWzwMPPOCvLyws1CeffKLS0lI1Nzfrl7/8pSorK/Xoo4/6a3p7e9XY2KjGxkb19vaqvb1djY2N+v3vf++vKSkp0dtvv62nn35av//97/Xyyy+roqJCixcvPp3+AQCALYZ857pv2bRpk0lOTjYxMTEmMzPT1NXV+Z9buHChmTVrVkD9nj17zLRp00xMTIyZMmVK0M33WlpaQt5k57vH+c1vfmMyMjKMw+EwU6dONRUVFWHN2+v1GknG6/WGtR8AABg5Q/37HfZ9akYz7lMDAMDoc1buUwMAAHCuItQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALBC1EhPAAjLG66RnkH4bl4x0jMAgPMCKzUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFbgkm6MKp6Pj470FMKWffNIzwAAzg+s1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJeP7DDz/U/PnzNWXKFEVERGjDhg2DHs/lcikiIkLFxcXDmT4AALBQ2KGmqqpKxcXFWrVqlRoaGpSbm6s5c+aotbU1ZH1LS4vmzp2r3NxcNTQ0aOXKlSoqKlJ1dbW/5tixY0pNTdWaNWuUmJg46PnfffddVVRU6Nprrw136gAAwGJhh5r169froYce0s9+9jOlpaVpw4YNmjRpksrLy0PWb9myRZMnT9aGDRuUlpamn/3sZ3rwwQe1bt06f80NN9ygZ599Vvfcc48cDsdJz/3ll1/qvvvu0wsvvKDx48efcq49PT3y+XwBGwAAsFNYoaa3t1f19fXKy8sLGM/Ly9P+/ftD7uPxeILqZ8+erYMHD6qvry+syS5evFi33367br311iHVu1wuxcXF+bdJkyaFdT4AADB6hBVquru7NTAwoISEhIDxhIQEdXZ2htyns7MzZH1/f7+6u7uHfO7t27fr0KFDcrlcQ95nxYoV8nq9/q2trW3I+wIAgNElajg7RUREBDw2xgSNnao+1PjJtLW1aenSpdq9e7diY2OHPE+HwzHo21kAAMAeYYWaCRMmKDIyMmhVpqurK2g15oTExMSQ9VFRUYqPjx/Seevr69XV1SWn0+kfGxgY0JtvvqmysjL19PQoMjIynFYAnMobQ18VPWfcvGKkZwBgBIX19lNMTIycTqfcbnfAuNvtVk5OTsh9srOzg+p3796trKwsRUdHD+m8t9xyi95//301Njb6t6ysLN13331qbGwk0AAAgPDffiotLVVBQYGysrKUnZ2tiooKtba2qrCwUNI3n2Npb2/X1q1bJUmFhYUqKytTaWmpFi1aJI/Ho8rKSm3bts1/zN7eXh0+fNj/7/b2djU2Nuriiy/WlVdeqbFjxyojIyNgHhdddJHi4+ODxgEAwPkp7FCTn5+vo0ePavXq1ero6FBGRoZqa2uVnJwsSero6Ai4Z01KSopqa2tVUlKiTZs2KSkpSRs3btT8+fP9NZ999pmmTZvmf7xu3TqtW7dOs2bN0p49e06jPQAAcL6IMCc+tXse8Pl8iouLk9fr1bhx40Z6OhgGT+WjIz2FsGU/tO7URecgXmsA54qh/v3mu58AAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWiBrpCQC2e859ZKSnMCw/HOkJAECYWKkBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACl3QDZ9kPWytGegoAcF5gpQYAAFiBUAMAAKxAqAEAAFYYVqjZvHmzUlJSFBsbK6fTqb179w5aX1dXJ6fTqdjYWKWmpmrLli0Bz3/44YeaP3++pkyZooiICG3YsCHoGC6XSzfccIPGjh2riRMnat68efroo4+GM30AAGChsENNVVWViouLtWrVKjU0NCg3N1dz5sxRa2tryPqWlhbNnTtXubm5amho0MqVK1VUVKTq6mp/zbFjx5Samqo1a9YoMTEx5HHq6uq0ePFivf3223K73erv71deXp6++uqrcFsAAAAWijDGmHB2mD59ujIzM1VeXu4fS0tL07x58+RyuYLqly1bppqaGjU3N/vHCgsL1dTUJI/HE1Q/ZcoUFRcXq7i4eNB5/PGPf9TEiRNVV1enmTNnDmnuPp9PcXFx8nq9Gjdu3JD2wbnFU/noSE8B57Dsh9aN9BQAnAVD/fsd1kpNb2+v6uvrlZeXFzCel5en/fv3h9zH4/EE1c+ePVsHDx5UX19fOKcP4PV6JUmXXnrpSWt6enrk8/kCNgAAYKewQk13d7cGBgaUkJAQMJ6QkKDOzs6Q+3R2doas7+/vV3d3d5jT/YYxRqWlpbrxxhuVkZFx0jqXy6W4uDj/NmnSpGGdDwAAnPuG9UHhiIiIgMfGmKCxU9WHGh+qJUuW6L333tO2bdsGrVuxYoW8Xq9/a2trG9b5AADAuS+sOwpPmDBBkZGRQasyXV1dQasxJyQmJoasj4qKUnx8fJjTlR555BHV1NTozTff1OWXXz5orcPhkMPhCPscAABg9AlrpSYmJkZOp1Nutztg3O12KycnJ+Q+2dnZQfW7d+9WVlaWoqOjh3xuY4yWLFmiHTt26PXXX1dKSko4UwcAAJYL+7ufSktLVVBQoKysLGVnZ6uiokKtra0qLCyU9M1bPu3t7dq6daukb650KisrU2lpqRYtWiSPx6PKysqAt456e3t1+PBh/7/b29vV2Nioiy++WFdeeaUkafHixXr55Zf1yiuvaOzYsf7Vn7i4OI0ZM+b0XgUAADDqhR1q8vPzdfToUa1evVodHR3KyMhQbW2tkpOTJUkdHR0B96xJSUlRbW2tSkpKtGnTJiUlJWnjxo2aP3++v+azzz7TtGnT/I/XrVundevWadasWdqzZ48k+S8hv+mmmwLm8+KLL+qnP/1puG0AAADLhH2fmtGM+9SMftynBoPhPjWAnc7KfWoAAADOVYQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJeP7DDz/U/PnzNWXKFEVERGjDhg1n5LwAAOD8EXaoqaqqUnFxsVatWqWGhgbl5uZqzpw5am1tDVnf0tKiuXPnKjc3Vw0NDVq5cqWKiopUXV3trzl27JhSU1O1Zs0aJSYmnpHzAgCA80uEMcaEs8P06dOVmZmp8vJy/1haWprmzZsnl8sVVL9s2TLV1NSoubnZP1ZYWKimpiZ5PJ6g+ilTpqi4uFjFxcWndd5QfD6f4uLi5PV6NW7cuCHtg3OLp/LRkZ4CzmHZD60b6SkAOAuG+vc7rJWa3t5e1dfXKy8vL2A8Ly9P+/fvD7mPx+MJqp89e7YOHjyovr6+s3ZeSerp6ZHP5wvYAACAncIKNd3d3RoYGFBCQkLAeEJCgjo7O0Pu09nZGbK+v79f3d3dZ+28kuRyuRQXF+ffJk2aNKTzAQCA0WdYHxSOiIgIeGyMCRo7VX2o8TN93hUrVsjr9fq3tra2sM4HAABGj6hwiidMmKDIyMig1ZGurq6gVZQTEhMTQ9ZHRUUpPj7+rJ1XkhwOhxwOx5DOAQAARrewVmpiYmLkdDrldrsDxt1ut3JyckLuk52dHVS/e/duZWVlKTo6+qydFwAAnF/CWqmRpNLSUhUUFCgrK0vZ2dmqqKhQa2urCgsLJX3zlk97e7u2bt0q6ZsrncrKylRaWqpFixbJ4/GosrJS27Zt8x+zt7dXhw8f9v+7vb1djY2Nuvjii3XllVcO6bwAAOD8Fnaoyc/P19GjR7V69Wp1dHQoIyNDtbW1Sk5OliR1dHQE3DsmJSVFtbW1Kikp0aZNm5SUlKSNGzdq/vz5/prPPvtM06ZN8z9et26d1q1bp1mzZmnPnj1DOi8A6I2h3d7hnHLzipGeAWCNsO9TM5pxn5rRj/vUwDbcWwc4tbNynxoAAIBzFaEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKwwrFCzefNmpaSkKDY2Vk6nU3v37h20vq6uTk6nU7GxsUpNTdWWLVuCaqqrq5Weni6Hw6H09HTt3Lkz4Pn+/n49/vjjSklJ0ZgxY5SamqrVq1fr+PHjw2kBAABYJuxQU1VVpeLiYq1atUoNDQ3Kzc3VnDlz1NraGrK+paVFc+fOVW5urhoaGrRy5UoVFRWpurraX+PxeJSfn6+CggI1NTWpoKBACxYs0IEDB/w1a9eu1ZYtW1RWVqbm5mY988wzevbZZ/X8888Po20AAGCbCGOMCWeH6dOnKzMzU+Xl5f6xtLQ0zZs3Ty6XK6h+2bJlqqmpUXNzs3+ssLBQTU1N8ng8kqT8/Hz5fD7t2rXLX3Pbbbdp/Pjx2rZtmyTpjjvuUEJCgiorK/018+fP14UXXqh//dd/HdLcfT6f4uLi5PV6NW7cuHDaxjnCU/noSE8BOKOyH1o30lMAznlD/fsd1kpNb2+v6uvrlZeXFzCel5en/fv3h9zH4/EE1c+ePVsHDx5UX1/foDXfPuaNN96o3/72tzpy5IgkqampSfv27dPcuXNPOt+enh75fL6ADQAA2CkqnOLu7m4NDAwoISEhYDwhIUGdnZ0h9+ns7AxZ39/fr+7ubl122WUnrfn2MZctWyav16upU6cqMjJSAwMDeuqpp3TvvfeedL4ul0v/8A//EE6LAABglBrWB4UjIiICHhtjgsZOVf/d8VMds6qqSi+99JJefvllHTp0SL/61a+0bt06/epXvzrpeVesWCGv1+vf2traTt0cAAAYlcJaqZkwYYIiIyODVmW6urqCVlpOSExMDFkfFRWl+Pj4QWu+fczHHntMy5cv1z333CNJuuaaa/TJJ5/I5XJp4cKFIc/tcDjkcDjCaREAAIxSYa3UxMTEyOl0yu12B4y73W7l5OSE3Cc7Ozuofvfu3crKylJ0dPSgNd8+5rFjx3TBBYHTjYyM5JJuAAAgKcyVGkkqLS1VQUGBsrKylJ2drYqKCrW2tqqwsFDSN2/5tLe3a+vWrZK+udKprKxMpaWlWrRokTwejyorK/1XNUnS0qVLNXPmTK1du1Z33XWXXnnlFb322mvat2+fv+bHP/6xnnrqKU2ePFlXX321GhoatH79ej344IOn+xoAAAALhB1q8vPzdfToUa1evVodHR3KyMhQbW2tkpOTJUkdHR0B96xJSUlRbW2tSkpKtGnTJiUlJWnjxo2aP3++vyYnJ0fbt2/X448/rieeeEJXXHGFqqqqNH36dH/N888/ryeeeEI///nP1dXVpaSkJD388MP6u7/7u9PpHwAAWCLs+9SMZtynZvTjPjWwDfepAU7trNynBgAA4FxFqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAVhhWqNm8ebNSUlIUGxsrp9OpvXv3DlpfV1cnp9Op2NhYpaamasuWLUE11dXVSk9Pl8PhUHp6unbu3BlU097ervvvv1/x8fG68MILdf3116u+vn44LQAAAMuEHWqqqqpUXFysVatWqaGhQbm5uZozZ45aW1tD1re0tGju3LnKzc1VQ0ODVq5cqaKiIlVXV/trPB6P8vPzVVBQoKamJhUUFGjBggU6cOCAv+bzzz/XjBkzFB0drV27dunw4cP6xS9+oUsuuST8rgEAgHUijDEmnB2mT5+uzMxMlZeX+8fS0tI0b948uVyuoPply5appqZGzc3N/rHCwkI1NTXJ4/FIkvLz8+Xz+bRr1y5/zW233abx48dr27ZtkqTly5frrbfeOuWq0GB8Pp/i4uLk9Xo1bty4YR8HI8dT+ehITwE4o7IfWjfSUwDOeUP9+x3WSk1vb6/q6+uVl5cXMJ6Xl6f9+/eH3Mfj8QTVz549WwcPHlRfX9+gNd8+Zk1NjbKysnT33Xdr4sSJmjZtml544YVB59vT0yOfzxewAQAAO4UVarq7uzUwMKCEhISA8YSEBHV2dobcp7OzM2R9f3+/uru7B6359jE//vhjlZeX66qrrtKrr76qwsJCFRUVaevWrSedr8vlUlxcnH+bNGlSOO0CAIBRZFgfFI6IiAh4bIwJGjtV/XfHT3XM48ePKzMzU08//bSmTZumhx9+WIsWLQp4G+y7VqxYIa/X69/a2tpO3RwAABiVwgo1EyZMUGRkZNCqTFdXV9BKywmJiYkh66OiohQfHz9ozbePedlllyk9PT2gJi0t7aQfUJYkh8OhcePGBWwAAMBOYYWamJgYOZ1Oud3ugHG3262cnJyQ+2RnZwfV7969W1lZWYqOjh605tvHnDFjhj766KOAmiNHjig5OTmcFgAAgKWiwt2htLRUBQUFysrKUnZ2tioqKtTa2qrCwkJJ37zl097e7v+sS2FhocrKylRaWqpFixbJ4/GosrLSf1WTJC1dulQzZ87U2rVrddddd+mVV17Ra6+9pn379vlrSkpKlJOTo6effloLFizQO++8o4qKClVUVJzuawAAACwQdqjJz8/X0aNHtXr1anV0dCgjI0O1tbX+FZOOjo6At4RSUlJUW1urkpISbdq0SUlJSdq4caPmz5/vr8nJydH27dv1+OOP64knntAVV1yhqqoqTZ8+3V9zww03aOfOnVqxYoVWr16tlJQUbdiwQffdd9/p9A8AACwR9n1qRjPuUzP6cZ8a2Ib71ACndlbuUwMAAHCuItQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBXC/kJLAMCZ85z7yEhPIWwlP/o/Iz0FICRWagAAgBUINQAAwAqEGgAAYAVCDQAAsAIfFD5T3nCN9AzCd/OKkZ4BcN77YWvFSE9hGNaN9ASAkFipAQAAViDUAAAAK/D20xni+fjoSE8hbNk3j/QMAAA4c1ipAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAV+O6n85in8tGRngIAAGcMKzUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJqqmurlZ6erocDofS09O1c+fOkx7P5XIpIiJCxcXFw5k+AACwUNihpqqqSsXFxVq1apUaGhqUm5urOXPmqLW1NWR9S0uL5s6dq9zcXDU0NGjlypUqKipSdXW1v8bj8Sg/P18FBQVqampSQUGBFixYoAMHDgQd791331VFRYWuvfbacKcOAAAsFmGMMeHsMH36dGVmZqq8vNw/lpaWpnnz5snlcgXVL1u2TDU1NWpubvaPFRYWqqmpSR6PR5KUn58vn8+nXbt2+Wtuu+02jR8/Xtu2bfOPffnll8rMzNTmzZv1j//4j7r++uu1YcOGIc/d5/MpLi5OXq9X48aNC6ftU+JKIgDni+yH1o30FHCeGerf77BWanp7e1VfX6+8vLyA8by8PO3fvz/kPh6PJ6h+9uzZOnjwoPr6+gat+e4xFy9erNtvv1233nrrkObb09Mjn88XsAEAADuFFWq6u7s1MDCghISEgPGEhAR1dnaG3KezszNkfX9/v7q7uwet+fYxt2/frkOHDoVcDToZl8uluLg4/zZp0qQh7wsAAEaXYX1QOCIiIuCxMSZo7FT13x0f7JhtbW1aunSpXnrpJcXGxg55nitWrJDX6/VvbW1tQ94XAACMLmHdUXjChAmKjIwMWpXp6uoKWmk5ITExMWR9VFSU4uPjB605ccz6+np1dXXJ6XT6nx8YGNCbb76psrIy9fT0KDIyMujcDodDDocjnBYBAMAoFdZKTUxMjJxOp9xud8C42+1WTk5OyH2ys7OD6nfv3q2srCxFR0cPWnPimLfccovef/99NTY2+resrCzdd999amxsDBloAADA+SXs734qLS1VQUGBsrKylJ2drYqKCrW2tqqwsFDSN2/5tLe3a+vWrZK+udKprKxMpaWlWrRokTwejyorKwOualq6dKlmzpyptWvX6q677tIrr7yi1157Tfv27ZMkjR07VhkZGQHzuOiiixQfHx80DgAAzk9hh5r8/HwdPXpUq1evVkdHhzIyMlRbW6vk5GRJUkdHR8A9a1JSUlRbW6uSkhJt2rRJSUlJ2rhxo+bPn++vycnJ0fbt2/X444/riSee0BVXXKGqqipNnz79DLQIAADOB2Hfp2Y04z41AHD6uE8Nvm9n5T41AAAA5ypCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYYVihZvPmzUpJSVFsbKycTqf27t07aH1dXZ2cTqdiY2OVmpqqLVu2BNVUV1crPT1dDodD6enp2rlzZ8DzLpdLN9xwg8aOHauJEydq3rx5+uijj4YzfQAAYKGwQ01VVZWKi4u1atUqNTQ0KDc3V3PmzFFra2vI+paWFs2dO1e5ublqaGjQypUrVVRUpOrqan+Nx+NRfn6+CgoK1NTUpIKCAi1YsEAHDhzw19TV1Wnx4sV6++235Xa71d/fr7y8PH311VfDaBsAANgmwhhjwtlh+vTpyszMVHl5uX8sLS1N8+bNk8vlCqpftmyZampq1Nzc7B8rLCxUU1OTPB6PJCk/P18+n0+7du3y19x2220aP368tm3bFnIef/zjHzVx4kTV1dVp5syZQ5q7z+dTXFycvF6vxo0bN6R9hspT+egZPR4AnKuyH1o30lPAeWaof7/DWqnp7e1VfX298vLyAsbz8vK0f//+kPt4PJ6g+tmzZ+vgwYPq6+sbtOZkx5Qkr9crSbr00ktPWtPT0yOfzxewAQAAO4UVarq7uzUwMKCEhISA8YSEBHV2dobcp7OzM2R9f3+/uru7B6052TGNMSotLdWNN96ojIyMk87X5XIpLi7Ov02aNOmUPQIAgNFpWB8UjoiICHhsjAkaO1X9d8fDOeaSJUv03nvvnfStqRNWrFghr9fr39ra2gatBwAAo1dUOMUTJkxQZGRk0ApKV1dX0ErLCYmJiSHro6KiFB8fP2hNqGM+8sgjqqmp0ZtvvqnLL7980Pk6HA45HI5T9gUAGLrn3EdGegphK/nR/xnpKeB7ENZKTUxMjJxOp9xud8C42+1WTk5OyH2ys7OD6nfv3q2srCxFR0cPWvPtYxpjtGTJEu3YsUOvv/66UlJSwpk6AACwXFgrNZJUWlqqgoICZWVlKTs7WxUVFWptbVVhYaGkb97yaW9v19atWyV9c6VTWVmZSktLtWjRInk8HlVWVga8dbR06VLNnDlTa9eu1V133aVXXnlFr732mvbt2+evWbx4sV5++WW98sorGjt2rH9lJy4uTmPGjDmtFwEAAIx+YYea/Px8HT16VKtXr1ZHR4cyMjJUW1ur5ORkSVJHR0fAPWtSUlJUW1urkpISbdq0SUlJSdq4caPmz5/vr8nJydH27dv1+OOP64knntAVV1yhqqoqTZ8+3V9z4hLym266KWA+L774on7605+G2wYAALBM2PepGc24Tw0AnL63J//fkZ5C2PhMzeh2Vu5TAwAAcK4i1AAAACsQagAAgBUINQAAwAphX/0EADi//bC1YqSnMAyj8Es43wj+kuhz3s0rRvT0rNQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBS7oBADgHeT4+OtJTCFv2zSN7flZqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKfE0CAMB6z7mPjPQUwvbDkZ7AKESoAQBY74etFSM9BXwPePsJAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYYVqjZvHmzUlJSFBsbK6fTqb179w5aX1dXJ6fTqdjYWKWmpmrLli1BNdXV1UpPT5fD4VB6erp27tx52ucFAADnj7BDTVVVlYqLi7Vq1So1NDQoNzdXc+bMUWtra8j6lpYWzZ07V7m5uWpoaNDKlStVVFSk6upqf43H41F+fr4KCgrU1NSkgoICLViwQAcOHBj2eQEAwPklwhhjwtlh+vTpyszMVHl5uX8sLS1N8+bNk8vlCqpftmyZampq1Nzc7B8rLCxUU1OTPB6PJCk/P18+n0+7du3y19x2220aP368tm3bNqzzSlJPT496enr8j71eryZPnqy2tjaNGzcunLZP6Z2tq87o8QAAGG3+4oGnzspxfT6fJk2apD/96U+Ki4s7eaEJQ09Pj4mMjDQ7duwIGC8qKjIzZ84MuU9ubq4pKioKGNuxY4eJiooyvb29xhhjJk2aZNavXx9Qs379ejN58uRhn9cYY5588kkjiY2NjY2Njc2Cra2tbdCcEqUwdHd3a2BgQAkJCQHjCQkJ6uzsDLlPZ2dnyPr+/n51d3frsssuO2nNiWMO57yStGLFCpWWlvofHz9+XP/zP/+j+Ph4RUREnLrhITqRIM/GCtC57nzt/XztW6L387H387Vvid7Pld6NMfriiy+UlJQ0aF1YoeaE7wYCY8ygISFU/XfHh3LMcM/rcDjkcDgCxi655JKT1p+ucePGjfgPfqScr72fr31L9H4+9n6+9i3R+7nQ+6BvO/1/YX1QeMKECYqMjAxaHenq6gpaRTkhMTExZH1UVJTi4+MHrTlxzOGcFwAAnF/CCjUxMTFyOp1yu90B4263Wzk5OSH3yc7ODqrfvXu3srKyFB0dPWjNiWMO57wAAOA8M+gnbkLYvn27iY6ONpWVlebw4cOmuLjYXHTRReYPf/iDMcaY5cuXm4KCAn/9xx9/bC688EJTUlJiDh8+bCorK010dLT5j//4D3/NW2+9ZSIjI82aNWtMc3OzWbNmjYmKijJvv/32kM87kr7++mvz5JNPmq+//nqkp/K9O197P1/7Nobez8fez9e+jaH30dZ72KHGGGM2bdpkkpOTTUxMjMnMzDR1dXX+5xYuXGhmzZoVUL9nzx4zbdo0ExMTY6ZMmWLKy8uDjvnv//7v5s///M9NdHS0mTp1qqmurg7rvAAA4PwW9n1qAAAAzkV89xMAALACoQYAAFiBUAMAAKxAqAEAAFYg1JwBmzdvVkpKimJjY+V0OrV3796RntIZ5XK5dMMNN2js2LGaOHGi5s2bp48++iigxhijv//7v1dSUpLGjBmjm266SR9++OEIzfjscLlcioiIUHFxsX/M5r7b29t1//33Kz4+XhdeeKGuv/561dfX+5+3tff+/n49/vjjSklJ0ZgxY5SamqrVq1fr+PHj/hpben/zzTf14x//WElJSYqIiNCvf/3rgOeH0mdPT48eeeQRTZgwQRdddJHuvPNOffrpp99jF+EbrO++vj4tW7ZM11xzjS666CIlJSXpgQce0GeffRZwjNHYt3Tqn/m3Pfzww4qIiNCGDRsCxs/l3gk1p6mqqkrFxcVatWqVGhoalJubqzlz5qi1tXWkp3bG1NXVafHixXr77bfldrvV39+vvLw8ffXVV/6aZ555RuvXr1dZWZneffddJSYm6kc/+pG++OKLEZz5mfPuu++qoqJC1157bcC4rX1//vnnmjFjhqKjo7Vr1y4dPnxYv/jFLwK+ZsTW3teuXastW7aorKxMzc3NeuaZZ/Tss8/q+eef99fY0vtXX32l6667TmVlZSGfH0qfxcXF2rlzp7Zv3659+/bpyy+/1B133KGBgYHvq42wDdb3sWPHdOjQIT3xxBM6dOiQduzYoSNHjujOO+8MqBuNfUun/pmf8Otf/1oHDhwI+V1L53TvI3g5uRX+4i/+whQWFgaMTZ061SxfvnyEZnT2dXV1GUn++wQdP37cJCYmmjVr1vhrvv76axMXF2e2bNkyUtM8Y7744gtz1VVXGbfbbWbNmmWWLl1qjLG772XLlpkbb7zxpM/b3Pvtt99uHnzwwYCxn/zkJ+b+++83xtjbuySzc+dO/+Oh9PmnP/3JREdHm+3bt/tr2tvbzQUXXGD+67/+63ub++n4bt+hvPPOO0aS+eSTT4wxdvRtzMl7//TTT80PfvAD88EHH5jk5GTz3HPP+Z8713tnpeY09Pb2qr6+Xnl5eQHjeXl52r9//wjN6uzzer2SpEsvvVSS1NLSos7OzoDXweFwaNasWVa8DosXL9btt9+uW2+9NWDc5r5ramqUlZWlu+++WxMnTtS0adP0wgsv+J+3ufcbb7xRv/3tb3XkyBFJUlNTk/bt26e5c+dKsrv3bxtKn/X19err6wuoSUpKUkZGhlWvhdfrVUREhH+l0ua+jx8/roKCAj322GO6+uqrg54/13sf1rd04xvd3d0aGBgI+lLNhISEoC/ftIUxRqWlpbrxxhuVkZEhSf5eQ70On3zyyfc+xzNp+/btOnTokN59992g52zu++OPP1Z5eblKS0u1cuVKvfPOOyoqKpLD4dADDzxgde/Lli2T1+vV1KlTFRkZqYGBAT311FO69957Jdn9c/+2ofTZ2dmpmJgYjR8/PqjGlt+BX3/9tZYvX66/+qu/8n9Ttc19r127VlFRUSoqKgr5/LneO6HmDIiIiAh4bIwJGrPFkiVL9N5772nfvn1Bz9n2OrS1tWnp0qXavXu3YmNjT1pnW9/SN/+3lpWVpaefflqSNG3aNH344YcqLy/XAw884K+zsfeqqiq99NJLevnll3X11VersbFRxcXFSkpK0sKFC/11NvYeynD6tOW16Ovr0z333KPjx49r8+bNp6wf7X3X19frn/7pn3To0KGw+zhXeuftp9MwYcIERUZGBqXTrq6uoP+7scEjjzyimpoavfHGG7r88sv944mJiZJk3etQX1+vrq4uOZ1ORUVFKSoqSnV1ddq4caOioqL8vdnWtyRddtllSk9PDxhLS0vzfwDe1p+5JD322GNavny57rnnHl1zzTUqKChQSUmJXC6XJLt7/7ah9JmYmKje3l59/vnnJ60Zrfr6+rRgwQK1tLTI7Xb7V2kke/veu3evurq6NHnyZP/vvE8++UR/8zd/oylTpkg693sn1JyGmJgYOZ1Oud3ugHG3262cnJwRmtWZZ4zRkiVLtGPHDr3++utKSUkJeD4lJUWJiYkBr0Nvb6/q6upG9etwyy236P3331djY6N/y8rK0n333afGxkalpqZa2bckzZgxI+iy/SNHjig5OVmSvT9z6ZurXy64IPBXY2RkpP+Sbpt7/7ah9Ol0OhUdHR1Q09HRoQ8++GBUvxYnAs3vfvc7vfbaa4qPjw943ta+CwoK9N577wX8zktKStJjjz2mV199VdIo6H2EPqBsje3bt5vo6GhTWVlpDh8+bIqLi81FF11k/vCHP4z01M6Yv/7rvzZxcXFmz549pqOjw78dO3bMX7NmzRoTFxdnduzYYd5//31z7733mssuu8z4fL4RnPmZ9+2rn4yxt+933nnHREVFmaeeesr87ne/M//2b/9mLrzwQvPSSy/5a2ztfeHCheYHP/iB+c///E/T0tJiduzYYSZMmGD+9m//1l9jS+9ffPGFaWhoMA0NDUaSWb9+vWloaPBf5TOUPgsLC83ll19uXnvtNXPo0CHzl3/5l+a6664z/f39I9XWKQ3Wd19fn7nzzjvN5ZdfbhobGwN+5/X09PiPMRr7NubUP/Pv+u7VT8ac270Tas6ATZs2meTkZBMTE2MyMzP9lzrbQlLI7cUXX/TXHD9+3Dz55JMmMTHROBwOM3PmTPP++++P3KTPku+GGpv7/s1vfmMyMjKMw+EwU6dONRUVFQHP29q7z+czS5cuNZMnTzaxsbEmNTXVrFq1KuAPmi29v/HGGyH/2164cKExZmh9/u///q9ZsmSJufTSS82YMWPMHXfcYVpbW0egm6EbrO+WlpaT/s574403/McYjX0bc+qf+XeFCjXncu8RxhjzfawIAQAAnE18pgYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAVvh//ay80P2gzYQAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(y, density=True, alpha=0.5)\n",
"plt.hist(y_train_full, density=True, alpha=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 69120 candidates, totalling 345600 fits\n",
"Best parameter: {'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100}\n",
"Best score: -20.039419563611347\n"
]
}
],
"source": [
"model = XGBRegressor(objective=\"reg:absoluteerror\")\n",
"\n",
"param_grid = {\n",
" \"gamma\": [0, 0.1, 0.2],\n",
" \"learning_rate\": [0.01, 0.1, 0.2, 0.3],\n",
" \"max_bin\": [2, 3, 4, 5, 10],\n",
" \"max_delta_step\": [0, 1, 2],\n",
" \"max_depth\": [1, 2, 4, 6],\n",
" \"max_leaves\": [0, 1, 2, 3, 4, 5],\n",
" \"min_child_weight\": [0.001, 0.01, 0.1, 0.5],\n",
" \"n_estimators\": [10, 50, 100, 200],\n",
"}\n",
"\n",
"# Set up the GridSearchCV\n",
"grid_search = GridSearchCV(\n",
" estimator=model,\n",
" param_grid=param_grid,\n",
" scoring='neg_mean_absolute_error', \n",
" cv=5,\n",
" verbose=1,\n",
" n_jobs=-1,\n",
" return_train_score=True,\n",
")\n",
"\n",
"grid_search.fit(X_train_full, y_train_full)\n",
"\n",
"print(f\"Best parameter: {grid_search.best_params_}\")\n",
"print(f\"Best score: {grid_search.best_score_}\")"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"pd.DataFrame(grid_search.cv_results_).to_csv(\"output.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100}\n",
"-20.039419563611347\n",
"20.508699798583983\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>43.000000</td>\n",
" <td>37.987656</td>\n",
" <td>5.012344</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>9.000000</td>\n",
" <td>53.990318</td>\n",
" <td>-44.990318</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>73.000000</td>\n",
" <td>59.660213</td>\n",
" <td>13.339787</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>16.000000</td>\n",
" <td>53.314426</td>\n",
" <td>-37.314426</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>59.000000</td>\n",
" <td>60.072800</td>\n",
" <td>-1.072800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75</th>\n",
" <td>53.000000</td>\n",
" <td>55.507435</td>\n",
" <td>-2.507435</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76</th>\n",
" <td>93.000000</td>\n",
" <td>59.416786</td>\n",
" <td>33.583214</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77</th>\n",
" <td>82.416667</td>\n",
" <td>47.146587</td>\n",
" <td>35.270079</td>\n",
" </tr>\n",
" <tr>\n",
" <th>78</th>\n",
" <td>89.000000</td>\n",
" <td>57.640259</td>\n",
" <td>31.359741</td>\n",
" </tr>\n",
" <tr>\n",
" <th>79</th>\n",
" <td>88.000000</td>\n",
" <td>54.131992</td>\n",
" <td>33.868008</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>80 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2\n",
"0 43.000000 37.987656 5.012344\n",
"1 9.000000 53.990318 -44.990318\n",
"2 73.000000 59.660213 13.339787\n",
"3 16.000000 53.314426 -37.314426\n",
"4 59.000000 60.072800 -1.072800\n",
".. ... ... ...\n",
"75 53.000000 55.507435 -2.507435\n",
"76 93.000000 59.416786 33.583214\n",
"77 82.416667 47.146587 35.270079\n",
"78 89.000000 57.640259 31.359741\n",
"79 88.000000 54.131992 33.868008\n",
"\n",
"[80 rows x 3 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(grid_search.best_params_)\n",
"print(grid_search.best_score_)\n",
"\n",
"model = grid_search.best_estimator_\n",
"\n",
"y_pred = model.predict(X_test_reserved)\n",
"\n",
"print(mean_absolute_error(y_test_reserved, y_pred))\n",
"\n",
"l1 = np.array(list(y_test_reserved))\n",
"l2 = np.array(list(y_pred))\n",
"l3 = l1 - l2\n",
"\n",
"display(pd.DataFrame([l1, l2, l3]).T)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CV MAE and R2\n",
"20.039419563611347\n",
"0.05608801676895068\n",
"\n",
"Training set MAE and R2\n",
"17.971198264757792\n",
"0.1730067350076071\n",
"\n",
"Test set MAE and R2\n",
"20.508699798583983\n",
"0.029032418854204156\n"
]
}
],
"source": [
"param = {'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100, \"objective\":\"reg:absoluteerror\"}\n",
"\n",
"model = XGBRegressor(**param)\n",
"\n",
"model.fit(X_train_full, y_train_full)\n",
"\n",
"print(\"CV MAE and R2\")\n",
"print(-np.mean(cross_val_score(model, X_train_full, y_train_full, scoring='neg_mean_absolute_error')))\n",
"print(np.mean(cross_val_score(model, X_train_full, y_train_full, scoring='r2')))\n",
"\n",
"y_pred_train = model.predict(X_train_full)\n",
"print(\"\\nTraining set MAE and R2\")\n",
"print(mean_absolute_error(y_train_full, y_pred_train))\n",
"print(r2_score(y_train_full, y_pred_train))\n",
"\n",
"y_pred_test = model.predict(X_test_reserved)\n",
"print(\"\\nTest set MAE and R2\")\n",
"print(mean_absolute_error(y_test_reserved, y_pred_test))\n",
"print(r2_score(y_test_reserved, y_pred_test))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "MLEAsm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}