[e6e569]: / XGBRegression / main.ipynb

Download this file

463 lines (462 with data), 32.4 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the path to the `xls` file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_file = \"../TrainDataset2024.xls\"\n",
    "# training_file = \"/kaggle/input/dataset/TrainDataset2024.xls\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-12-07 16:54:09.391306: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-12-07 16:54:09.571207: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1733590449.640280   43324 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1733590449.660207   43324 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-12-07 16:54:09.824402: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "# Add the parent directory to the system path\n",
    "sys.path.append(os.path.abspath('../'))  # Adjust the path as needed\n",
    "\n",
    "from my_util import df_to_corr_matrix, remove_outliers\n",
    "\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "from matplotlib.colors import Normalize\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.model_selection import train_test_split, KFold, cross_val_score, GridSearchCV, cross_val_predict, StratifiedKFold\n",
    "from sklearn.preprocessing import StandardScaler, RobustScaler\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, accuracy_score, f1_score, make_scorer, balanced_accuracy_score, r2_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.feature_selection import SelectKBest, f_classif, chi2, mutual_info_classif\n",
    "from sklearn.impute import KNNImputer\n",
    "\n",
    "\n",
    "from imblearn.over_sampling import SMOTE\n",
    "from imblearn.pipeline import Pipeline\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "import xgboost as xgb\n",
    "from xgboost import XGBClassifier, XGBRegressor\n",
    "\n",
    "from pickle import dump , load\n",
    "\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded '../FeatureSelection/pkl/regression_features_corr_25_selected_features.pkl' to selected_feature\n",
      "(400, 25) (400,)\n",
      "['original_firstorder_InterquartileRange', 'original_firstorder_Kurtosis', 'TumourStage', 'original_shape_MajorAxisLength', 'original_firstorder_90Percentile', 'ChemoGrade', 'HER2', 'original_shape_Maximum2DDiameterRow', 'original_shape_LeastAxisLength', 'original_shape_Maximum2DDiameterColumn', 'original_glszm_SmallAreaEmphasis', 'Age', 'original_shape_Sphericity', 'original_firstorder_10Percentile', 'original_glszm_SizeZoneNonUniformityNormalized', 'original_gldm_DependenceEntropy', 'original_ngtdm_Busyness', 'original_glcm_Imc1', 'Gene', 'original_gldm_SmallDependenceEmphasis', 'original_glszm_GrayLevelNonUniformityNormalized', 'PgR', 'TrippleNegative', 'original_shape_Elongation', 'original_glcm_Correlation']\n"
     ]
    }
   ],
   "source": [
    "NUM_OF_SELECTED_FEATURES = \"regression_features_corr_25\"\n",
    "\n",
    "data = pd.read_excel(training_file)\n",
    "data.replace(999, np.nan, inplace=True)\n",
    "\n",
    "data.drop([\"ID\", \"pCR (outcome)\"], axis=1, inplace=True)\n",
    "data.dropna(subset=[\"RelapseFreeSurvival (outcome)\"], inplace=True)\n",
    "\n",
    "with open(f'../FeatureSelection/pkl/{NUM_OF_SELECTED_FEATURES}_selected_features.pkl', mode='rb') as file:\n",
    "    selected_features = load(file)\n",
    "    print(f\"Loaded '{file.name}' to selected_feature\")\n",
    "\n",
    "X = data[selected_features]\n",
    "y = data[\"RelapseFreeSurvival (outcome)\"]\n",
    "print(X.shape, y.shape)\n",
    "\n",
    "print(selected_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Splited the data into train and test. The test will not be used in the training, but just for test the xgb. \n",
      "The training data has 320 data. The testing data has 80 data. \n",
      "RandomState = 7\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "randomstate = random.randint(0, 1000)\n",
    "randomstate = 7\n",
    "X_train_full, X_test_reserved, y_train_full, y_test_reserved = train_test_split(X, y, test_size=0.2, random_state=randomstate) # similar distribution of 1 and 0\n",
    "\n",
    "X_train_full.reset_index(drop=True, inplace=True)\n",
    "X_test_reserved.reset_index(drop=True, inplace=True)\n",
    "y_train_full.reset_index(drop=True, inplace=True)\n",
    "y_test_reserved.reset_index(drop=True, inplace=True)\n",
    "\n",
    "\n",
    "print(\"Splited the data into train and test. The test will not be used in the training, but just for test the xgb. \")\n",
    "print(f\"The training data has {len(X_train_full)} data. The testing data has {len(X_test_reserved)} data. \")\n",
    "print(f\"RandomState = {randomstate}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0.00499132, 0.00434028, 0.01540799, 0.01236979, 0.01388889,\n",
       "        0.00976562, 0.00499132, 0.0015191 , 0.00065104, 0.0015191 ]),\n",
       " array([  0. ,  14.4,  28.8,  43.2,  57.6,  72. ,  86.4, 100.8, 115.2,\n",
       "        129.6, 144. ]),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGeCAYAAABsJvAoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwXElEQVR4nO3dfXBUVZ7/8U/MQwcVIoYlMSOERP0tifGBdFwmwYCuThB0lBpKoquRKR1+ZgcMSVaXJ113qdUGZZDFQNi4GWtYV0jtBsbMVlhpR4kgLUpI4gMpmVkzJsakMmGdbpU1T5zfH/7osu0mpAMYcni/qm4Vffp77z3fTpl8PN33doQxxggAAGCUu2CkJwAAAHAmEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACtEjfQEvk/Hjx/XZ599prFjxyoiImKkpwMAAIbAGKMvvvhCSUlJuuCCQdZjzDBs2rTJTJkyxTgcDpOZmWnefPPNQev37NljMjMzjcPhMCkpKaa8vDzg+Q8++MD85Cc/McnJyUaSee6550Ie59NPPzX33XefufTSS82YMWPMddddZw4ePDjkebe1tRlJbGxsbGxsbKNwa2trG/TvfNgrNVVVVSouLtbmzZs1Y8YM/fM//7PmzJmjw4cPa/LkyUH1LS0tmjt3rhYtWqSXXnpJb731ln7+85/rz/7szzR//nxJ0rFjx5Samqq7775bJSUlIc/7+eefa8aMGbr55pu1a9cuTZw4Uf/93/+tSy65ZMhzHzt2rCSpra1N48aNC7d1AAAwAnw+nyZNmuT/O34yEcaE94WW06dPV2ZmpsrLy/1jaWlpmjdvnlwuV1D9smXLVFNTo+bmZv9YYWGhmpqa5PF4guqnTJmi4uJiFRcXB4wvX75cb731lvbu3RvOdAP4fD7FxcXJ6/USagAAGCWG+vc7rA8K9/b2qr6+Xnl5eQHjeXl52r9/f8h9PB5PUP3s2bN18OBB9fX1DfncNTU1ysrK0t13362JEydq2rRpeuGFFwbdp6enRz6fL2ADAAB2CivUdHd3a2BgQAkJCQHjCQkJ6uzsDLlPZ2dnyPr+/n51d3cP+dwff/yxysvLddVVV+nVV19VYWGhioqKtHXr1pPu43K5FBcX598mTZo05PMBAIDRZViXdH/3yiFjzKBXE4WqDzU+mOPHjyszM1NPP/20pk2bpocffliLFi0KeBvsu1asWCGv1+vf2trahnw+AAAwuoQVaiZMmKDIyMigVZmurq6g1ZgTEhMTQ9ZHRUUpPj5+yOe+7LLLlJ6eHjCWlpam1tbWk+7jcDg0bty4gA0AANgprFATExMjp9Mpt9sdMO52u5WTkxNyn+zs7KD63bt3KysrS9HR0UM+94wZM/TRRx8FjB05ckTJyclDPgYAALBX2G8/lZaW6l/+5V/0y1/+Us3NzSopKVFra6sKCwslffOWzwMPPOCvLyws1CeffKLS0lI1Nzfrl7/8pSorK/Xoo4/6a3p7e9XY2KjGxkb19vaqvb1djY2N+v3vf++vKSkp0dtvv62nn35av//97/Xyyy+roqJCixcvPp3+AQCALYZ857pv2bRpk0lOTjYxMTEmMzPT1NXV+Z9buHChmTVrVkD9nj17zLRp00xMTIyZMmVK0M33WlpaQt5k57vH+c1vfmMyMjKMw+EwU6dONRUVFWHN2+v1GknG6/WGtR8AABg5Q/37HfZ9akYz7lMDAMDoc1buUwMAAHCuItQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALBC1EhPAAjLG66RnkH4bl4x0jMAgPMCKzUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFbgkm6MKp6Pj470FMKWffNIzwAAzg+s1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJeP7DDz/U/PnzNWXKFEVERGjDhg2DHs/lcikiIkLFxcXDmT4AALBQ2KGmqqpKxcXFWrVqlRoaGpSbm6s5c+aotbU1ZH1LS4vmzp2r3NxcNTQ0aOXKlSoqKlJ1dbW/5tixY0pNTdWaNWuUmJg46PnfffddVVRU6Nprrw136gAAwGJhh5r169froYce0s9+9jOlpaVpw4YNmjRpksrLy0PWb9myRZMnT9aGDRuUlpamn/3sZ3rwwQe1bt06f80NN9ygZ599Vvfcc48cDsdJz/3ll1/qvvvu0wsvvKDx48efcq49PT3y+XwBGwAAsFNYoaa3t1f19fXKy8sLGM/Ly9P+/ftD7uPxeILqZ8+erYMHD6qvry+syS5evFi33367br311iHVu1wuxcXF+bdJkyaFdT4AADB6hBVquru7NTAwoISEhIDxhIQEdXZ2htyns7MzZH1/f7+6u7uHfO7t27fr0KFDcrlcQ95nxYoV8nq9/q2trW3I+wIAgNElajg7RUREBDw2xgSNnao+1PjJtLW1aenSpdq9e7diY2OHPE+HwzHo21kAAMAeYYWaCRMmKDIyMmhVpqurK2g15oTExMSQ9VFRUYqPjx/Seevr69XV1SWn0+kfGxgY0JtvvqmysjL19PQoMjIynFYAnMobQ18VPWfcvGKkZwBgBIX19lNMTIycTqfcbnfAuNvtVk5OTsh9srOzg+p3796trKwsRUdHD+m8t9xyi95//301Njb6t6ysLN13331qbGwk0AAAgPDffiotLVVBQYGysrKUnZ2tiooKtba2qrCwUNI3n2Npb2/X1q1bJUmFhYUqKytTaWmpFi1aJI/Ho8rKSm3bts1/zN7eXh0+fNj/7/b2djU2Nuriiy/WlVdeqbFjxyojIyNgHhdddJHi4+ODxgEAwPkp7FCTn5+vo0ePavXq1ero6FBGRoZqa2uVnJwsSero6Ai4Z01KSopqa2tVUlKiTZs2KSkpSRs3btT8+fP9NZ999pmmTZvmf7xu3TqtW7dOs2bN0p49e06jPQAAcL6IMCc+tXse8Pl8iouLk9fr1bhx40Z6OhgGT+WjIz2FsGU/tO7URecgXmsA54qh/v3mu58AAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWiBrpCQC2e859ZKSnMCw/HOkJAECYWKkBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACl3QDZ9kPWytGegoAcF5gpQYAAFiBUAMAAKxAqAEAAFYYVqjZvHmzUlJSFBsbK6fTqb179w5aX1dXJ6fTqdjYWKWmpmrLli0Bz3/44YeaP3++pkyZooiICG3YsCHoGC6XSzfccIPGjh2riRMnat68efroo4+GM30AAGChsENNVVWViouLtWrVKjU0NCg3N1dz5sxRa2tryPqWlhbNnTtXubm5amho0MqVK1VUVKTq6mp/zbFjx5Samqo1a9YoMTEx5HHq6uq0ePFivf3223K73erv71deXp6++uqrcFsAAAAWijDGmHB2mD59ujIzM1VeXu4fS0tL07x58+RyuYLqly1bppqaGjU3N/vHCgsL1dTUJI/HE1Q/ZcoUFRcXq7i4eNB5/PGPf9TEiRNVV1enmTNnDmnuPp9PcXFx8nq9Gjdu3JD2wbnFU/noSE8B57Dsh9aN9BQAnAVD/fsd1kpNb2+v6uvrlZeXFzCel5en/fv3h9zH4/EE1c+ePVsHDx5UX19fOKcP4PV6JUmXXnrpSWt6enrk8/kCNgAAYKewQk13d7cGBgaUkJAQMJ6QkKDOzs6Q+3R2doas7+/vV3d3d5jT/YYxRqWlpbrxxhuVkZFx0jqXy6W4uDj/NmnSpGGdDwAAnPuG9UHhiIiIgMfGmKCxU9WHGh+qJUuW6L333tO2bdsGrVuxYoW8Xq9/a2trG9b5AADAuS+sOwpPmDBBkZGRQasyXV1dQasxJyQmJoasj4qKUnx8fJjTlR555BHV1NTozTff1OWXXz5orcPhkMPhCPscAABg9AlrpSYmJkZOp1Nutztg3O12KycnJ+Q+2dnZQfW7d+9WVlaWoqOjh3xuY4yWLFmiHTt26PXXX1dKSko4UwcAAJYL+7ufSktLVVBQoKysLGVnZ6uiokKtra0qLCyU9M1bPu3t7dq6daukb650KisrU2lpqRYtWiSPx6PKysqAt456e3t1+PBh/7/b29vV2Nioiy++WFdeeaUkafHixXr55Zf1yiuvaOzYsf7Vn7i4OI0ZM+b0XgUAADDqhR1q8vPzdfToUa1evVodHR3KyMhQbW2tkpOTJUkdHR0B96xJSUlRbW2tSkpKtGnTJiUlJWnjxo2aP3++v+azzz7TtGnT/I/XrVundevWadasWdqzZ48k+S8hv+mmmwLm8+KLL+qnP/1puG0AAADLhH2fmtGM+9SMftynBoPhPjWAnc7KfWoAAADOVYQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJeP7DDz/U/PnzNWXKFEVERGjDhg1n5LwAAOD8EXaoqaqqUnFxsVatWqWGhgbl5uZqzpw5am1tDVnf0tKiuXPnKjc3Vw0NDVq5cqWKiopUXV3trzl27JhSU1O1Zs0aJSYmnpHzAgCA80uEMcaEs8P06dOVmZmp8vJy/1haWprmzZsnl8sVVL9s2TLV1NSoubnZP1ZYWKimpiZ5PJ6g+ilTpqi4uFjFxcWndd5QfD6f4uLi5PV6NW7cuCHtg3OLp/LRkZ4CzmHZD60b6SkAOAuG+vc7rJWa3t5e1dfXKy8vL2A8Ly9P+/fvD7mPx+MJqp89e7YOHjyovr6+s3ZeSerp6ZHP5wvYAACAncIKNd3d3RoYGFBCQkLAeEJCgjo7O0Pu09nZGbK+v79f3d3dZ+28kuRyuRQXF+ffJk2aNKTzAQCA0WdYHxSOiIgIeGyMCRo7VX2o8TN93hUrVsjr9fq3tra2sM4HAABGj6hwiidMmKDIyMig1ZGurq6gVZQTEhMTQ9ZHRUUpPj7+rJ1XkhwOhxwOx5DOAQAARrewVmpiYmLkdDrldrsDxt1ut3JyckLuk52dHVS/e/duZWVlKTo6+qydFwAAnF/CWqmRpNLSUhUUFCgrK0vZ2dmqqKhQa2urCgsLJX3zlk97e7u2bt0q6ZsrncrKylRaWqpFixbJ4/GosrJS27Zt8x+zt7dXhw8f9v+7vb1djY2Nuvjii3XllVcO6bwAAOD8Fnaoyc/P19GjR7V69Wp1dHQoIyNDtbW1Sk5OliR1dHQE3DsmJSVFtbW1Kikp0aZNm5SUlKSNGzdq/vz5/prPPvtM06ZN8z9et26d1q1bp1mzZmnPnj1DOi8A6I2h3d7hnHLzipGeAWCNsO9TM5pxn5rRj/vUwDbcWwc4tbNynxoAAIBzFaEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKwwrFCzefNmpaSkKDY2Vk6nU3v37h20vq6uTk6nU7GxsUpNTdWWLVuCaqqrq5Weni6Hw6H09HTt3Lkz4Pn+/n49/vjjSklJ0ZgxY5SamqrVq1fr+PHjw2kBAABYJuxQU1VVpeLiYq1atUoNDQ3Kzc3VnDlz1NraGrK+paVFc+fOVW5urhoaGrRy5UoVFRWpurraX+PxeJSfn6+CggI1NTWpoKBACxYs0IEDB/w1a9eu1ZYtW1RWVqbm5mY988wzevbZZ/X8888Po20AAGCbCGOMCWeH6dOnKzMzU+Xl5f6xtLQ0zZs3Ty6XK6h+2bJlqqmpUXNzs3+ssLBQTU1N8ng8kqT8/Hz5fD7t2rXLX3Pbbbdp/Pjx2rZtmyTpjjvuUEJCgiorK/018+fP14UXXqh//dd/HdLcfT6f4uLi5PV6NW7cuHDaxjnCU/noSE8BOKOyH1o30lMAznlD/fsd1kpNb2+v6uvrlZeXFzCel5en/fv3h9zH4/EE1c+ePVsHDx5UX1/foDXfPuaNN96o3/72tzpy5IgkqampSfv27dPcuXNPOt+enh75fL6ADQAA2CkqnOLu7m4NDAwoISEhYDwhIUGdnZ0h9+ns7AxZ39/fr+7ubl122WUnrfn2MZctWyav16upU6cqMjJSAwMDeuqpp3TvvfeedL4ul0v/8A//EE6LAABglBrWB4UjIiICHhtjgsZOVf/d8VMds6qqSi+99JJefvllHTp0SL/61a+0bt06/epXvzrpeVesWCGv1+vf2traTt0cAAAYlcJaqZkwYYIiIyODVmW6urqCVlpOSExMDFkfFRWl+Pj4QWu+fczHHntMy5cv1z333CNJuuaaa/TJJ5/I5XJp4cKFIc/tcDjkcDjCaREAAIxSYa3UxMTEyOl0yu12B4y73W7l5OSE3Cc7Ozuofvfu3crKylJ0dPSgNd8+5rFjx3TBBYHTjYyM5JJuAAAgKcyVGkkqLS1VQUGBsrKylJ2drYqKCrW2tqqwsFDSN2/5tLe3a+vWrZK+udKprKxMpaWlWrRokTwejyorK/1XNUnS0qVLNXPmTK1du1Z33XWXXnnlFb322mvat2+fv+bHP/6xnnrqKU2ePFlXX321GhoatH79ej344IOn+xoAAAALhB1q8vPzdfToUa1evVodHR3KyMhQbW2tkpOTJUkdHR0B96xJSUlRbW2tSkpKtGnTJiUlJWnjxo2aP3++vyYnJ0fbt2/X448/rieeeEJXXHGFqqqqNH36dH/N888/ryeeeEI///nP1dXVpaSkJD388MP6u7/7u9PpHwAAWCLs+9SMZtynZvTjPjWwDfepAU7trNynBgAA4FxFqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAVhhWqNm8ebNSUlIUGxsrp9OpvXv3DlpfV1cnp9Op2NhYpaamasuWLUE11dXVSk9Pl8PhUHp6unbu3BlU097ervvvv1/x8fG68MILdf3116u+vn44LQAAAMuEHWqqqqpUXFysVatWqaGhQbm5uZozZ45aW1tD1re0tGju3LnKzc1VQ0ODVq5cqaKiIlVXV/trPB6P8vPzVVBQoKamJhUUFGjBggU6cOCAv+bzzz/XjBkzFB0drV27dunw4cP6xS9+oUsuuST8rgEAgHUijDEmnB2mT5+uzMxMlZeX+8fS0tI0b948uVyuoPply5appqZGzc3N/rHCwkI1NTXJ4/FIkvLz8+Xz+bRr1y5/zW233abx48dr27ZtkqTly5frrbfeOuWq0GB8Pp/i4uLk9Xo1bty4YR8HI8dT+ehITwE4o7IfWjfSUwDOeUP9+x3WSk1vb6/q6+uVl5cXMJ6Xl6f9+/eH3Mfj8QTVz549WwcPHlRfX9+gNd8+Zk1NjbKysnT33Xdr4sSJmjZtml544YVB59vT0yOfzxewAQAAO4UVarq7uzUwMKCEhISA8YSEBHV2dobcp7OzM2R9f3+/uru7B6359jE//vhjlZeX66qrrtKrr76qwsJCFRUVaevWrSedr8vlUlxcnH+bNGlSOO0CAIBRZFgfFI6IiAh4bIwJGjtV/XfHT3XM48ePKzMzU08//bSmTZumhx9+WIsWLQp4G+y7VqxYIa/X69/a2tpO3RwAABiVwgo1EyZMUGRkZNCqTFdXV9BKywmJiYkh66OiohQfHz9ozbePedlllyk9PT2gJi0t7aQfUJYkh8OhcePGBWwAAMBOYYWamJgYOZ1Oud3ugHG3262cnJyQ+2RnZwfV7969W1lZWYqOjh605tvHnDFjhj766KOAmiNHjig5OTmcFgAAgKWiwt2htLRUBQUFysrKUnZ2tioqKtTa2qrCwkJJ37zl097e7v+sS2FhocrKylRaWqpFixbJ4/GosrLSf1WTJC1dulQzZ87U2rVrddddd+mVV17Ra6+9pn379vlrSkpKlJOTo6effloLFizQO++8o4qKClVUVJzuawAAACwQdqjJz8/X0aNHtXr1anV0dCgjI0O1tbX+FZOOjo6At4RSUlJUW1urkpISbdq0SUlJSdq4caPmz5/vr8nJydH27dv1+OOP64knntAVV1yhqqoqTZ8+3V9zww03aOfOnVqxYoVWr16tlJQUbdiwQffdd9/p9A8AACwR9n1qRjPuUzP6cZ8a2Ib71ACndlbuUwMAAHCuItQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBXC/kJLAMCZ85z7yEhPIWwlP/o/Iz0FICRWagAAgBUINQAAwAqEGgAAYAVCDQAAsAIfFD5T3nCN9AzCd/OKkZ4BcN77YWvFSE9hGNaN9ASAkFipAQAAViDUAAAAK/D20xni+fjoSE8hbNk3j/QMAAA4c1ipAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAV+O6n85in8tGRngIAAGcMKzUAAMAKhBoAAGAFQg0AALDCsELN5s2blZKSotjYWDmdTu3du3fQ+rq6OjmdTsXGxio1NVVbtmwJqqmurlZ6erocDofS09O1c+fOkx7P5XIpIiJCxcXFw5k+AACwUNihpqqqSsXFxVq1apUaGhqUm5urOXPmqLW1NWR9S0uL5s6dq9zcXDU0NGjlypUqKipSdXW1v8bj8Sg/P18FBQVqampSQUGBFixYoAMHDgQd791331VFRYWuvfbacKcOAAAsFmGMMeHsMH36dGVmZqq8vNw/lpaWpnnz5snlcgXVL1u2TDU1NWpubvaPFRYWqqmpSR6PR5KUn58vn8+nXbt2+Wtuu+02jR8/Xtu2bfOPffnll8rMzNTmzZv1j//4j7r++uu1YcOGIc/d5/MpLi5OXq9X48aNC6ftU+JKIgDni+yH1o30FHCeGerf77BWanp7e1VfX6+8vLyA8by8PO3fvz/kPh6PJ6h+9uzZOnjwoPr6+gat+e4xFy9erNtvv1233nrrkObb09Mjn88XsAEAADuFFWq6u7s1MDCghISEgPGEhAR1dnaG3KezszNkfX9/v7q7uwet+fYxt2/frkOHDoVcDToZl8uluLg4/zZp0qQh7wsAAEaXYX1QOCIiIuCxMSZo7FT13x0f7JhtbW1aunSpXnrpJcXGxg55nitWrJDX6/VvbW1tQ94XAACMLmHdUXjChAmKjIwMWpXp6uoKWmk5ITExMWR9VFSU4uPjB605ccz6+np1dXXJ6XT6nx8YGNCbb76psrIy9fT0KDIyMujcDodDDocjnBYBAMAoFdZKTUxMjJxOp9xud8C42+1WTk5OyH2ys7OD6nfv3q2srCxFR0cPWnPimLfccovef/99NTY2+resrCzdd999amxsDBloAADA+SXs734qLS1VQUGBsrKylJ2drYqKCrW2tqqwsFDSN2/5tLe3a+vWrZK+udKprKxMpaWlWrRokTwejyorKwOualq6dKlmzpyptWvX6q677tIrr7yi1157Tfv27ZMkjR07VhkZGQHzuOiiixQfHx80DgAAzk9hh5r8/HwdPXpUq1evVkdHhzIyMlRbW6vk5GRJUkdHR8A9a1JSUlRbW6uSkhJt2rRJSUlJ2rhxo+bPn++vycnJ0fbt2/X444/riSee0BVXXKGqqipNnz79DLQIAADOB2Hfp2Y04z41AHD6uE8Nvm9n5T41AAAA5ypCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAViDUAAAAKxBqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYYVihZvPmzUpJSVFsbKycTqf27t07aH1dXZ2cTqdiY2OVmpqqLVu2BNVUV1crPT1dDodD6enp2rlzZ8DzLpdLN9xwg8aOHauJEydq3rx5+uijj4YzfQAAYKGwQ01VVZWKi4u1atUqNTQ0KDc3V3PmzFFra2vI+paWFs2dO1e5ublqaGjQypUrVVRUpOrqan+Nx+NRfn6+CgoK1NTUpIKCAi1YsEAHDhzw19TV1Wnx4sV6++235Xa71d/fr7y8PH311VfDaBsAANgmwhhjwtlh+vTpyszMVHl5uX8sLS1N8+bNk8vlCqpftmyZampq1Nzc7B8rLCxUU1OTPB6PJCk/P18+n0+7du3y19x2220aP368tm3bFnIef/zjHzVx4kTV1dVp5syZQ5q7z+dTXFycvF6vxo0bN6R9hspT+egZPR4AnKuyH1o30lPAeWaof7/DWqnp7e1VfX298vLyAsbz8vK0f//+kPt4PJ6g+tmzZ+vgwYPq6+sbtOZkx5Qkr9crSbr00ktPWtPT0yOfzxewAQAAO4UVarq7uzUwMKCEhISA8YSEBHV2dobcp7OzM2R9f3+/uru7B6052TGNMSotLdWNN96ojIyMk87X5XIpLi7Ov02aNOmUPQIAgNFpWB8UjoiICHhsjAkaO1X9d8fDOeaSJUv03nvvnfStqRNWrFghr9fr39ra2gatBwAAo1dUOMUTJkxQZGRk0ApKV1dX0ErLCYmJiSHro6KiFB8fP2hNqGM+8sgjqqmp0ZtvvqnLL7980Pk6HA45HI5T9gUAGLrn3EdGegphK/nR/xnpKeB7ENZKTUxMjJxOp9xud8C42+1WTk5OyH2ys7OD6nfv3q2srCxFR0cPWvPtYxpjtGTJEu3YsUOvv/66UlJSwpk6AACwXFgrNZJUWlqqgoICZWVlKTs7WxUVFWptbVVhYaGkb97yaW9v19atWyV9c6VTWVmZSktLtWjRInk8HlVWVga8dbR06VLNnDlTa9eu1V133aVXXnlFr732mvbt2+evWbx4sV5++WW98sorGjt2rH9lJy4uTmPGjDmtFwEAAIx+YYea/Px8HT16VKtXr1ZHR4cyMjJUW1ur5ORkSVJHR0fAPWtSUlJUW1urkpISbdq0SUlJSdq4caPmz5/vr8nJydH27dv1+OOP64knntAVV1yhqqoqTZ8+3V9z4hLym266KWA+L774on7605+G2wYAALBM2PepGc24Tw0AnL63J//fkZ5C2PhMzeh2Vu5TAwAAcK4i1AAAACsQagAAgBUINQAAwAphX/0EADi//bC1YqSnMAyj8Es43wj+kuhz3s0rRvT0rNQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBS7oBADgHeT4+OtJTCFv2zSN7flZqAACAFQg1AADACoQaAABgBUINAACwAqEGAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKfE0CAMB6z7mPjPQUwvbDkZ7AKESoAQBY74etFSM9BXwPePsJAABYgVADAACsQKgBAABWINQAAAArEGoAAIAVCDUAAMAKhBoAAGAFQg0AALACoQYAAFiBUAMAAKxAqAEAAFYYVqjZvHmzUlJSFBsbK6fTqb179w5aX1dXJ6fTqdjYWKWmpmrLli1BNdXV1UpPT5fD4VB6erp27tx52ucFAADnj7BDTVVVlYqLi7Vq1So1NDQoNzdXc+bMUWtra8j6lpYWzZ07V7m5uWpoaNDKlStVVFSk6upqf43H41F+fr4KCgrU1NSkgoICLViwQAcOHBj2eQEAwPklwhhjwtlh+vTpyszMVHl5uX8sLS1N8+bNk8vlCqpftmyZampq1Nzc7B8rLCxUU1OTPB6PJCk/P18+n0+7du3y19x2220aP368tm3bNqzzSlJPT496enr8j71eryZPnqy2tjaNGzcunLZP6Z2tq87o8QAAGG3+4oGnzspxfT6fJk2apD/96U+Ki4s7eaEJQ09Pj4mMjDQ7duwIGC8qKjIzZ84MuU9ubq4pKioKGNuxY4eJiooyvb29xhhjJk2aZNavXx9Qs379ejN58uRhn9cYY5588kkjiY2NjY2Njc2Cra2tbdCcEqUwdHd3a2BgQAkJCQHjCQkJ6uzsDLlPZ2dnyPr+/n51d3frsssuO2nNiWMO57yStGLFCpWWlvofHz9+XP/zP/+j+Ph4RUREnLrhITqRIM/GCtC57nzt/XztW6L387H387Vvid7Pld6NMfriiy+UlJQ0aF1YoeaE7wYCY8ygISFU/XfHh3LMcM/rcDjkcDgCxi655JKT1p+ucePGjfgPfqScr72fr31L9H4+9n6+9i3R+7nQ+6BvO/1/YX1QeMKECYqMjAxaHenq6gpaRTkhMTExZH1UVJTi4+MHrTlxzOGcFwAAnF/CCjUxMTFyOp1yu90B4263Wzk5OSH3yc7ODqrfvXu3srKyFB0dPWjNiWMO57wAAOA8M+gnbkLYvn27iY6ONpWVlebw4cOmuLjYXHTRReYPf/iDMcaY5cuXm4KCAn/9xx9/bC688EJTUlJiDh8+bCorK010dLT5j//4D3/NW2+9ZSIjI82aNWtMc3OzWbNmjYmKijJvv/32kM87kr7++mvz5JNPmq+//nqkp/K9O197P1/7Nobez8fez9e+jaH30dZ72KHGGGM2bdpkkpOTTUxMjMnMzDR1dXX+5xYuXGhmzZoVUL9nzx4zbdo0ExMTY6ZMmWLKy8uDjvnv//7v5s///M9NdHS0mTp1qqmurg7rvAAA4PwW9n1qAAAAzkV89xMAALACoQYAAFiBUAMAAKxAqAEAAFYg1JwBmzdvVkpKimJjY+V0OrV3796RntIZ5XK5dMMNN2js2LGaOHGi5s2bp48++iigxhijv//7v1dSUpLGjBmjm266SR9++OEIzfjscLlcioiIUHFxsX/M5r7b29t1//33Kz4+XhdeeKGuv/561dfX+5+3tff+/n49/vjjSklJ0ZgxY5SamqrVq1fr+PHj/hpben/zzTf14x//WElJSYqIiNCvf/3rgOeH0mdPT48eeeQRTZgwQRdddJHuvPNOffrpp99jF+EbrO++vj4tW7ZM11xzjS666CIlJSXpgQce0GeffRZwjNHYt3Tqn/m3Pfzww4qIiNCGDRsCxs/l3gk1p6mqqkrFxcVatWqVGhoalJubqzlz5qi1tXWkp3bG1NXVafHixXr77bfldrvV39+vvLw8ffXVV/6aZ555RuvXr1dZWZneffddJSYm6kc/+pG++OKLEZz5mfPuu++qoqJC1157bcC4rX1//vnnmjFjhqKjo7Vr1y4dPnxYv/jFLwK+ZsTW3teuXastW7aorKxMzc3NeuaZZ/Tss8/q+eef99fY0vtXX32l6667TmVlZSGfH0qfxcXF2rlzp7Zv3659+/bpyy+/1B133KGBgYHvq42wDdb3sWPHdOjQIT3xxBM6dOiQduzYoSNHjujOO+8MqBuNfUun/pmf8Otf/1oHDhwI+V1L53TvI3g5uRX+4i/+whQWFgaMTZ061SxfvnyEZnT2dXV1GUn++wQdP37cJCYmmjVr1vhrvv76axMXF2e2bNkyUtM8Y7744gtz1VVXGbfbbWbNmmWWLl1qjLG772XLlpkbb7zxpM/b3Pvtt99uHnzwwYCxn/zkJ+b+++83xtjbuySzc+dO/+Oh9PmnP/3JREdHm+3bt/tr2tvbzQUXXGD+67/+63ub++n4bt+hvPPOO0aS+eSTT4wxdvRtzMl7//TTT80PfvAD88EHH5jk5GTz3HPP+Z8713tnpeY09Pb2qr6+Xnl5eQHjeXl52r9//wjN6uzzer2SpEsvvVSS1NLSos7OzoDXweFwaNasWVa8DosXL9btt9+uW2+9NWDc5r5ramqUlZWlu+++WxMnTtS0adP0wgsv+J+3ufcbb7xRv/3tb3XkyBFJUlNTk/bt26e5c+dKsrv3bxtKn/X19err6wuoSUpKUkZGhlWvhdfrVUREhH+l0ua+jx8/roKCAj322GO6+uqrg54/13sf1rd04xvd3d0aGBgI+lLNhISEoC/ftIUxRqWlpbrxxhuVkZEhSf5eQ70On3zyyfc+xzNp+/btOnTokN59992g52zu++OPP1Z5eblKS0u1cuVKvfPOOyoqKpLD4dADDzxgde/Lli2T1+vV1KlTFRkZqYGBAT311FO69957Jdn9c/+2ofTZ2dmpmJgYjR8/PqjGlt+BX3/9tZYvX66/+qu/8n9Ttc19r127VlFRUSoqKgr5/LneO6HmDIiIiAh4bIwJGrPFkiVL9N5772nfvn1Bz9n2OrS1tWnp0qXavXu3YmNjT1pnW9/SN/+3lpWVpaefflqSNG3aNH344YcqLy/XAw884K+zsfeqqiq99NJLevnll3X11VersbFRxcXFSkpK0sKFC/11NvYeynD6tOW16Ovr0z333KPjx49r8+bNp6wf7X3X19frn/7pn3To0KGw+zhXeuftp9MwYcIERUZGBqXTrq6uoP+7scEjjzyimpoavfHGG7r88sv944mJiZJk3etQX1+vrq4uOZ1ORUVFKSoqSnV1ddq4caOioqL8vdnWtyRddtllSk9PDxhLS0vzfwDe1p+5JD322GNavny57rnnHl1zzTUqKChQSUmJXC6XJLt7/7ah9JmYmKje3l59/vnnJ60Zrfr6+rRgwQK1tLTI7Xb7V2kke/veu3evurq6NHnyZP/vvE8++UR/8zd/oylTpkg693sn1JyGmJgYOZ1Oud3ugHG3262cnJwRmtWZZ4zRkiVLtGPHDr3++utKSUkJeD4lJUWJiYkBr0Nvb6/q6upG9etwyy236P3331djY6N/y8rK0n333afGxkalpqZa2bckzZgxI+iy/SNHjig5OVmSvT9z6ZurXy64IPBXY2RkpP+Sbpt7/7ah9Ol0OhUdHR1Q09HRoQ8++GBUvxYnAs3vfvc7vfbaa4qPjw943ta+CwoK9N577wX8zktKStJjjz2mV199VdIo6H2EPqBsje3bt5vo6GhTWVlpDh8+bIqLi81FF11k/vCHP4z01M6Yv/7rvzZxcXFmz549pqOjw78dO3bMX7NmzRoTFxdnduzYYd5//31z7733mssuu8z4fL4RnPmZ9+2rn4yxt+933nnHREVFmaeeesr87ne/M//2b/9mLrzwQvPSSy/5a2ztfeHCheYHP/iB+c///E/T0tJiduzYYSZMmGD+9m//1l9jS+9ffPGFaWhoMA0NDUaSWb9+vWloaPBf5TOUPgsLC83ll19uXnvtNXPo0CHzl3/5l+a6664z/f39I9XWKQ3Wd19fn7nzzjvN5ZdfbhobGwN+5/X09PiPMRr7NubUP/Pv+u7VT8ac270Tas6ATZs2meTkZBMTE2MyMzP9lzrbQlLI7cUXX/TXHD9+3Dz55JMmMTHROBwOM3PmTPP++++P3KTPku+GGpv7/s1vfmMyMjKMw+EwU6dONRUVFQHP29q7z+czS5cuNZMnTzaxsbEmNTXVrFq1KuAPmi29v/HGGyH/2164cKExZmh9/u///q9ZsmSJufTSS82YMWPMHXfcYVpbW0egm6EbrO+WlpaT/s574403/McYjX0bc+qf+XeFCjXncu8RxhjzfawIAQAAnE18pgYAAFiBUAMAAKxAqAEAAFYg1AAAACsQagAAgBUINQAAwAqEGgAAYAVCDQAAsAKhBgAAWIFQAwAArECoAQAAVvh//ay80P2gzYQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(y, density=True, alpha=0.5)\n",
    "plt.hist(y_train_full, density=True, alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 5 folds for each of 69120 candidates, totalling 345600 fits\n",
      "Best parameter: {'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100}\n",
      "Best score: -20.039419563611347\n"
     ]
    }
   ],
   "source": [
    "model = XGBRegressor(objective=\"reg:absoluteerror\")\n",
    "\n",
    "param_grid = {\n",
    "    \"gamma\": [0, 0.1, 0.2],\n",
    "    \"learning_rate\": [0.01, 0.1, 0.2, 0.3],\n",
    "    \"max_bin\": [2, 3, 4, 5, 10],\n",
    "    \"max_delta_step\": [0, 1, 2],\n",
    "    \"max_depth\": [1, 2, 4, 6],\n",
    "    \"max_leaves\": [0, 1, 2, 3, 4, 5],\n",
    "    \"min_child_weight\": [0.001, 0.01, 0.1, 0.5],\n",
    "    \"n_estimators\": [10, 50, 100, 200],\n",
    "}\n",
    "\n",
    "# Set up the GridSearchCV\n",
    "grid_search = GridSearchCV(\n",
    "    estimator=model,\n",
    "    param_grid=param_grid,\n",
    "    scoring='neg_mean_absolute_error', \n",
    "    cv=5,\n",
    "    verbose=1,\n",
    "    n_jobs=-1,\n",
    "    return_train_score=True,\n",
    ")\n",
    "\n",
    "grid_search.fit(X_train_full, y_train_full)\n",
    "\n",
    "print(f\"Best parameter: {grid_search.best_params_}\")\n",
    "print(f\"Best score: {grid_search.best_score_}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(grid_search.cv_results_).to_csv(\"output.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100}\n",
      "-20.039419563611347\n",
      "20.508699798583983\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>43.000000</td>\n",
       "      <td>37.987656</td>\n",
       "      <td>5.012344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>9.000000</td>\n",
       "      <td>53.990318</td>\n",
       "      <td>-44.990318</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>73.000000</td>\n",
       "      <td>59.660213</td>\n",
       "      <td>13.339787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>16.000000</td>\n",
       "      <td>53.314426</td>\n",
       "      <td>-37.314426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>59.000000</td>\n",
       "      <td>60.072800</td>\n",
       "      <td>-1.072800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>53.000000</td>\n",
       "      <td>55.507435</td>\n",
       "      <td>-2.507435</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>93.000000</td>\n",
       "      <td>59.416786</td>\n",
       "      <td>33.583214</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>82.416667</td>\n",
       "      <td>47.146587</td>\n",
       "      <td>35.270079</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>89.000000</td>\n",
       "      <td>57.640259</td>\n",
       "      <td>31.359741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>88.000000</td>\n",
       "      <td>54.131992</td>\n",
       "      <td>33.868008</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>80 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            0          1          2\n",
       "0   43.000000  37.987656   5.012344\n",
       "1    9.000000  53.990318 -44.990318\n",
       "2   73.000000  59.660213  13.339787\n",
       "3   16.000000  53.314426 -37.314426\n",
       "4   59.000000  60.072800  -1.072800\n",
       "..        ...        ...        ...\n",
       "75  53.000000  55.507435  -2.507435\n",
       "76  93.000000  59.416786  33.583214\n",
       "77  82.416667  47.146587  35.270079\n",
       "78  89.000000  57.640259  31.359741\n",
       "79  88.000000  54.131992  33.868008\n",
       "\n",
       "[80 rows x 3 columns]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(grid_search.best_params_)\n",
    "print(grid_search.best_score_)\n",
    "\n",
    "model = grid_search.best_estimator_\n",
    "\n",
    "y_pred = model.predict(X_test_reserved)\n",
    "\n",
    "print(mean_absolute_error(y_test_reserved, y_pred))\n",
    "\n",
    "l1 = np.array(list(y_test_reserved))\n",
    "l2 = np.array(list(y_pred))\n",
    "l3 = l1 - l2\n",
    "\n",
    "display(pd.DataFrame([l1, l2, l3]).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CV MAE and R2\n",
      "20.039419563611347\n",
      "0.05608801676895068\n",
      "\n",
      "Training set MAE and R2\n",
      "17.971198264757792\n",
      "0.1730067350076071\n",
      "\n",
      "Test set MAE and R2\n",
      "20.508699798583983\n",
      "0.029032418854204156\n"
     ]
    }
   ],
   "source": [
    "param = {'gamma': 0, 'learning_rate': 0.1, 'max_bin': 3, 'max_delta_step': 0, 'max_depth': 2, 'max_leaves': 3, 'min_child_weight': 0.001, 'n_estimators': 100, \"objective\":\"reg:absoluteerror\"}\n",
    "\n",
    "model = XGBRegressor(**param)\n",
    "\n",
    "model.fit(X_train_full, y_train_full)\n",
    "\n",
    "print(\"CV MAE and R2\")\n",
    "print(-np.mean(cross_val_score(model, X_train_full, y_train_full, scoring='neg_mean_absolute_error')))\n",
    "print(np.mean(cross_val_score(model, X_train_full, y_train_full, scoring='r2')))\n",
    "\n",
    "y_pred_train = model.predict(X_train_full)\n",
    "print(\"\\nTraining set MAE and R2\")\n",
    "print(mean_absolute_error(y_train_full, y_pred_train))\n",
    "print(r2_score(y_train_full, y_pred_train))\n",
    "\n",
    "y_pred_test = model.predict(X_test_reserved)\n",
    "print(\"\\nTest set MAE and R2\")\n",
    "print(mean_absolute_error(y_test_reserved, y_pred_test))\n",
    "print(r2_score(y_test_reserved, y_pred_test))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MLEAsm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}