[b4c0b6]: / 03-Experiments / 12-AutoGluon.ipynb

Download this file

2005 lines (2004 with data), 120.1 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Experiment: artifact_location='/Users/arham/Downloads/Projects/03-Experiments/mlruns/3', creation_time=1713916994501, experiment_id='3', last_update_time=1713916994501, lifecycle_stage='active', name='AutoGluon', tags={}>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import mlflow\n",
    "\n",
    "# Set the MLflow tracking URI to a new SQLite URI\n",
    "mlflow.set_tracking_uri(\"sqlite:///new_mlflow.db\")\n",
    "mlflow.set_experiment(\"AutoGluon\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experiment Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "import lightgbm as lgb\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "import lightgbm as lgb\n",
    "from sklearn.metrics import accuracy_score\n",
    "import warnings\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score\n",
    "import xgboost as xgb\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "\n",
    "def load_data(path):\n",
    "    df = pd.read_csv(path)\n",
    "    train_df, test_df = train_test_split(df, test_size=0.35, random_state=42)\n",
    "    train_df, val_df,  = train_test_split(train_df, test_size=0.20, random_state=42)\n",
    "    train_df = train_df.drop(['id'], axis=1).drop_duplicates().reset_index(drop=True)\n",
    "    test_df = test_df.drop(['id'], axis=1).drop_duplicates().reset_index(drop=True)\n",
    "    val_df = val_df.drop(['id'], axis=1).drop_duplicates().reset_index(drop=True)\n",
    "    return train_df, val_df, test_df\n",
    "\n",
    "def encode_target(train):\n",
    "    target_key = {'Insufficient_Weight': 0, 'Normal_Weight': 1, 'Overweight_Level_I': 2, 'Overweight_Level_II': 3, 'Obesity_Type_I': 4,'Obesity_Type_II' : 5, 'Obesity_Type_III': 6}\n",
    "    train['NObeyesdad'] = train['NObeyesdad'].map(target_key)\n",
    "    return train\n",
    "\n",
    "def make_gender_binary(train):\n",
    "    train['Gender'] = train['Gender'].map({'Male':0, 'Female':1})\n",
    "\n",
    "def datatypes(train):\n",
    "    train['Weight'] = train['Weight'].astype(float)\n",
    "    train['Age'] = train['Age'].astype(float)\n",
    "    train['Height'] = train['Height'].astype(float)\n",
    "    return train\n",
    "\n",
    "# def age_binning(train_df):\n",
    "#     # train_df['Age_Group'] = pd.cut(train_df['Age'], bins=[0, 20, 30, 40, 50, train_df['Age'].max()], labels=['0-20', '21-30', '31-40', '41-50', '50+'])\n",
    "#     train_df['Age_Group'] = pd.cut(train_df['Age'], bins=[0, 20, 30, 40, 50, train_df['Age'].max()], labels=[1, 2, 3, 4, 5])\n",
    "#     train_df['Age_Group'] = train_df['Age_Group'].astype(int)\n",
    "#     return train_df\n",
    "\n",
    "def age_binning(df):\n",
    "    age_groups = []\n",
    "    for age in df['Age']:\n",
    "        if age <= 20:\n",
    "            age_group = 1\n",
    "        elif age <= 30:\n",
    "            age_group = 2\n",
    "        elif age <= 40:\n",
    "            age_group = 3\n",
    "        elif age <= 50:\n",
    "            age_group = 4\n",
    "        else:\n",
    "            age_group = 5\n",
    "        age_groups.append(age_group)\n",
    "    df['Age_Group'] = age_groups\n",
    "    return df\n",
    "\n",
    "def age_scaling_log(train_df):\n",
    "    train_df['Age'] = train_df['Age'].astype(float)\n",
    "    train_df['Log_Age'] = np.log1p(train_df['Age'])\n",
    "    return train_df\n",
    "\n",
    "def age_scaling_minmax(train_df):\n",
    "    train_df['Age'] = train_df['Age'].astype(float)\n",
    "    scaler_age = MinMaxScaler()\n",
    "    train_df['Scaled_Age'] = scaler_age.fit_transform(train_df['Age'].values.reshape(-1, 1))\n",
    "    return train_df, scaler_age\n",
    "\n",
    "def weight_scaling_log(train_df):\n",
    "    train_df['Weight'] = train_df['Weight'].astype(float)\n",
    "    train_df['Log_Weight'] = np.log1p(train_df['Weight'])\n",
    "    return train_df\n",
    "\n",
    "def weight_scaling_minmax(train_df):\n",
    "    train_df['Weight'] = train_df['Weight'].astype(float)\n",
    "    scaler_weight = MinMaxScaler()\n",
    "    train_df['Scaled_Weight'] = scaler_weight.fit_transform(train_df['Weight'].values.reshape(-1, 1))\n",
    "    return train_df, scaler_weight\n",
    "\n",
    "def height_scaling_log(train_df):\n",
    "    train_df['Log_Height'] = np.log1p(train_df['Height'])\n",
    "    return train_df\n",
    "\n",
    "def height_scaling_minmax(train_df):\n",
    "    scaler_height = MinMaxScaler()\n",
    "    train_df['Scaled_Height'] = scaler_height.fit_transform(train_df['Height'].values.reshape(-1, 1))\n",
    "    return train_df, scaler_height\n",
    "\n",
    "def make_gender_binary(train):\n",
    "    train['Gender'] = train['Gender'].map({'Female':1, 'Male':0})\n",
    "    return train\n",
    "\n",
    "def fix_binary_columns(train):\n",
    "    Binary_Cols = ['family_history_with_overweight','FAVC', 'SCC','SMOKE']\n",
    "    # if yes then 1 else 0\n",
    "    for col in Binary_Cols:\n",
    "        train[col] = train[col].map({'yes': 1, 'no': 0})\n",
    "        # column datatype integer\n",
    "        train[col] = train[col].astype(int)\n",
    "    return train\n",
    "\n",
    "def freq_cat_cols(train):\n",
    "    # One hot encoding\n",
    "    cat_cols = ['CAEC', 'CALC']\n",
    "    for col in cat_cols:\n",
    "        train[col] = train[col].map({'no': 0, 'Sometimes': 1, 'Frequently': 2, 'Always': 3})\n",
    "    return train\n",
    "\n",
    "def Mtrans(train):\n",
    "    \"\"\"\n",
    "    Public_Transportation    8692\n",
    "    Automobile               1835\n",
    "    Walking                   231\n",
    "    Motorbike                  19\n",
    "    Bike                       16\n",
    "    \"\"\"\n",
    "    # train['MTRANS'] = train['MTRANS'].map({'Public_Transportation': 3, 'Automobile': 5, 'Walking': 1, 'Motorbike': 4, 'Bike': 2})\n",
    "    # dummify column\n",
    "    train = pd.get_dummies(train, columns=['MTRANS'])\n",
    "    # convert these columns to integer\n",
    "    train['MTRANS_Automobile'] = train['MTRANS_Automobile'].astype(int)\n",
    "    train['MTRANS_Walking'] = train['MTRANS_Walking'].astype(int)\n",
    "    train['MTRANS_Motorbike'] = train['MTRANS_Motorbike'].astype(int)\n",
    "    train['MTRANS_Bike'] = train['MTRANS_Bike'].astype(int)\n",
    "    train['MTRANS_Public_Transportation'] = train['MTRANS_Public_Transportation'].astype(int)\n",
    "    return train\n",
    "\n",
    "\n",
    "def other_features(train):\n",
    "    train['BMI'] = train['Weight'] / (train['Height'] ** 2)\n",
    "    # train['Age'*'Gender'] = train['Age'] * train['Gender']\n",
    "    polynomial_features = PolynomialFeatures(degree=2)\n",
    "    X_poly = polynomial_features.fit_transform(train[['Age', 'BMI']])\n",
    "    poly_features_df = pd.DataFrame(X_poly, columns=['Age^2', 'Age^3', 'BMI^2', 'Age * BMI', 'Age * BMI^2', 'Age^2 * BMI^2'])\n",
    "    train = pd.concat([train, poly_features_df], axis=1)\n",
    "    return train\n",
    "\n",
    "\n",
    "def test_pipeline(test, scaler_age, scaler_weight, scaler_height):\n",
    "    test = datatypes(test)\n",
    "    test = encode_target(test)\n",
    "    test = age_binning(test)\n",
    "    test = age_scaling_log(test)\n",
    "    test['Scaled_Age'] = scaler_age.transform(test['Age'].values.reshape(-1, 1))\n",
    "    test = weight_scaling_log(test)\n",
    "    test['Scaled_Weight'] = scaler_weight.transform(test['Weight'].values.reshape(-1, 1))\n",
    "    test = height_scaling_log(test)\n",
    "    test['Scaled_Height'] = scaler_height.transform(test['Height'].values.reshape(-1, 1))\n",
    "    test = make_gender_binary(test)\n",
    "    test = fix_binary_columns(test)\n",
    "    test = freq_cat_cols(test)\n",
    "    test = Mtrans(test)\n",
    "    test = other_features(test)\n",
    "\n",
    "    return test\n",
    "\n",
    "def train_model(params, X_train, y_train):\n",
    "    lgb_train = lgb.Dataset(X_train, y_train)\n",
    "    model = lgb.train(params, lgb_train, num_boost_round=1000)\n",
    "    return model\n",
    "\n",
    "def evaluate_model(model, X_val, y_val):\n",
    "    y_pred = model.predict(X_val)\n",
    "    y_pred = [np.argmax(y) for y in y_pred]\n",
    "    accuracy = accuracy_score(y_val, y_pred)\n",
    "    return accuracy\n",
    "\n",
    "def objective(trial, X_train, y_train):\n",
    "    params = {\n",
    "        'objective': 'multiclass',\n",
    "        'num_class': 7,\n",
    "        'metric': 'multi_logloss',\n",
    "        'boosting_type': 'gbdt',\n",
    "        'learning_rate': trial.suggest_loguniform('learning_rate', 0.005, 0.5),\n",
    "        'num_leaves': trial.suggest_int('num_leaves', 10, 1000),\n",
    "        'max_depth': trial.suggest_int('max_depth', -1, 20),\n",
    "        'bagging_fraction': trial.suggest_uniform('bagging_fraction', 0.6, 0.95),\n",
    "        'feature_fraction': trial.suggest_uniform('feature_fraction', 0.6, 0.95),\n",
    "        'verbosity': -1\n",
    "    }\n",
    "\n",
    "    n_splits = 5\n",
    "    kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)\n",
    "    scores = []\n",
    "\n",
    "    for train_index, val_index in kf.split(X_train, y_train):\n",
    "        X_tr, X_val = X_train.iloc[train_index], X_train.iloc[val_index]\n",
    "        y_tr, y_val = y_train.iloc[train_index], y_train.iloc[val_index]\n",
    "\n",
    "        model = train_model(params, X_tr, y_tr)\n",
    "        accuracy = evaluate_model(model, X_val, y_val)\n",
    "        scores.append(accuracy)\n",
    "\n",
    "    return np.mean(scores)\n",
    "\n",
    "def optimize_hyperparameters(X_train, y_train, n_trials=2):\n",
    "    study = optuna.create_study(direction='maximize')\n",
    "    study.optimize(lambda trial: objective(trial, X_train, y_train), n_trials=n_trials)\n",
    "    return study.best_params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '/Users/arham/Downloads/Projects/01-Dataset/01-Data-for-model-building/train.csv'\n",
    "train_df, val_df, test_df = load_data(path)\n",
    "\n",
    "train_df = datatypes(train_df)\n",
    "train_df = encode_target(train_df)\n",
    "train_df = age_binning(train_df)\n",
    "train_df, scaler_age = age_scaling_minmax(train_df)\n",
    "train_df = age_scaling_log(train_df)\n",
    "train_df, scaler_weight = weight_scaling_minmax(train_df)\n",
    "train_df = weight_scaling_log(train_df)\n",
    "train_df, scaler_height = height_scaling_minmax(train_df)\n",
    "train_df = height_scaling_log(train_df)\n",
    "train_df = make_gender_binary(train_df)\n",
    "train_df = fix_binary_columns(train_df)\n",
    "train_df = freq_cat_cols(train_df)\n",
    "train_df = Mtrans(train_df)\n",
    "train_df = other_features(train_df)\n",
    "\n",
    "val_df = test_pipeline(val_df, scaler_age, scaler_weight, scaler_height)\n",
    "test_df = test_pipeline(test_df, scaler_age, scaler_weight, scaler_height)\n",
    "\n",
    "Target = 'NObeyesdad'\n",
    "features = train_df.columns.drop(Target)\n",
    "\n",
    "features = ['Gender', 'Age', 'Height', 'Weight', 'family_history_with_overweight',\n",
    "       'FAVC', 'FCVC', 'NCP', 'CAEC', 'SMOKE', 'CH2O', 'SCC', 'FAF', 'TUE',\n",
    "       'CALC', \n",
    "       'MTRANS_Automobile', 'MTRANS_Bike', 'MTRANS_Motorbike',\n",
    "       'MTRANS_Public_Transportation', 'MTRANS_Walking']\n",
    "       #   'BMI', 'Age^2',\n",
    "       # 'Age^3', 'BMI^2', 'Age * BMI', 'Age * BMI^2', 'Age^2 * BMI^2'] \n",
    "#'Scaled_Age', 'Log_Age', 'Scaled_Weight', 'Log_Weight', 'Scaled_Height', 'Log_Height',\n",
    "\n",
    "X_train = train_df[features]\n",
    "y_train = train_df[Target]\n",
    "X_val = val_df[features]\n",
    "y_val = val_df[Target]\n",
    "X_test = test_df[features]\n",
    "y_test = test_df[Target]\n",
    "\n",
    "#combine X_train and y_train as one dataframe\n",
    "tr = pd.concat([X_train, y_train], axis=1)\n",
    "te = pd.concat([X_test, y_test], axis =1)\n",
    "va = pd.concat([X_val, y_val], axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{6: 0.19790605021773372,\n",
       " 5: 0.1545446122486797,\n",
       " 1: 0.15056054850365977,\n",
       " 4: 0.13879366255906606,\n",
       " 0: 0.12202353377188918,\n",
       " 3: 0.12081904938385991,\n",
       " 2: 0.11535254331511165}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.value_counts()/len(y_train)\n",
    "# to dict as weights\n",
    "weights = y_train.value_counts(normalize=True).to_dict()\n",
    "weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No path specified. Models will be saved in: \"AutogluonModels/ag-20240426_034247\"\n",
      "No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets.\n",
      "\tRecommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):\n",
      "\tpresets='best_quality'   : Maximize accuracy. Default time_limit=3600.\n",
      "\tpresets='high_quality'   : Strong accuracy with fast inference speed. Default time_limit=3600.\n",
      "\tpresets='good_quality'   : Good accuracy with very fast inference speed. Default time_limit=3600.\n",
      "\tpresets='medium_quality' : Fast training time, ideal for initial prototyping.\n",
      "Beginning AutoGluon training ...\n",
      "AutoGluon will save models to \"AutogluonModels/ag-20240426_034247\"\n",
      "=================== System Info ===================\n",
      "AutoGluon Version:  1.1.0\n",
      "Python Version:     3.10.13\n",
      "Operating System:   Darwin\n",
      "Platform Machine:   arm64\n",
      "Platform Version:   Darwin Kernel Version 23.0.0: Fri Sep 15 14:42:57 PDT 2023; root:xnu-10002.1.13~1/RELEASE_ARM64_T8112\n",
      "CPU Count:          8\n",
      "Memory Avail:       1.25 GB / 8.00 GB (15.7%)\n",
      "Disk Space Avail:   24.83 GB / 228.27 GB (10.9%)\n",
      "===================================================\n",
      "Train Data Rows:    10793\n",
      "Train Data Columns: 20\n",
      "Label Column:       NObeyesdad\n",
      "AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).\n",
      "\t7 unique label values:  [0, 3, 1, 6, 4, 2, 5]\n",
      "\tIf 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n",
      "Problem Type:       multiclass\n",
      "Preprocessing data ...\n",
      "Train Data Class Count: 7\n",
      "Using Feature Generators to preprocess the data ...\n",
      "Fitting AutoMLPipelineFeatureGenerator...\n",
      "\tAvailable Memory:                    1285.60 MB\n",
      "\tTrain Data (Original)  Memory Usage: 1.65 MB (0.1% of available memory)\n",
      "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
      "\tStage 1 Generators:\n",
      "\t\tFitting AsTypeFeatureGenerator...\n",
      "\t\t\tNote: Converting 10 features to boolean dtype as they only contain 2 unique values.\n",
      "\tStage 2 Generators:\n",
      "\t\tFitting FillNaFeatureGenerator...\n",
      "\tStage 3 Generators:\n",
      "\t\tFitting IdentityFeatureGenerator...\n",
      "\tStage 4 Generators:\n",
      "\t\tFitting DropUniqueFeatureGenerator...\n",
      "\tStage 5 Generators:\n",
      "\t\tFitting DropDuplicatesFeatureGenerator...\n",
      "\tTypes of features in original data (raw dtype, special dtypes):\n",
      "\t\t('float', []) :  8 | ['Age', 'Height', 'Weight', 'FCVC', 'NCP', ...]\n",
      "\t\t('int', [])   : 12 | ['Gender', 'family_history_with_overweight', 'FAVC', 'CAEC', 'SMOKE', ...]\n",
      "\tTypes of features in processed data (raw dtype, special dtypes):\n",
      "\t\t('float', [])     :  8 | ['Age', 'Height', 'Weight', 'FCVC', 'NCP', ...]\n",
      "\t\t('int', [])       :  2 | ['CAEC', 'CALC']\n",
      "\t\t('int', ['bool']) : 10 | ['Gender', 'family_history_with_overweight', 'FAVC', 'SMOKE', 'SCC', ...]\n",
      "\t0.1s = Fit runtime\n",
      "\t20 features in original data used to generate 20 features in processed data.\n",
      "\tTrain Data (Processed) Memory Usage: 0.93 MB (0.1% of available memory)\n",
      "Data preprocessing and feature engineering runtime = 0.07s ...\n",
      "AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
      "\tTo change this, specify the eval_metric parameter of Predictor()\n",
      "Automatically generating train/validation split with holdout_frac=0.1, Train Rows: 9713, Val Rows: 1080\n",
      "User-specified model hyperparameters to be fit:\n",
      "{\n",
      "\t'NN_TORCH': {},\n",
      "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n",
      "\t'CAT': {},\n",
      "\t'XGB': {},\n",
      "\t'FASTAI': {},\n",
      "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n",
      "}\n",
      "Fitting 13 L1 models ...\n",
      "Fitting model: KNeighborsUnif ...\n",
      "\t0.837\t = Validation score   (accuracy)\n",
      "\t0.04s\t = Training   runtime\n",
      "\t0.02s\t = Validation runtime\n",
      "Fitting model: KNeighborsDist ...\n",
      "\t0.8361\t = Validation score   (accuracy)\n",
      "\t0.02s\t = Training   runtime\n",
      "\t0.01s\t = Validation runtime\n",
      "Fitting model: NeuralNetFastAI ...\n",
      "\t0.8907\t = Validation score   (accuracy)\n",
      "\t6.55s\t = Training   runtime\n",
      "\t0.02s\t = Validation runtime\n",
      "Fitting model: LightGBMXT ...\n",
      "\t0.9019\t = Validation score   (accuracy)\n",
      "\t7.58s\t = Training   runtime\n",
      "\t0.06s\t = Validation runtime\n",
      "Fitting model: LightGBM ...\n",
      "\t0.9204\t = Validation score   (accuracy)\n",
      "\t9.07s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: RandomForestGini ...\n",
      "\t0.9037\t = Validation score   (accuracy)\n",
      "\t1.18s\t = Training   runtime\n",
      "\t0.06s\t = Validation runtime\n",
      "Fitting model: RandomForestEntr ...\n",
      "\t0.9083\t = Validation score   (accuracy)\n",
      "\t1.13s\t = Training   runtime\n",
      "\t0.05s\t = Validation runtime\n",
      "Fitting model: CatBoost ...\n",
      "\t0.913\t = Validation score   (accuracy)\n",
      "\t9.03s\t = Training   runtime\n",
      "\t0.01s\t = Validation runtime\n",
      "Fitting model: ExtraTreesGini ...\n",
      "\t0.8778\t = Validation score   (accuracy)\n",
      "\t0.98s\t = Training   runtime\n",
      "\t0.08s\t = Validation runtime\n",
      "Fitting model: ExtraTreesEntr ...\n",
      "\t0.875\t = Validation score   (accuracy)\n",
      "\t0.98s\t = Training   runtime\n",
      "\t0.07s\t = Validation runtime\n",
      "Fitting model: XGBoost ...\n",
      "\t0.9176\t = Validation score   (accuracy)\n",
      "\t6.6s\t = Training   runtime\n",
      "\t0.02s\t = Validation runtime\n",
      "Fitting model: NeuralNetTorch ...\n",
      "\tWarning: Exception caused NeuralNetTorch to fail during training... Skipping this model.\n",
      "\t\tmodule 'torch.utils._pytree' has no attribute 'register_pytree_node'\n",
      "Detailed Traceback:\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1904, in _train_and_save\n",
      "    model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1844, in _train_single\n",
      "    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 855, in fit\n",
      "    out = self._fit(**kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py\", line 196, in _fit\n",
      "    self.optimizer = self._init_optimizer(**optimizer_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py\", line 553, in _init_optimizer\n",
      "    optimizer = torch.optim.Adam(params=self.model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/optim/adam.py\", line 45, in __init__\n",
      "    super().__init__(params, defaults)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/optim/optimizer.py\", line 266, in __init__\n",
      "    self.add_param_group(cast(dict, param_group))\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_compile.py\", line 22, in inner\n",
      "    import torch._dynamo\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/__init__.py\", line 2, in <module>\n",
      "    from . import allowed_functions, convert_frame, eval_frame, resume_execution\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/allowed_functions.py\", line 26, in <module>\n",
      "    from . import config\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/config.py\", line 49, in <module>\n",
      "    torch.onnx.is_in_onnx_export: False,\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/__init__.py\", line 1831, in __getattr__\n",
      "    return importlib.import_module(f\".{name}\", __name__)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/importlib/__init__.py\", line 126, in import_module\n",
      "    return _bootstrap._gcd_import(name[level:], package, level)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/__init__.py\", line 46, in <module>\n",
      "    from ._internal.exporter import (  # usort:skip. needs to be last to avoid circular import\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py\", line 42, in <module>\n",
      "    from torch.onnx._internal.fx import (\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/fx/__init__.py\", line 1, in <module>\n",
      "    from .patcher import ONNXTorchPatcher\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/fx/patcher.py\", line 11, in <module>\n",
      "    import transformers  # type: ignore[import]\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/__init__.py\", line 26, in <module>\n",
      "    from . import dependency_versions_check\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/dependency_versions_check.py\", line 16, in <module>\n",
      "    from .utils.versions import require_version, require_version_core\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/utils/__init__.py\", line 33, in <module>\n",
      "    from .generic import (\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/utils/generic.py\", line 455, in <module>\n",
      "    _torch_pytree.register_pytree_node(\n",
      "AttributeError: module 'torch.utils._pytree' has no attribute 'register_pytree_node'. Did you mean: '_register_pytree_node'?\n",
      "Fitting model: LightGBMLarge ...\n",
      "\t0.912\t = Validation score   (accuracy)\n",
      "\t75.08s\t = Training   runtime\n",
      "\t0.09s\t = Validation runtime\n",
      "Fitting model: WeightedEnsemble_L2 ...\n",
      "\tEnsemble Weights: {'LightGBM': 0.3, 'RandomForestEntr': 0.3, 'KNeighborsUnif': 0.2, 'NeuralNetFastAI': 0.1, 'XGBoost': 0.1}\n",
      "\t0.9259\t = Validation score   (accuracy)\n",
      "\t0.13s\t = Training   runtime\n",
      "\t0.0s\t = Validation runtime\n",
      "AutoGluon training complete, total runtime = 121.18s ... Best model: \"WeightedEnsemble_L2\"\n",
      "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20240426_034247\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall for class 0: 0.9462025316455697\n",
      "Recall for class 1: 0.8957219251336899\n",
      "Recall for class 2: 0.755223880597015\n",
      "Recall for class 3: 0.8449848024316109\n",
      "Recall for class 4: 0.8788598574821853\n",
      "Recall for class 5: 0.9688995215311005\n",
      "Recall for class 6: 0.9960474308300395\n"
     ]
    }
   ],
   "source": [
    "from autogluon.tabular import TabularDataset, TabularPredictor\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "import mlflow\n",
    "\n",
    "# Load your data into AutoGluon TabularDataset format\n",
    "train_data = TabularDataset(X_train.join(y_train))\n",
    "val_data = TabularDataset(X_val.join(y_val))\n",
    "\n",
    "# Define the label column\n",
    "label_column = Target # Replace 'Target' with your actual label column name\n",
    "\n",
    "# Specify the task and run AutoGluon\n",
    "predictor = TabularPredictor(label=label_column).fit(train_data=train_data)\n",
    "\n",
    "# Make predictions on the validation set\n",
    "y_val_pred_autogluon = predictor.predict(val_data.drop(columns=[label_column]))\n",
    "\n",
    "# Evaluate performance\n",
    "precision, recall, f1, support = precision_recall_fscore_support(y_val, y_val_pred_autogluon, average='weighted')\n",
    "\n",
    "# Log metrics and model using MLflow\n",
    "with mlflow.start_run(run_name=\"AutoGluon_Without_Feature_Engineering\"):\n",
    "    # Log AutoGluon model\n",
    "    mlflow.sklearn.log_model(predictor, \"autogluon_model\")\n",
    "    \n",
    "    # Log metrics\n",
    "    mlflow.log_metric('accuracy', accuracy_score(y_val, y_val_pred_autogluon))\n",
    "    mlflow.log_metric('precision', precision)\n",
    "    mlflow.log_metric('recall', recall)\n",
    "    mlflow.log_metric('f1', f1)\n",
    "\n",
    "    # Log recall per class\n",
    "    recall_per_class = recall_score(y_val, y_val_pred_autogluon, average=None)\n",
    "    for i, recall_class in enumerate(recall_per_class):\n",
    "        print(f\"Recall for class {i}: {recall_class}\")\n",
    "        mlflow.log_metric(f'recall_class_{i}', recall_class)\n",
    "\n",
    "    mlflow.set_tag('experiments', 'Arham A.')\n",
    "    mlflow.set_tag('model_name', 'AutoGluon')\n",
    "    mlflow.set_tag('preprocessing', 'Yes')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-04-26 05:38:37 -0400] [14055] [INFO] Starting gunicorn 21.2.0\n",
      "[2024-04-26 05:38:37 -0400] [14055] [INFO] Listening at: http://127.0.0.1:5000 (14055)\n",
      "[2024-04-26 05:38:37 -0400] [14055] [INFO] Using worker: sync\n",
      "[2024-04-26 05:38:37 -0400] [14056] [INFO] Booting worker with pid: 14056\n",
      "[2024-04-26 05:38:37 -0400] [14057] [INFO] Booting worker with pid: 14057\n",
      "[2024-04-26 05:38:37 -0400] [14058] [INFO] Booting worker with pid: 14058\n",
      "[2024-04-26 05:38:37 -0400] [14059] [INFO] Booting worker with pid: 14059\n",
      "[2024-04-26 06:53:19 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:14056)\n",
      "[2024-04-26 06:53:19 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:14057)\n",
      "[2024-04-26 06:53:19 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:14058)\n",
      "[2024-04-26 06:53:19 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:14059)\n",
      "[2024-04-26 06:53:19 -0400] [14059] [INFO] Worker exiting (pid: 14059)\n",
      "[2024-04-26 06:53:19 -0400] [14058] [INFO] Worker exiting (pid: 14058)\n",
      "[2024-04-26 06:53:19 -0400] [14057] [INFO] Worker exiting (pid: 14057)\n",
      "[2024-04-26 06:53:19 -0400] [14056] [INFO] Worker exiting (pid: 14056)\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14058) exited with code 1\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14058) exited with code 1.\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14059) exited with code 1\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14059) exited with code 1.\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14056) exited with code 1\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14056) exited with code 1.\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14057) exited with code 1\n",
      "[2024-04-26 06:53:20 -0400] [14055] [ERROR] Worker (pid:14057) exited with code 1.\n",
      "[2024-04-26 06:53:20 -0400] [16284] [INFO] Booting worker with pid: 16284\n",
      "[2024-04-26 06:53:20 -0400] [16285] [INFO] Booting worker with pid: 16285\n",
      "[2024-04-26 06:53:20 -0400] [16286] [INFO] Booting worker with pid: 16286\n",
      "[2024-04-26 06:53:20 -0400] [16287] [INFO] Booting worker with pid: 16287\n",
      "[2024-04-26 07:10:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16284)\n",
      "[2024-04-26 07:10:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16285)\n",
      "[2024-04-26 07:10:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16286)\n",
      "[2024-04-26 07:10:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16287)\n",
      "[2024-04-26 07:10:32 -0400] [16287] [INFO] Worker exiting (pid: 16287)\n",
      "[2024-04-26 07:10:32 -0400] [16285] [INFO] Worker exiting (pid: 16285)\n",
      "[2024-04-26 07:10:32 -0400] [16286] [INFO] Worker exiting (pid: 16286)\n",
      "[2024-04-26 07:10:32 -0400] [16284] [INFO] Worker exiting (pid: 16284)\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16287) exited with code 1\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16287) exited with code 1.\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16284) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16285) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16286) exited with code 1\n",
      "[2024-04-26 07:10:33 -0400] [14055] [ERROR] Worker (pid:16286) exited with code 1.\n",
      "[2024-04-26 07:10:33 -0400] [16374] [INFO] Booting worker with pid: 16374\n",
      "[2024-04-26 07:10:33 -0400] [16375] [INFO] Booting worker with pid: 16375\n",
      "[2024-04-26 07:10:33 -0400] [16376] [INFO] Booting worker with pid: 16376\n",
      "[2024-04-26 07:10:33 -0400] [16377] [INFO] Booting worker with pid: 16377\n",
      "[2024-04-26 07:48:01 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16374)\n",
      "[2024-04-26 07:48:01 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16375)\n",
      "[2024-04-26 07:48:01 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16376)\n",
      "[2024-04-26 07:48:01 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16377)\n",
      "[2024-04-26 07:48:01 -0400] [16376] [INFO] Worker exiting (pid: 16376)\n",
      "[2024-04-26 07:48:01 -0400] [16377] [INFO] Worker exiting (pid: 16377)\n",
      "[2024-04-26 07:48:01 -0400] [16375] [INFO] Worker exiting (pid: 16375)\n",
      "[2024-04-26 07:48:01 -0400] [16374] [INFO] Worker exiting (pid: 16374)\n",
      "[2024-04-26 07:48:02 -0400] [14055] [ERROR] Worker (pid:16374) exited with code 1\n",
      "[2024-04-26 07:48:02 -0400] [14055] [ERROR] Worker (pid:16374) exited with code 1.\n",
      "[2024-04-26 07:48:02 -0400] [16404] [INFO] Booting worker with pid: 16404\n",
      "[2024-04-26 07:48:02 -0400] [14055] [ERROR] Worker (pid:16375) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:48:02 -0400] [14055] [ERROR] Worker (pid:16376) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:48:02 -0400] [14055] [ERROR] Worker (pid:16377) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:48:02 -0400] [16405] [INFO] Booting worker with pid: 16405\n",
      "[2024-04-26 07:48:02 -0400] [16406] [INFO] Booting worker with pid: 16406\n",
      "[2024-04-26 07:48:02 -0400] [16407] [INFO] Booting worker with pid: 16407\n",
      "[2024-04-26 07:54:51 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16404)\n",
      "[2024-04-26 07:54:51 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16405)\n",
      "[2024-04-26 07:54:51 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16406)\n",
      "[2024-04-26 07:54:51 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16407)\n",
      "[2024-04-26 07:54:51 -0400] [16406] [INFO] Worker exiting (pid: 16406)\n",
      "[2024-04-26 07:54:51 -0400] [16405] [INFO] Worker exiting (pid: 16405)\n",
      "[2024-04-26 07:54:51 -0400] [16404] [INFO] Worker exiting (pid: 16404)\n",
      "[2024-04-26 07:54:51 -0400] [16407] [INFO] Worker exiting (pid: 16407)\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16404) exited with code 1\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16404) exited with code 1.\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16407) exited with code 1\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16407) exited with code 1.\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16406) exited with code 1\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16406) exited with code 1.\n",
      "[2024-04-26 07:54:52 -0400] [16450] [INFO] Booting worker with pid: 16450\n",
      "[2024-04-26 07:54:52 -0400] [14055] [ERROR] Worker (pid:16405) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 07:54:52 -0400] [16451] [INFO] Booting worker with pid: 16451\n",
      "[2024-04-26 07:54:52 -0400] [16452] [INFO] Booting worker with pid: 16452\n",
      "[2024-04-26 07:54:52 -0400] [16453] [INFO] Booting worker with pid: 16453\n",
      "[2024-04-26 08:04:43 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16450)\n",
      "[2024-04-26 08:04:43 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16451)\n",
      "[2024-04-26 08:04:43 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16452)\n",
      "[2024-04-26 08:04:43 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16453)\n",
      "[2024-04-26 08:04:43 -0400] [16450] [INFO] Worker exiting (pid: 16450)\n",
      "[2024-04-26 08:04:43 -0400] [16451] [INFO] Worker exiting (pid: 16451)\n",
      "[2024-04-26 08:04:43 -0400] [16453] [INFO] Worker exiting (pid: 16453)\n",
      "[2024-04-26 08:04:43 -0400] [16452] [INFO] Worker exiting (pid: 16452)\n",
      "[2024-04-26 08:04:44 -0400] [14055] [ERROR] Worker (pid:16453) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:04:44 -0400] [16510] [INFO] Booting worker with pid: 16510\n",
      "[2024-04-26 08:04:44 -0400] [14055] [ERROR] Worker (pid:16452) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:04:44 -0400] [14055] [ERROR] Worker (pid:16451) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:04:44 -0400] [14055] [ERROR] Worker (pid:16450) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:04:44 -0400] [16511] [INFO] Booting worker with pid: 16511\n",
      "[2024-04-26 08:04:44 -0400] [16512] [INFO] Booting worker with pid: 16512\n",
      "[2024-04-26 08:04:44 -0400] [16514] [INFO] Booting worker with pid: 16514\n",
      "[2024-04-26 08:23:29 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16510)\n",
      "[2024-04-26 08:23:29 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16511)\n",
      "[2024-04-26 08:23:29 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16512)\n",
      "[2024-04-26 08:23:29 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16514)\n",
      "[2024-04-26 08:23:29 -0400] [16514] [INFO] Worker exiting (pid: 16514)\n",
      "[2024-04-26 08:23:29 -0400] [16511] [INFO] Worker exiting (pid: 16511)\n",
      "[2024-04-26 08:23:29 -0400] [16512] [INFO] Worker exiting (pid: 16512)\n",
      "[2024-04-26 08:23:29 -0400] [16510] [INFO] Worker exiting (pid: 16510)\n",
      "[2024-04-26 08:23:29 -0400] [14055] [ERROR] Worker (pid:16514) exited with code 1\n",
      "[2024-04-26 08:23:29 -0400] [14055] [ERROR] Worker (pid:16514) exited with code 1.\n",
      "[2024-04-26 08:23:29 -0400] [14055] [ERROR] Worker (pid:16512) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:23:29 -0400] [14055] [ERROR] Worker (pid:16511) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:23:29 -0400] [14055] [ERROR] Worker (pid:16510) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:23:29 -0400] [16747] [INFO] Booting worker with pid: 16747\n",
      "[2024-04-26 08:23:29 -0400] [16748] [INFO] Booting worker with pid: 16748\n",
      "[2024-04-26 08:23:29 -0400] [16749] [INFO] Booting worker with pid: 16749\n",
      "[2024-04-26 08:23:29 -0400] [16750] [INFO] Booting worker with pid: 16750\n",
      "[2024-04-26 08:31:37 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16747)\n",
      "[2024-04-26 08:31:37 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16748)\n",
      "[2024-04-26 08:31:37 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16749)\n",
      "[2024-04-26 08:31:37 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16750)\n",
      "[2024-04-26 08:31:37 -0400] [16749] [INFO] Worker exiting (pid: 16749)\n",
      "[2024-04-26 08:31:37 -0400] [16750] [INFO] Worker exiting (pid: 16750)\n",
      "[2024-04-26 08:31:37 -0400] [16747] [INFO] Worker exiting (pid: 16747)\n",
      "[2024-04-26 08:31:37 -0400] [16748] [INFO] Worker exiting (pid: 16748)\n",
      "[2024-04-26 08:31:37 -0400] [14055] [ERROR] Worker (pid:16750) exited with code 1\n",
      "[2024-04-26 08:31:37 -0400] [14055] [ERROR] Worker (pid:16750) exited with code 1.\n",
      "[2024-04-26 08:31:37 -0400] [16768] [INFO] Booting worker with pid: 16768\n",
      "[2024-04-26 08:31:37 -0400] [14055] [ERROR] Worker (pid:16748) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:31:37 -0400] [14055] [ERROR] Worker (pid:16747) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:31:37 -0400] [14055] [ERROR] Worker (pid:16749) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:31:37 -0400] [16769] [INFO] Booting worker with pid: 16769\n",
      "[2024-04-26 08:31:37 -0400] [16771] [INFO] Booting worker with pid: 16771\n",
      "[2024-04-26 08:31:37 -0400] [16772] [INFO] Booting worker with pid: 16772\n",
      "[2024-04-26 08:37:33 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16768)\n",
      "[2024-04-26 08:37:33 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16769)\n",
      "[2024-04-26 08:37:33 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16771)\n",
      "[2024-04-26 08:37:33 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16772)\n",
      "[2024-04-26 08:37:33 -0400] [16769] [INFO] Worker exiting (pid: 16769)\n",
      "[2024-04-26 08:37:33 -0400] [16768] [INFO] Worker exiting (pid: 16768)\n",
      "[2024-04-26 08:37:33 -0400] [16771] [INFO] Worker exiting (pid: 16771)\n",
      "[2024-04-26 08:37:33 -0400] [16772] [INFO] Worker exiting (pid: 16772)\n",
      "[2024-04-26 08:37:33 -0400] [14055] [ERROR] Worker (pid:16772) exited with code 1\n",
      "[2024-04-26 08:37:33 -0400] [14055] [ERROR] Worker (pid:16772) exited with code 1.\n",
      "[2024-04-26 08:37:33 -0400] [14055] [ERROR] Worker (pid:16771) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:37:33 -0400] [14055] [ERROR] Worker (pid:16768) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:37:33 -0400] [14055] [ERROR] Worker (pid:16769) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:37:33 -0400] [16901] [INFO] Booting worker with pid: 16901\n",
      "[2024-04-26 08:37:33 -0400] [16902] [INFO] Booting worker with pid: 16902\n",
      "[2024-04-26 08:37:33 -0400] [16903] [INFO] Booting worker with pid: 16903\n",
      "[2024-04-26 08:37:33 -0400] [16904] [INFO] Booting worker with pid: 16904\n",
      "[2024-04-26 08:38:21 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16901)\n",
      "[2024-04-26 08:38:21 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16902)\n",
      "[2024-04-26 08:38:21 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16903)\n",
      "[2024-04-26 08:38:21 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16904)\n",
      "[2024-04-26 08:38:21 -0400] [16901] [INFO] Worker exiting (pid: 16901)\n",
      "[2024-04-26 08:38:21 -0400] [16902] [INFO] Worker exiting (pid: 16902)\n",
      "[2024-04-26 08:38:21 -0400] [16903] [INFO] Worker exiting (pid: 16903)\n",
      "[2024-04-26 08:38:21 -0400] [16904] [INFO] Worker exiting (pid: 16904)\n",
      "[2024-04-26 08:38:21 -0400] [14055] [ERROR] Worker (pid:16904) exited with code 1\n",
      "[2024-04-26 08:38:21 -0400] [14055] [ERROR] Worker (pid:16904) exited with code 1.\n",
      "[2024-04-26 08:38:21 -0400] [14055] [ERROR] Worker (pid:16902) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:38:21 -0400] [16965] [INFO] Booting worker with pid: 16965\n",
      "[2024-04-26 08:38:21 -0400] [14055] [ERROR] Worker (pid:16901) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:38:21 -0400] [14055] [ERROR] Worker (pid:16903) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:38:21 -0400] [16966] [INFO] Booting worker with pid: 16966\n",
      "[2024-04-26 08:38:21 -0400] [16967] [INFO] Booting worker with pid: 16967\n",
      "[2024-04-26 08:38:21 -0400] [16968] [INFO] Booting worker with pid: 16968\n",
      "[2024-04-26 08:39:34 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16965)\n",
      "[2024-04-26 08:39:34 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16966)\n",
      "[2024-04-26 08:39:34 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16967)\n",
      "[2024-04-26 08:39:34 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:16968)\n",
      "[2024-04-26 08:39:34 -0400] [16965] [INFO] Worker exiting (pid: 16965)\n",
      "[2024-04-26 08:39:34 -0400] [16966] [INFO] Worker exiting (pid: 16966)\n",
      "[2024-04-26 08:39:34 -0400] [16967] [INFO] Worker exiting (pid: 16967)\n",
      "[2024-04-26 08:39:34 -0400] [16968] [INFO] Worker exiting (pid: 16968)\n",
      "[2024-04-26 08:39:34 -0400] [14055] [ERROR] Worker (pid:16965) exited with code 1\n",
      "[2024-04-26 08:39:34 -0400] [14055] [ERROR] Worker (pid:16965) exited with code 1.\n",
      "[2024-04-26 08:39:34 -0400] [17085] [INFO] Booting worker with pid: 17085\n",
      "[2024-04-26 08:39:34 -0400] [14055] [ERROR] Worker (pid:16966) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:39:34 -0400] [14055] [ERROR] Worker (pid:16967) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:39:34 -0400] [14055] [ERROR] Worker (pid:16968) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:39:34 -0400] [17086] [INFO] Booting worker with pid: 17086\n",
      "[2024-04-26 08:39:34 -0400] [17087] [INFO] Booting worker with pid: 17087\n",
      "[2024-04-26 08:39:34 -0400] [17088] [INFO] Booting worker with pid: 17088\n",
      "[2024-04-26 08:40:44 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17085)\n",
      "[2024-04-26 08:40:44 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17086)\n",
      "[2024-04-26 08:40:44 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17087)\n",
      "[2024-04-26 08:40:44 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17088)\n",
      "[2024-04-26 08:40:44 -0400] [17086] [INFO] Worker exiting (pid: 17086)\n",
      "[2024-04-26 08:40:44 -0400] [17087] [INFO] Worker exiting (pid: 17087)\n",
      "[2024-04-26 08:40:44 -0400] [17085] [INFO] Worker exiting (pid: 17085)\n",
      "[2024-04-26 08:40:44 -0400] [17088] [INFO] Worker exiting (pid: 17088)\n",
      "[2024-04-26 08:40:44 -0400] [14055] [ERROR] Worker (pid:17086) exited with code 1\n",
      "[2024-04-26 08:40:44 -0400] [14055] [ERROR] Worker (pid:17086) exited with code 1.\n",
      "[2024-04-26 08:40:44 -0400] [17139] [INFO] Booting worker with pid: 17139\n",
      "[2024-04-26 08:40:44 -0400] [14055] [ERROR] Worker (pid:17088) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:40:44 -0400] [14055] [ERROR] Worker (pid:17085) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:40:44 -0400] [14055] [ERROR] Worker (pid:17087) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:40:44 -0400] [17140] [INFO] Booting worker with pid: 17140\n",
      "[2024-04-26 08:40:44 -0400] [17141] [INFO] Booting worker with pid: 17141\n",
      "[2024-04-26 08:40:44 -0400] [17142] [INFO] Booting worker with pid: 17142\n",
      "[2024-04-26 08:41:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17139)\n",
      "[2024-04-26 08:41:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17140)\n",
      "[2024-04-26 08:41:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17141)\n",
      "[2024-04-26 08:41:32 -0400] [14055] [CRITICAL] WORKER TIMEOUT (pid:17142)\n",
      "[2024-04-26 08:41:32 -0400] [17142] [INFO] Worker exiting (pid: 17142)\n",
      "[2024-04-26 08:41:32 -0400] [17140] [INFO] Worker exiting (pid: 17140)\n",
      "[2024-04-26 08:41:32 -0400] [17141] [INFO] Worker exiting (pid: 17141)\n",
      "[2024-04-26 08:41:32 -0400] [17139] [INFO] Worker exiting (pid: 17139)\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17140) exited with code 1\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17140) exited with code 1.\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17139) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17141) exited with code 1\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17141) exited with code 1.\n",
      "[2024-04-26 08:41:32 -0400] [14055] [ERROR] Worker (pid:17142) was sent SIGKILL! Perhaps out of memory?\n",
      "[2024-04-26 08:41:32 -0400] [17204] [INFO] Booting worker with pid: 17204\n",
      "[2024-04-26 08:41:32 -0400] [17205] [INFO] Booting worker with pid: 17205\n",
      "[2024-04-26 08:41:33 -0400] [17206] [INFO] Booting worker with pid: 17206\n",
      "[2024-04-26 08:41:33 -0400] [17207] [INFO] Booting worker with pid: 17207\n",
      "^C\n",
      "[2024-04-26 08:55:21 -0400] [14055] [INFO] Handling signal: int\n",
      "[2024-04-26 08:55:21 -0400] [17207] [INFO] Worker exiting (pid: 17207)\n",
      "[2024-04-26 08:55:21 -0400] [17205] [INFO] Worker exiting (pid: 17205)\n",
      "[2024-04-26 08:55:21 -0400] [17206] [INFO] Worker exiting (pid: 17206)\n",
      "[2024-04-26 08:55:21 -0400] [17204] [INFO] Worker exiting (pid: 17204)\n"
     ]
    }
   ],
   "source": [
    "!mlflow ui --backend-store-uri \"sqlite:////Users/arham/Downloads/Projects/03-Experiments/new_mlflow.db\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving probabilities for stacked model later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions on train, validation, and test data\n",
    "y_pred_train = predictor.predict(tr)\n",
    "y_pred_va = predictor.predict(va)\n",
    "y_pred_test = predictor.predict(te)\n",
    "\n",
    "# Get prediction probabilities for each class\n",
    "y_probabilities_train = predictor.predict_proba(tr)\n",
    "y_probabilities_va = predictor.predict_proba(va)\n",
    "y_probabilities_test = predictor.predict_proba(te)\n",
    "\n",
    "# add to tr, te, va\n",
    "test = pd.concat([te, y_probabilities_test.iloc[:, -7:]], axis=1)\n",
    "train = pd.concat([tr, y_probabilities_train.iloc[:, -7:]], axis=1)\n",
    "val = pd.concat([va, y_probabilities_va.iloc[:, -7:]], axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Gender</th>\n",
       "      <th>Age</th>\n",
       "      <th>Height</th>\n",
       "      <th>Weight</th>\n",
       "      <th>family_history_with_overweight</th>\n",
       "      <th>FAVC</th>\n",
       "      <th>FCVC</th>\n",
       "      <th>NCP</th>\n",
       "      <th>CAEC</th>\n",
       "      <th>SMOKE</th>\n",
       "      <th>...</th>\n",
       "      <th>Age * BMI^2</th>\n",
       "      <th>Age^2 * BMI^2</th>\n",
       "      <th>NObeyesdad</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>21.000000</td>\n",
       "      <td>1.550000</td>\n",
       "      <td>51.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>445.785640</td>\n",
       "      <td>450.623213</td>\n",
       "      <td>0</td>\n",
       "      <td>0.797175</td>\n",
       "      <td>0.196596</td>\n",
       "      <td>0.005617</td>\n",
       "      <td>0.000520</td>\n",
       "      <td>0.000032</td>\n",
       "      <td>0.000019</td>\n",
       "      <td>0.000040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>20.000000</td>\n",
       "      <td>1.700000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>553.633218</td>\n",
       "      <td>766.274350</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000041</td>\n",
       "      <td>0.001926</td>\n",
       "      <td>0.047782</td>\n",
       "      <td>0.948167</td>\n",
       "      <td>0.002041</td>\n",
       "      <td>0.000018</td>\n",
       "      <td>0.000025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>18.000000</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>60.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>421.875000</td>\n",
       "      <td>549.316406</td>\n",
       "      <td>1</td>\n",
       "      <td>0.011628</td>\n",
       "      <td>0.953612</td>\n",
       "      <td>0.033047</td>\n",
       "      <td>0.001538</td>\n",
       "      <td>0.000150</td>\n",
       "      <td>0.000005</td>\n",
       "      <td>0.000020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>1.632983</td>\n",
       "      <td>111.720238</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1089.285877</td>\n",
       "      <td>1755.242193</td>\n",
       "      <td>6</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000002</td>\n",
       "      <td>0.000008</td>\n",
       "      <td>0.000018</td>\n",
       "      <td>0.000172</td>\n",
       "      <td>0.000040</td>\n",
       "      <td>0.999759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>21.682636</td>\n",
       "      <td>1.748524</td>\n",
       "      <td>133.845064</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>949.229536</td>\n",
       "      <td>1916.541944</td>\n",
       "      <td>6</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000004</td>\n",
       "      <td>0.000012</td>\n",
       "      <td>0.000092</td>\n",
       "      <td>0.000058</td>\n",
       "      <td>0.999831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10788</th>\n",
       "      <td>0</td>\n",
       "      <td>18.000000</td>\n",
       "      <td>1.780000</td>\n",
       "      <td>108.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>613.558894</td>\n",
       "      <td>1161.896656</td>\n",
       "      <td>4</td>\n",
       "      <td>0.000007</td>\n",
       "      <td>0.000014</td>\n",
       "      <td>0.000035</td>\n",
       "      <td>0.000287</td>\n",
       "      <td>0.998704</td>\n",
       "      <td>0.000949</td>\n",
       "      <td>0.000005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10789</th>\n",
       "      <td>1</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>1.641601</td>\n",
       "      <td>111.830924</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1078.946835</td>\n",
       "      <td>1722.080284</td>\n",
       "      <td>6</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000002</td>\n",
       "      <td>0.000007</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.000154</td>\n",
       "      <td>0.000059</td>\n",
       "      <td>0.999757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10790</th>\n",
       "      <td>0</td>\n",
       "      <td>21.000000</td>\n",
       "      <td>1.770000</td>\n",
       "      <td>75.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>502.729101</td>\n",
       "      <td>573.098750</td>\n",
       "      <td>2</td>\n",
       "      <td>0.000212</td>\n",
       "      <td>0.341897</td>\n",
       "      <td>0.657261</td>\n",
       "      <td>0.000391</td>\n",
       "      <td>0.000052</td>\n",
       "      <td>0.000019</td>\n",
       "      <td>0.000168</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10791</th>\n",
       "      <td>0</td>\n",
       "      <td>29.669219</td>\n",
       "      <td>1.774644</td>\n",
       "      <td>105.966894</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2.934671</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>998.283353</td>\n",
       "      <td>1132.127734</td>\n",
       "      <td>4</td>\n",
       "      <td>0.000621</td>\n",
       "      <td>0.000067</td>\n",
       "      <td>0.000338</td>\n",
       "      <td>0.004232</td>\n",
       "      <td>0.936736</td>\n",
       "      <td>0.057409</td>\n",
       "      <td>0.000597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10792</th>\n",
       "      <td>1</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>1.656504</td>\n",
       "      <td>111.884535</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1060.128308</td>\n",
       "      <td>1662.532588</td>\n",
       "      <td>6</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000002</td>\n",
       "      <td>0.000007</td>\n",
       "      <td>0.000021</td>\n",
       "      <td>0.000173</td>\n",
       "      <td>0.000076</td>\n",
       "      <td>0.999719</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10793 rows × 36 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       Gender        Age    Height      Weight  \\\n",
       "0           1  21.000000  1.550000   51.000000   \n",
       "1           0  20.000000  1.700000   80.000000   \n",
       "2           1  18.000000  1.600000   60.000000   \n",
       "3           1  26.000000  1.632983  111.720238   \n",
       "4           1  21.682636  1.748524  133.845064   \n",
       "...       ...        ...       ...         ...   \n",
       "10788       0  18.000000  1.780000  108.000000   \n",
       "10789       1  26.000000  1.641601  111.830924   \n",
       "10790       0  21.000000  1.770000   75.000000   \n",
       "10791       0  29.669219  1.774644  105.966894   \n",
       "10792       1  26.000000  1.656504  111.884535   \n",
       "\n",
       "       family_history_with_overweight  FAVC      FCVC  NCP  CAEC  SMOKE  ...  \\\n",
       "0                                   0     1  3.000000  1.0     2      0  ...   \n",
       "1                                   1     1  2.000000  3.0     1      0  ...   \n",
       "2                                   1     1  2.000000  3.0     1      0  ...   \n",
       "3                                   1     1  3.000000  3.0     1      0  ...   \n",
       "4                                   1     1  3.000000  3.0     1      0  ...   \n",
       "...                               ...   ...       ...  ...   ...    ...  ...   \n",
       "10788                               1     1  2.000000  3.0     1      0  ...   \n",
       "10789                               1     1  3.000000  3.0     1      0  ...   \n",
       "10790                               0     1  3.000000  3.0     2      0  ...   \n",
       "10791                               1     1  2.934671  3.0     1      0  ...   \n",
       "10792                               1     1  3.000000  3.0     1      0  ...   \n",
       "\n",
       "       Age * BMI^2  Age^2 * BMI^2  NObeyesdad         0         1         2  \\\n",
       "0       445.785640     450.623213           0  0.797175  0.196596  0.005617   \n",
       "1       553.633218     766.274350           3  0.000041  0.001926  0.047782   \n",
       "2       421.875000     549.316406           1  0.011628  0.953612  0.033047   \n",
       "3      1089.285877    1755.242193           6  0.000001  0.000002  0.000008   \n",
       "4       949.229536    1916.541944           6  0.000001  0.000001  0.000004   \n",
       "...            ...            ...         ...       ...       ...       ...   \n",
       "10788   613.558894    1161.896656           4  0.000007  0.000014  0.000035   \n",
       "10789  1078.946835    1722.080284           6  0.000001  0.000002  0.000007   \n",
       "10790   502.729101     573.098750           2  0.000212  0.341897  0.657261   \n",
       "10791   998.283353    1132.127734           4  0.000621  0.000067  0.000338   \n",
       "10792  1060.128308    1662.532588           6  0.000001  0.000002  0.000007   \n",
       "\n",
       "              3         4         5         6  \n",
       "0      0.000520  0.000032  0.000019  0.000040  \n",
       "1      0.948167  0.002041  0.000018  0.000025  \n",
       "2      0.001538  0.000150  0.000005  0.000020  \n",
       "3      0.000018  0.000172  0.000040  0.999759  \n",
       "4      0.000012  0.000092  0.000058  0.999831  \n",
       "...         ...       ...       ...       ...  \n",
       "10788  0.000287  0.998704  0.000949  0.000005  \n",
       "10789  0.000020  0.000154  0.000059  0.999757  \n",
       "10790  0.000391  0.000052  0.000019  0.000168  \n",
       "10791  0.004232  0.936736  0.057409  0.000597  \n",
       "10792  0.000021  0.000173  0.000076  0.999719  \n",
       "\n",
       "[10793 rows x 36 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_960f5_row8_col1 {\n",
       "  background-color: lightgreen;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_960f5\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_960f5_level0_col0\" class=\"col_heading level0 col0\" >Description</th>\n",
       "      <th id=\"T_960f5_level0_col1\" class=\"col_heading level0 col1\" >Value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_960f5_row0_col0\" class=\"data row0 col0\" >Session id</td>\n",
       "      <td id=\"T_960f5_row0_col1\" class=\"data row0 col1\" >5119</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_960f5_row1_col0\" class=\"data row1 col0\" >Target</td>\n",
       "      <td id=\"T_960f5_row1_col1\" class=\"data row1 col1\" >NObeyesdad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_960f5_row2_col0\" class=\"data row2 col0\" >Target type</td>\n",
       "      <td id=\"T_960f5_row2_col1\" class=\"data row2 col1\" >Multiclass</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_960f5_row3_col0\" class=\"data row3 col0\" >Original data shape</td>\n",
       "      <td id=\"T_960f5_row3_col1\" class=\"data row3 col1\" >(10793, 35)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_960f5_row4_col0\" class=\"data row4 col0\" >Transformed data shape</td>\n",
       "      <td id=\"T_960f5_row4_col1\" class=\"data row4 col1\" >(10793, 35)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
       "      <td id=\"T_960f5_row5_col0\" class=\"data row5 col0\" >Transformed train set shape</td>\n",
       "      <td id=\"T_960f5_row5_col1\" class=\"data row5 col1\" >(7555, 35)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_960f5_row6_col0\" class=\"data row6 col0\" >Transformed test set shape</td>\n",
       "      <td id=\"T_960f5_row6_col1\" class=\"data row6 col1\" >(3238, 35)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
       "      <td id=\"T_960f5_row7_col0\" class=\"data row7 col0\" >Numeric features</td>\n",
       "      <td id=\"T_960f5_row7_col1\" class=\"data row7 col1\" >34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row8\" class=\"row_heading level0 row8\" >8</th>\n",
       "      <td id=\"T_960f5_row8_col0\" class=\"data row8 col0\" >Preprocess</td>\n",
       "      <td id=\"T_960f5_row8_col1\" class=\"data row8 col1\" >True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row9\" class=\"row_heading level0 row9\" >9</th>\n",
       "      <td id=\"T_960f5_row9_col0\" class=\"data row9 col0\" >Imputation type</td>\n",
       "      <td id=\"T_960f5_row9_col1\" class=\"data row9 col1\" >simple</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row10\" class=\"row_heading level0 row10\" >10</th>\n",
       "      <td id=\"T_960f5_row10_col0\" class=\"data row10 col0\" >Numeric imputation</td>\n",
       "      <td id=\"T_960f5_row10_col1\" class=\"data row10 col1\" >mean</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row11\" class=\"row_heading level0 row11\" >11</th>\n",
       "      <td id=\"T_960f5_row11_col0\" class=\"data row11 col0\" >Categorical imputation</td>\n",
       "      <td id=\"T_960f5_row11_col1\" class=\"data row11 col1\" >mode</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row12\" class=\"row_heading level0 row12\" >12</th>\n",
       "      <td id=\"T_960f5_row12_col0\" class=\"data row12 col0\" >Fold Generator</td>\n",
       "      <td id=\"T_960f5_row12_col1\" class=\"data row12 col1\" >StratifiedKFold</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row13\" class=\"row_heading level0 row13\" >13</th>\n",
       "      <td id=\"T_960f5_row13_col0\" class=\"data row13 col0\" >Fold Number</td>\n",
       "      <td id=\"T_960f5_row13_col1\" class=\"data row13 col1\" >10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row14\" class=\"row_heading level0 row14\" >14</th>\n",
       "      <td id=\"T_960f5_row14_col0\" class=\"data row14 col0\" >CPU Jobs</td>\n",
       "      <td id=\"T_960f5_row14_col1\" class=\"data row14 col1\" >-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row15\" class=\"row_heading level0 row15\" >15</th>\n",
       "      <td id=\"T_960f5_row15_col0\" class=\"data row15 col0\" >Use GPU</td>\n",
       "      <td id=\"T_960f5_row15_col1\" class=\"data row15 col1\" >False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row16\" class=\"row_heading level0 row16\" >16</th>\n",
       "      <td id=\"T_960f5_row16_col0\" class=\"data row16 col0\" >Log Experiment</td>\n",
       "      <td id=\"T_960f5_row16_col1\" class=\"data row16 col1\" >False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row17\" class=\"row_heading level0 row17\" >17</th>\n",
       "      <td id=\"T_960f5_row17_col0\" class=\"data row17 col0\" >Experiment Name</td>\n",
       "      <td id=\"T_960f5_row17_col1\" class=\"data row17 col1\" >clf-default-name</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_960f5_level0_row18\" class=\"row_heading level0 row18\" >18</th>\n",
       "      <td id=\"T_960f5_row18_col0\" class=\"data row18 col0\" >USI</td>\n",
       "      <td id=\"T_960f5_row18_col1\" class=\"data row18 col1\" >2ccb</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x143ae0a90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_d099c_row10_col0, #T_d099c_row10_col1, #T_d099c_row10_col2, #T_d099c_row10_col3, #T_d099c_row10_col4, #T_d099c_row10_col5, #T_d099c_row10_col6 {\n",
       "  background: yellow;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_d099c\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_d099c_level0_col0\" class=\"col_heading level0 col0\" >Accuracy</th>\n",
       "      <th id=\"T_d099c_level0_col1\" class=\"col_heading level0 col1\" >AUC</th>\n",
       "      <th id=\"T_d099c_level0_col2\" class=\"col_heading level0 col2\" >Recall</th>\n",
       "      <th id=\"T_d099c_level0_col3\" class=\"col_heading level0 col3\" >Prec.</th>\n",
       "      <th id=\"T_d099c_level0_col4\" class=\"col_heading level0 col4\" >F1</th>\n",
       "      <th id=\"T_d099c_level0_col5\" class=\"col_heading level0 col5\" >Kappa</th>\n",
       "      <th id=\"T_d099c_level0_col6\" class=\"col_heading level0 col6\" >MCC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >Fold</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "      <th class=\"blank col4\" >&nbsp;</th>\n",
       "      <th class=\"blank col5\" >&nbsp;</th>\n",
       "      <th class=\"blank col6\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_d099c_row0_col0\" class=\"data row0 col0\" >0.8307</td>\n",
       "      <td id=\"T_d099c_row0_col1\" class=\"data row0 col1\" >0.9024</td>\n",
       "      <td id=\"T_d099c_row0_col2\" class=\"data row0 col2\" >0.8307</td>\n",
       "      <td id=\"T_d099c_row0_col3\" class=\"data row0 col3\" >0.8307</td>\n",
       "      <td id=\"T_d099c_row0_col4\" class=\"data row0 col4\" >0.8304</td>\n",
       "      <td id=\"T_d099c_row0_col5\" class=\"data row0 col5\" >0.8013</td>\n",
       "      <td id=\"T_d099c_row0_col6\" class=\"data row0 col6\" >0.8014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "      <td id=\"T_d099c_row1_col0\" class=\"data row1 col0\" >0.8241</td>\n",
       "      <td id=\"T_d099c_row1_col1\" class=\"data row1 col1\" >0.8991</td>\n",
       "      <td id=\"T_d099c_row1_col2\" class=\"data row1 col2\" >0.8241</td>\n",
       "      <td id=\"T_d099c_row1_col3\" class=\"data row1 col3\" >0.8259</td>\n",
       "      <td id=\"T_d099c_row1_col4\" class=\"data row1 col4\" >0.8248</td>\n",
       "      <td id=\"T_d099c_row1_col5\" class=\"data row1 col5\" >0.7937</td>\n",
       "      <td id=\"T_d099c_row1_col6\" class=\"data row1 col6\" >0.7938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "      <td id=\"T_d099c_row2_col0\" class=\"data row2 col0\" >0.8571</td>\n",
       "      <td id=\"T_d099c_row2_col1\" class=\"data row2 col1\" >0.9178</td>\n",
       "      <td id=\"T_d099c_row2_col2\" class=\"data row2 col2\" >0.8571</td>\n",
       "      <td id=\"T_d099c_row2_col3\" class=\"data row2 col3\" >0.8616</td>\n",
       "      <td id=\"T_d099c_row2_col4\" class=\"data row2 col4\" >0.8588</td>\n",
       "      <td id=\"T_d099c_row2_col5\" class=\"data row2 col5\" >0.8324</td>\n",
       "      <td id=\"T_d099c_row2_col6\" class=\"data row2 col6\" >0.8325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "      <td id=\"T_d099c_row3_col0\" class=\"data row3 col0\" >0.8347</td>\n",
       "      <td id=\"T_d099c_row3_col1\" class=\"data row3 col1\" >0.9048</td>\n",
       "      <td id=\"T_d099c_row3_col2\" class=\"data row3 col2\" >0.8347</td>\n",
       "      <td id=\"T_d099c_row3_col3\" class=\"data row3 col3\" >0.8336</td>\n",
       "      <td id=\"T_d099c_row3_col4\" class=\"data row3 col4\" >0.8334</td>\n",
       "      <td id=\"T_d099c_row3_col5\" class=\"data row3 col5\" >0.8059</td>\n",
       "      <td id=\"T_d099c_row3_col6\" class=\"data row3 col6\" >0.8061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "      <td id=\"T_d099c_row4_col0\" class=\"data row4 col0\" >0.8254</td>\n",
       "      <td id=\"T_d099c_row4_col1\" class=\"data row4 col1\" >0.8996</td>\n",
       "      <td id=\"T_d099c_row4_col2\" class=\"data row4 col2\" >0.8254</td>\n",
       "      <td id=\"T_d099c_row4_col3\" class=\"data row4 col3\" >0.8254</td>\n",
       "      <td id=\"T_d099c_row4_col4\" class=\"data row4 col4\" >0.8253</td>\n",
       "      <td id=\"T_d099c_row4_col5\" class=\"data row4 col5\" >0.7951</td>\n",
       "      <td id=\"T_d099c_row4_col6\" class=\"data row4 col6\" >0.7951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
       "      <td id=\"T_d099c_row5_col0\" class=\"data row5 col0\" >0.8477</td>\n",
       "      <td id=\"T_d099c_row5_col1\" class=\"data row5 col1\" >0.9123</td>\n",
       "      <td id=\"T_d099c_row5_col2\" class=\"data row5 col2\" >0.8477</td>\n",
       "      <td id=\"T_d099c_row5_col3\" class=\"data row5 col3\" >0.8471</td>\n",
       "      <td id=\"T_d099c_row5_col4\" class=\"data row5 col4\" >0.8473</td>\n",
       "      <td id=\"T_d099c_row5_col5\" class=\"data row5 col5\" >0.8213</td>\n",
       "      <td id=\"T_d099c_row5_col6\" class=\"data row5 col6\" >0.8213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
       "      <td id=\"T_d099c_row6_col0\" class=\"data row6 col0\" >0.8278</td>\n",
       "      <td id=\"T_d099c_row6_col1\" class=\"data row6 col1\" >0.9007</td>\n",
       "      <td id=\"T_d099c_row6_col2\" class=\"data row6 col2\" >0.8278</td>\n",
       "      <td id=\"T_d099c_row6_col3\" class=\"data row6 col3\" >0.8268</td>\n",
       "      <td id=\"T_d099c_row6_col4\" class=\"data row6 col4\" >0.8271</td>\n",
       "      <td id=\"T_d099c_row6_col5\" class=\"data row6 col5\" >0.7979</td>\n",
       "      <td id=\"T_d099c_row6_col6\" class=\"data row6 col6\" >0.7980</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
       "      <td id=\"T_d099c_row7_col0\" class=\"data row7 col0\" >0.8291</td>\n",
       "      <td id=\"T_d099c_row7_col1\" class=\"data row7 col1\" >0.9019</td>\n",
       "      <td id=\"T_d099c_row7_col2\" class=\"data row7 col2\" >0.8291</td>\n",
       "      <td id=\"T_d099c_row7_col3\" class=\"data row7 col3\" >0.8330</td>\n",
       "      <td id=\"T_d099c_row7_col4\" class=\"data row7 col4\" >0.8298</td>\n",
       "      <td id=\"T_d099c_row7_col5\" class=\"data row7 col5\" >0.7997</td>\n",
       "      <td id=\"T_d099c_row7_col6\" class=\"data row7 col6\" >0.8002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row8\" class=\"row_heading level0 row8\" >8</th>\n",
       "      <td id=\"T_d099c_row8_col0\" class=\"data row8 col0\" >0.8424</td>\n",
       "      <td id=\"T_d099c_row8_col1\" class=\"data row8 col1\" >0.9093</td>\n",
       "      <td id=\"T_d099c_row8_col2\" class=\"data row8 col2\" >0.8424</td>\n",
       "      <td id=\"T_d099c_row8_col3\" class=\"data row8 col3\" >0.8451</td>\n",
       "      <td id=\"T_d099c_row8_col4\" class=\"data row8 col4\" >0.8434</td>\n",
       "      <td id=\"T_d099c_row8_col5\" class=\"data row8 col5\" >0.8151</td>\n",
       "      <td id=\"T_d099c_row8_col6\" class=\"data row8 col6\" >0.8152</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row9\" class=\"row_heading level0 row9\" >9</th>\n",
       "      <td id=\"T_d099c_row9_col0\" class=\"data row9 col0\" >0.8331</td>\n",
       "      <td id=\"T_d099c_row9_col1\" class=\"data row9 col1\" >0.9038</td>\n",
       "      <td id=\"T_d099c_row9_col2\" class=\"data row9 col2\" >0.8331</td>\n",
       "      <td id=\"T_d099c_row9_col3\" class=\"data row9 col3\" >0.8334</td>\n",
       "      <td id=\"T_d099c_row9_col4\" class=\"data row9 col4\" >0.8329</td>\n",
       "      <td id=\"T_d099c_row9_col5\" class=\"data row9 col5\" >0.8041</td>\n",
       "      <td id=\"T_d099c_row9_col6\" class=\"data row9 col6\" >0.8043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row10\" class=\"row_heading level0 row10\" >Mean</th>\n",
       "      <td id=\"T_d099c_row10_col0\" class=\"data row10 col0\" >0.8352</td>\n",
       "      <td id=\"T_d099c_row10_col1\" class=\"data row10 col1\" >0.9052</td>\n",
       "      <td id=\"T_d099c_row10_col2\" class=\"data row10 col2\" >0.8352</td>\n",
       "      <td id=\"T_d099c_row10_col3\" class=\"data row10 col3\" >0.8363</td>\n",
       "      <td id=\"T_d099c_row10_col4\" class=\"data row10 col4\" >0.8353</td>\n",
       "      <td id=\"T_d099c_row10_col5\" class=\"data row10 col5\" >0.8066</td>\n",
       "      <td id=\"T_d099c_row10_col6\" class=\"data row10 col6\" >0.8068</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d099c_level0_row11\" class=\"row_heading level0 row11\" >Std</th>\n",
       "      <td id=\"T_d099c_row11_col0\" class=\"data row11 col0\" >0.0101</td>\n",
       "      <td id=\"T_d099c_row11_col1\" class=\"data row11 col1\" >0.0058</td>\n",
       "      <td id=\"T_d099c_row11_col2\" class=\"data row11 col2\" >0.0101</td>\n",
       "      <td id=\"T_d099c_row11_col3\" class=\"data row11 col3\" >0.0110</td>\n",
       "      <td id=\"T_d099c_row11_col4\" class=\"data row11 col4\" >0.0105</td>\n",
       "      <td id=\"T_d099c_row11_col5\" class=\"data row11 col5\" >0.0119</td>\n",
       "      <td id=\"T_d099c_row11_col6\" class=\"data row11 col6\" >0.0119</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x143b24130>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_d319f\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_d319f_level0_col0\" class=\"col_heading level0 col0\" >Model</th>\n",
       "      <th id=\"T_d319f_level0_col1\" class=\"col_heading level0 col1\" >Accuracy</th>\n",
       "      <th id=\"T_d319f_level0_col2\" class=\"col_heading level0 col2\" >AUC</th>\n",
       "      <th id=\"T_d319f_level0_col3\" class=\"col_heading level0 col3\" >Recall</th>\n",
       "      <th id=\"T_d319f_level0_col4\" class=\"col_heading level0 col4\" >Prec.</th>\n",
       "      <th id=\"T_d319f_level0_col5\" class=\"col_heading level0 col5\" >F1</th>\n",
       "      <th id=\"T_d319f_level0_col6\" class=\"col_heading level0 col6\" >Kappa</th>\n",
       "      <th id=\"T_d319f_level0_col7\" class=\"col_heading level0 col7\" >MCC</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_d319f_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_d319f_row0_col0\" class=\"data row0 col0\" >Decision Tree Classifier</td>\n",
       "      <td id=\"T_d319f_row0_col1\" class=\"data row0 col1\" >0.8222</td>\n",
       "      <td id=\"T_d319f_row0_col2\" class=\"data row0 col2\" >0.8973</td>\n",
       "      <td id=\"T_d319f_row0_col3\" class=\"data row0 col3\" >0.8222</td>\n",
       "      <td id=\"T_d319f_row0_col4\" class=\"data row0 col4\" >0.8227</td>\n",
       "      <td id=\"T_d319f_row0_col5\" class=\"data row0 col5\" >0.8222</td>\n",
       "      <td id=\"T_d319f_row0_col6\" class=\"data row0 col6\" >0.7916</td>\n",
       "      <td id=\"T_d319f_row0_col7\" class=\"data row0 col7\" >0.7917</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x28324fa90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "ValueError",
     "evalue": "Classification metrics can't handle a mix of multiclass and continuous-multioutput targets",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[27], line 64\u001b[0m\n\u001b[1;32m     61\u001b[0m y_val_pred_pycaret \u001b[38;5;241m=\u001b[39m predict_model(model, data\u001b[38;5;241m=\u001b[39mval_data)\n\u001b[1;32m     63\u001b[0m \u001b[38;5;66;03m# Evaluate performance\u001b[39;00m\n\u001b[0;32m---> 64\u001b[0m precision, recall, f1, support \u001b[38;5;241m=\u001b[39m \u001b[43mprecision_recall_fscore_support\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_val_pred_pycaret\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweighted\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     66\u001b[0m \u001b[38;5;66;03m# Log metrics and model using MLflow\u001b[39;00m\n\u001b[1;32m     67\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m mlflow\u001b[38;5;241m.\u001b[39mstart_run(run_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPyCaret_Extended_Engineering\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m     68\u001b[0m     \u001b[38;5;66;03m# Log PyCaret model\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    208\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m    209\u001b[0m         skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m    210\u001b[0m             prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m    211\u001b[0m         )\n\u001b[1;32m    212\u001b[0m     ):\n\u001b[0;32m--> 213\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    215\u001b[0m     \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m    217\u001b[0m     \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m    218\u001b[0m     \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m    219\u001b[0m     msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m    220\u001b[0m         \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    221\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    222\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m    223\u001b[0m     )\n",
      "File \u001b[0;32m~/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1755\u001b[0m, in \u001b[0;36mprecision_recall_fscore_support\u001b[0;34m(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight, zero_division)\u001b[0m\n\u001b[1;32m   1592\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Compute precision, recall, F-measure and support for each class.\u001b[39;00m\n\u001b[1;32m   1593\u001b[0m \n\u001b[1;32m   1594\u001b[0m \u001b[38;5;124;03mThe precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1752\u001b[0m \u001b[38;5;124;03m array([2, 2, 2]))\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   1754\u001b[0m _check_zero_division(zero_division)\n\u001b[0;32m-> 1755\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[43m_check_set_wise_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1757\u001b[0m \u001b[38;5;66;03m# Calculate tp_sum, pred_sum, true_sum ###\u001b[39;00m\n\u001b[1;32m   1758\u001b[0m samplewise \u001b[38;5;241m=\u001b[39m average \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "File \u001b[0;32m~/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1527\u001b[0m, in \u001b[0;36m_check_set_wise_labels\u001b[0;34m(y_true, y_pred, average, labels, pos_label)\u001b[0m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m average \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m average_options \u001b[38;5;129;01mand\u001b[39;00m average \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m   1525\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maverage has to be one of \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(average_options))\n\u001b[0;32m-> 1527\u001b[0m y_type, y_true, y_pred \u001b[38;5;241m=\u001b[39m \u001b[43m_check_targets\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1528\u001b[0m \u001b[38;5;66;03m# Convert to Python primitive type to avoid NumPy type / Python str\u001b[39;00m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;66;03m# comparison. See https://github.com/numpy/numpy/issues/6784\u001b[39;00m\n\u001b[1;32m   1530\u001b[0m present_labels \u001b[38;5;241m=\u001b[39m unique_labels(y_true, y_pred)\u001b[38;5;241m.\u001b[39mtolist()\n",
      "File \u001b[0;32m~/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/metrics/_classification.py:94\u001b[0m, in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m     91\u001b[0m     y_type \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulticlass\u001b[39m\u001b[38;5;124m\"\u001b[39m}\n\u001b[1;32m     93\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(y_type) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m---> 94\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m     95\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClassification metrics can\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt handle a mix of \u001b[39m\u001b[38;5;132;01m{0}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{1}\u001b[39;00m\u001b[38;5;124m targets\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m     96\u001b[0m             type_true, type_pred\n\u001b[1;32m     97\u001b[0m         )\n\u001b[1;32m     98\u001b[0m     )\n\u001b[1;32m    100\u001b[0m \u001b[38;5;66;03m# We can't have more than one value on y_type => The set is no more needed\u001b[39;00m\n\u001b[1;32m    101\u001b[0m y_type \u001b[38;5;241m=\u001b[39m y_type\u001b[38;5;241m.\u001b[39mpop()\n",
      "\u001b[0;31mValueError\u001b[0m: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets"
     ]
    }
   ],
   "source": [
    "path = '/Users/arham/Downloads/Projects/01-Dataset/01-Data-for-model-building/train.csv'\n",
    "train_df, val_df, test_df = load_data(path)\n",
    "\n",
    "train_df = datatypes(train_df)\n",
    "train_df = encode_target(train_df)\n",
    "train_df = age_binning(train_df)\n",
    "train_df, scaler_age = age_scaling_minmax(train_df)\n",
    "train_df = age_scaling_log(train_df)\n",
    "train_df, scaler_weight = weight_scaling_minmax(train_df)\n",
    "train_df = weight_scaling_log(train_df)\n",
    "train_df, scaler_height = height_scaling_minmax(train_df)\n",
    "train_df = height_scaling_log(train_df)\n",
    "train_df = make_gender_binary(train_df)\n",
    "train_df = fix_binary_columns(train_df)\n",
    "train_df = freq_cat_cols(train_df)\n",
    "train_df = Mtrans(train_df)\n",
    "train_df = other_features(train_df)\n",
    "\n",
    "val_df = test_pipeline(val_df, scaler_age, scaler_weight, scaler_height)\n",
    "test_df = test_pipeline(test_df, scaler_age, scaler_weight, scaler_height)\n",
    "\n",
    "Target = 'NObeyesdad'\n",
    "features = train_df.columns.drop(Target)\n",
    "\n",
    "features = ['Gender', 'Age', 'Height', 'Weight', 'family_history_with_overweight',\n",
    "       'FAVC', 'FCVC', 'NCP', 'CAEC', 'SMOKE', 'CH2O', 'SCC', 'FAF', 'TUE',\n",
    "       'CALC', 'Age_Group', \n",
    "       'MTRANS_Automobile', 'MTRANS_Bike', 'MTRANS_Motorbike',\n",
    "       'MTRANS_Public_Transportation', 'MTRANS_Walking', 'BMI', 'Age^2',\n",
    "       'Age^3', 'BMI^2', 'Age * BMI', 'Age * BMI^2', 'Age^2 * BMI^2', \n",
    "       'Scaled_Age', 'Log_Age', 'Scaled_Weight', 'Log_Weight', 'Scaled_Height', 'Log_Height']\n",
    "#'Scaled_Age', 'Log_Age', 'Scaled_Weight', 'Log_Weight', 'Scaled_Height', 'Log_Height',\n",
    "\n",
    "X_train = train_df[features]\n",
    "y_train = train_df[Target]\n",
    "X_val = val_df[features]\n",
    "y_val = val_df[Target]\n",
    "X_test = test_df[features]\n",
    "y_test = test_df[Target]\n",
    "\n",
    "#combine X_train and y_train as one dataframe\n",
    "tr = pd.concat([X_train, y_train], axis=1)\n",
    "te = pd.concat([X_test, y_test], axis =1)\n",
    "va = pd.concat([X_val, y_val], axis = 1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zero Shot Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No path specified. Models will be saved in: \"AutogluonModels/ag-20240426_052138\"\n",
      "No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets.\n",
      "\tRecommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):\n",
      "\tpresets='best_quality'   : Maximize accuracy. Default time_limit=3600.\n",
      "\tpresets='high_quality'   : Strong accuracy with fast inference speed. Default time_limit=3600.\n",
      "\tpresets='good_quality'   : Good accuracy with very fast inference speed. Default time_limit=3600.\n",
      "\tpresets='medium_quality' : Fast training time, ideal for initial prototyping.\n",
      "Beginning AutoGluon training ...\n",
      "AutoGluon will save models to \"AutogluonModels/ag-20240426_052138\"\n",
      "=================== System Info ===================\n",
      "AutoGluon Version:  1.1.0\n",
      "Python Version:     3.10.13\n",
      "Operating System:   Darwin\n",
      "Platform Machine:   arm64\n",
      "Platform Version:   Darwin Kernel Version 23.0.0: Fri Sep 15 14:42:57 PDT 2023; root:xnu-10002.1.13~1/RELEASE_ARM64_T8112\n",
      "CPU Count:          8\n",
      "Memory Avail:       1.23 GB / 8.00 GB (15.3%)\n",
      "Disk Space Avail:   14.94 GB / 228.27 GB (6.5%)\n",
      "===================================================\n",
      "Train Data Rows:    10793\n",
      "Train Data Columns: 34\n",
      "Label Column:       NObeyesdad\n",
      "AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).\n",
      "\t7 unique label values:  [0, 3, 1, 6, 4, 2, 5]\n",
      "\tIf 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n",
      "Problem Type:       multiclass\n",
      "Preprocessing data ...\n",
      "Train Data Class Count: 7\n",
      "Using Feature Generators to preprocess the data ...\n",
      "Fitting AutoMLPipelineFeatureGenerator...\n",
      "\tAvailable Memory:                    1252.16 MB\n",
      "\tTrain Data (Original)  Memory Usage: 2.80 MB (0.2% of available memory)\n",
      "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
      "\tStage 1 Generators:\n",
      "\t\tFitting AsTypeFeatureGenerator...\n",
      "\t\t\tNote: Converting 10 features to boolean dtype as they only contain 2 unique values.\n",
      "\tStage 2 Generators:\n",
      "\t\tFitting FillNaFeatureGenerator...\n",
      "\tStage 3 Generators:\n",
      "\t\tFitting IdentityFeatureGenerator...\n",
      "\tStage 4 Generators:\n",
      "\t\tFitting DropUniqueFeatureGenerator...\n",
      "\tStage 5 Generators:\n",
      "\t\tFitting DropDuplicatesFeatureGenerator...\n",
      "\tUseless Original Features (Count: 1): ['Age^2']\n",
      "\t\tThese features carry no predictive signal and should be manually investigated.\n",
      "\t\tThis is typically a feature which has the same value for all rows.\n",
      "\t\tThese features do not need to be present at inference time.\n",
      "\tUnused Original Features (Count: 2): ['Age^3', 'BMI^2']\n",
      "\t\tThese features were not used to generate any of the output features. Add a feature generator compatible with these features to utilize them.\n",
      "\t\tFeatures can also be unused if they carry very little information, such as being categorical but having almost entirely unique values or being duplicates of other features.\n",
      "\t\tThese features do not need to be present at inference time.\n",
      "\t\t('float', []) : 2 | ['Age^3', 'BMI^2']\n",
      "\tTypes of features in original data (raw dtype, special dtypes):\n",
      "\t\t('float', []) : 18 | ['Age', 'Height', 'Weight', 'FCVC', 'NCP', ...]\n",
      "\t\t('int', [])   : 13 | ['Gender', 'family_history_with_overweight', 'FAVC', 'CAEC', 'SMOKE', ...]\n",
      "\tTypes of features in processed data (raw dtype, special dtypes):\n",
      "\t\t('float', [])     : 18 | ['Age', 'Height', 'Weight', 'FCVC', 'NCP', ...]\n",
      "\t\t('int', [])       :  3 | ['CAEC', 'CALC', 'Age_Group']\n",
      "\t\t('int', ['bool']) : 10 | ['Gender', 'family_history_with_overweight', 'FAVC', 'SMOKE', 'SCC', ...]\n",
      "\t0.6s = Fit runtime\n",
      "\t31 features in original data used to generate 31 features in processed data.\n",
      "\tTrain Data (Processed) Memory Usage: 1.83 MB (0.1% of available memory)\n",
      "Data preprocessing and feature engineering runtime = 0.67s ...\n",
      "AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
      "\tTo change this, specify the eval_metric parameter of Predictor()\n",
      "Automatically generating train/validation split with holdout_frac=0.1, Train Rows: 9713, Val Rows: 1080\n",
      "User-specified model hyperparameters to be fit:\n",
      "{\n",
      "\t'NN_TORCH': {},\n",
      "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n",
      "\t'CAT': {},\n",
      "\t'XGB': {},\n",
      "\t'FASTAI': {},\n",
      "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n",
      "}\n",
      "Fitting 13 L1 models ...\n",
      "Fitting model: KNeighborsUnif ...\n",
      "\tWarning: Exception caused KNeighborsUnif to fail during training... Skipping this model.\n",
      "\t\t'NoneType' object has no attribute 'split'\n",
      "Detailed Traceback:\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1926, in _train_and_save\n",
      "    y_pred_proba_val = model.predict_proba(X_val)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 950, in predict_proba\n",
      "    y_pred_proba = self._predict_proba(X=X, **kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 992, in _predict_proba\n",
      "    y_pred_proba = self.model.predict_proba(X)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/neighbors/_classification.py\", line 366, in predict_proba\n",
      "    neigh_ind = self.kneighbors(X, return_distance=False)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/neighbors/_base.py\", line 850, in kneighbors\n",
      "    results = ArgKmin.compute(\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 279, in compute\n",
      "    return ArgKmin32.compute(\n",
      "  File \"sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx\", line 575, in sklearn.metrics._pairwise_distances_reduction._argkmin.ArgKmin32.compute\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/utils/fixes.py\", line 94, in threadpool_limits\n",
      "    return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 171, in __init__\n",
      "    self._original_info = self._set_threadpool_limits()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 268, in _set_threadpool_limits\n",
      "    modules = _ThreadpoolInfo(prefixes=self._prefixes,\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 340, in __init__\n",
      "    self._load_modules()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 371, in _load_modules\n",
      "    self._find_modules_with_dyld()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 428, in _find_modules_with_dyld\n",
      "    self._make_module_from_path(filepath)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 515, in _make_module_from_path\n",
      "    module = module_class(filepath, prefix, user_api, internal_api)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 606, in __init__\n",
      "    self.version = self.get_version()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 646, in get_version\n",
      "    config = get_config().split()\n",
      "AttributeError: 'NoneType' object has no attribute 'split'\n",
      "Fitting model: KNeighborsDist ...\n",
      "\tWarning: Exception caused KNeighborsDist to fail during training... Skipping this model.\n",
      "\t\t'NoneType' object has no attribute 'split'\n",
      "Detailed Traceback:\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1926, in _train_and_save\n",
      "    y_pred_proba_val = model.predict_proba(X_val)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 950, in predict_proba\n",
      "    y_pred_proba = self._predict_proba(X=X, **kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 992, in _predict_proba\n",
      "    y_pred_proba = self.model.predict_proba(X)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/neighbors/_classification.py\", line 369, in predict_proba\n",
      "    neigh_dist, neigh_ind = self.kneighbors(X)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/neighbors/_base.py\", line 850, in kneighbors\n",
      "    results = ArgKmin.compute(\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py\", line 279, in compute\n",
      "    return ArgKmin32.compute(\n",
      "  File \"sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx\", line 575, in sklearn.metrics._pairwise_distances_reduction._argkmin.ArgKmin32.compute\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/sklearn/utils/fixes.py\", line 94, in threadpool_limits\n",
      "    return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 171, in __init__\n",
      "    self._original_info = self._set_threadpool_limits()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 268, in _set_threadpool_limits\n",
      "    modules = _ThreadpoolInfo(prefixes=self._prefixes,\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 340, in __init__\n",
      "    self._load_modules()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 371, in _load_modules\n",
      "    self._find_modules_with_dyld()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 428, in _find_modules_with_dyld\n",
      "    self._make_module_from_path(filepath)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 515, in _make_module_from_path\n",
      "    module = module_class(filepath, prefix, user_api, internal_api)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 606, in __init__\n",
      "    self.version = self.get_version()\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/threadpoolctl.py\", line 646, in get_version\n",
      "    config = get_config().split()\n",
      "AttributeError: 'NoneType' object has no attribute 'split'\n",
      "Fitting model: NeuralNetFastAI ...\n",
      "\t0.8944\t = Validation score   (accuracy)\n",
      "\t8.91s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: LightGBMXT ...\n",
      "\t0.9148\t = Validation score   (accuracy)\n",
      "\t17.49s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: LightGBM ...\n",
      "\t0.9139\t = Validation score   (accuracy)\n",
      "\t28.87s\t = Training   runtime\n",
      "\t0.05s\t = Validation runtime\n",
      "Fitting model: RandomForestGini ...\n",
      "\t0.9093\t = Validation score   (accuracy)\n",
      "\t2.86s\t = Training   runtime\n",
      "\t0.08s\t = Validation runtime\n",
      "Fitting model: RandomForestEntr ...\n",
      "\t0.9028\t = Validation score   (accuracy)\n",
      "\t2.62s\t = Training   runtime\n",
      "\t0.08s\t = Validation runtime\n",
      "Fitting model: CatBoost ...\n",
      "\t0.9102\t = Validation score   (accuracy)\n",
      "\t12.96s\t = Training   runtime\n",
      "\t0.01s\t = Validation runtime\n",
      "Fitting model: ExtraTreesGini ...\n",
      "\t0.9065\t = Validation score   (accuracy)\n",
      "\t1.1s\t = Training   runtime\n",
      "\t0.09s\t = Validation runtime\n",
      "Fitting model: ExtraTreesEntr ...\n",
      "\t0.9083\t = Validation score   (accuracy)\n",
      "\t0.94s\t = Training   runtime\n",
      "\t0.06s\t = Validation runtime\n",
      "Fitting model: XGBoost ...\n",
      "\t0.9111\t = Validation score   (accuracy)\n",
      "\t15.38s\t = Training   runtime\n",
      "\t0.05s\t = Validation runtime\n",
      "Fitting model: NeuralNetTorch ...\n",
      "\tWarning: Exception caused NeuralNetTorch to fail during training... Skipping this model.\n",
      "\t\tmodule 'torch.utils._pytree' has no attribute 'register_pytree_node'\n",
      "Detailed Traceback:\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1904, in _train_and_save\n",
      "    model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1844, in _train_single\n",
      "    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 855, in fit\n",
      "    out = self._fit(**kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py\", line 196, in _fit\n",
      "    self.optimizer = self._init_optimizer(**optimizer_kwargs)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py\", line 553, in _init_optimizer\n",
      "    optimizer = torch.optim.Adam(params=self.model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/optim/adam.py\", line 45, in __init__\n",
      "    super().__init__(params, defaults)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/optim/optimizer.py\", line 266, in __init__\n",
      "    self.add_param_group(cast(dict, param_group))\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_compile.py\", line 22, in inner\n",
      "    import torch._dynamo\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/__init__.py\", line 2, in <module>\n",
      "    from . import allowed_functions, convert_frame, eval_frame, resume_execution\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/allowed_functions.py\", line 26, in <module>\n",
      "    from . import config\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/_dynamo/config.py\", line 49, in <module>\n",
      "    torch.onnx.is_in_onnx_export: False,\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/__init__.py\", line 1831, in __getattr__\n",
      "    return importlib.import_module(f\".{name}\", __name__)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/importlib/__init__.py\", line 126, in import_module\n",
      "    return _bootstrap._gcd_import(name[level:], package, level)\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/__init__.py\", line 46, in <module>\n",
      "    from ._internal.exporter import (  # usort:skip. needs to be last to avoid circular import\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py\", line 42, in <module>\n",
      "    from torch.onnx._internal.fx import (\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/fx/__init__.py\", line 1, in <module>\n",
      "    from .patcher import ONNXTorchPatcher\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/torch/onnx/_internal/fx/patcher.py\", line 11, in <module>\n",
      "    import transformers  # type: ignore[import]\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/__init__.py\", line 26, in <module>\n",
      "    from . import dependency_versions_check\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/dependency_versions_check.py\", line 16, in <module>\n",
      "    from .utils.versions import require_version, require_version_core\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/utils/__init__.py\", line 33, in <module>\n",
      "    from .generic import (\n",
      "  File \"/Users/arham/anaconda3/envs/DataScience/lib/python3.10/site-packages/transformers/utils/generic.py\", line 455, in <module>\n",
      "    _torch_pytree.register_pytree_node(\n",
      "AttributeError: module 'torch.utils._pytree' has no attribute 'register_pytree_node'. Did you mean: '_register_pytree_node'?\n",
      "Fitting model: LightGBMLarge ...\n",
      "\t0.913\t = Validation score   (accuracy)\n",
      "\t95.73s\t = Training   runtime\n",
      "\t0.23s\t = Validation runtime\n",
      "Fitting model: WeightedEnsemble_L2 ...\n",
      "\tEnsemble Weights: {'ExtraTreesGini': 0.286, 'NeuralNetFastAI': 0.19, 'LightGBMXT': 0.19, 'CatBoost': 0.143, 'XGBoost': 0.143, 'LightGBMLarge': 0.048}\n",
      "\t0.9204\t = Validation score   (accuracy)\n",
      "\t0.21s\t = Training   runtime\n",
      "\t0.0s\t = Validation runtime\n",
      "AutoGluon training complete, total runtime = 194.17s ... Best model: \"WeightedEnsemble_L2\"\n",
      "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20240426_052138\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall for class 0: 0.9335443037974683\n",
      "Recall for class 1: 0.8983957219251337\n",
      "Recall for class 2: 0.7582089552238805\n",
      "Recall for class 3: 0.8389057750759878\n",
      "Recall for class 4: 0.8646080760095012\n",
      "Recall for class 5: 0.9641148325358851\n",
      "Recall for class 6: 0.9960474308300395\n"
     ]
    }
   ],
   "source": [
    "from autogluon.tabular import TabularDataset, TabularPredictor\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "import mlflow\n",
    "\n",
    "# Load your data into AutoGluon TabularDataset format\n",
    "train_data = TabularDataset(X_train.join(y_train))\n",
    "val_data = TabularDataset(X_val.join(y_val))\n",
    "\n",
    "# Define the label column\n",
    "label_column = Target  # Replace 'Target' with your actual label column name\n",
    "\n",
    "# Specify the task and run AutoGluon\n",
    "predictor = TabularPredictor(label=label_column).fit(train_data=train_data)\n",
    "\n",
    "# Make predictions on the validation set\n",
    "y_val_pred_autogluon = predictor.predict(val_data.drop(columns=[label_column]))\n",
    "\n",
    "# Evaluate performance\n",
    "precision, recall, f1, support = precision_recall_fscore_support(y_val, y_val_pred_autogluon, average='weighted')\n",
    "\n",
    "# Log metrics and model using MLflow\n",
    "with mlflow.start_run(run_name=\"AutoGluon_Without_Feature_Engineering\"):\n",
    "    # Log AutoGluon model\n",
    "    mlflow.sklearn.log_model(predictor, \"autogluon_model\")\n",
    "    \n",
    "    # Log metrics\n",
    "    mlflow.log_metric('accuracy', accuracy_score(y_val, y_val_pred_autogluon))\n",
    "    mlflow.log_metric('precision', precision)\n",
    "    mlflow.log_metric('recall', recall)\n",
    "    mlflow.log_metric('f1', f1)\n",
    "\n",
    "    # Log recall per class\n",
    "    recall_per_class = recall_score(y_val, y_val_pred_autogluon, average=None)\n",
    "    for i, recall_class in enumerate(recall_per_class):\n",
    "        print(f\"Recall for class {i}: {recall_class}\")\n",
    "        mlflow.log_metric(f'recall_class_{i}', recall_class)\n",
    "\n",
    "    mlflow.set_tag('experiments', 'Arham A.')\n",
    "    mlflow.set_tag('model_name', 'AutoGluon')\n",
    "    mlflow.set_tag('preprocessing', 'Yes')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DataScience",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}