[81a4e3]: / ONNX.ipynb

Download this file

390 lines (389 with data), 9.7 kB

{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T21:53:39.108926Z",
     "start_time": "2024-09-19T21:53:38.127937Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from request import y_true\n",
    "from utils import Preprocess, MissingValue\n",
    "import pickle\n",
    "from fastapi import FastAPI, Request\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import onnx\n",
    "import skl2onnx\n",
    "from skl2onnx import convert_sklearn\n",
    "from skl2onnx.common.data_types import FloatTensorType\n",
    "import os\n",
    "import onnxruntime as ort\n",
    "from onnxruntime_tools import optimizer  # Optimization package\n"
   ],
   "id": "739236a5cdfe49bb",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: onnxruntime_tools is deprecated. Use onnxruntime or onnxruntime-gpu instead. For more information, see https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/README.md.\n",
      "\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T21:52:15.433981Z",
     "start_time": "2024-09-19T21:52:15.390202Z"
    }
   },
   "cell_type": "code",
   "source": [
    "with open(f'model/columns.pkl', 'rb') as f:\n",
    "    cols = pickle.load(f)\n",
    "\n",
    "with open(f'model/scaler.pkl', 'rb') as f:\n",
    "    sc = pickle.load(f)\n",
    "\n",
    "with open(f'model/model.pkl', 'rb') as f:\n",
    "    model = pickle.load(f)"
   ],
   "id": "a4818dbe2c02b226",
   "outputs": [],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T21:52:20.896294Z",
     "start_time": "2024-09-19T21:52:19.139861Z"
    }
   },
   "cell_type": "code",
   "source": [
    "initial_type = [('float_input', FloatTensorType([None, len(cols)]))]\n",
    "\n",
    "# Convert to ONNX format\n",
    "onnx_model = convert_sklearn(model, initial_types=initial_type)\n",
    "\n",
    "# Save the ONNX model\n",
    "onnx_model_path = 'model/random_forest_model.onnx'\n",
    "onnx.save_model(onnx_model, onnx_model_path)\n",
    "print(f\"ONNX model saved at: {onnx_model_path}\")"
   ],
   "id": "586833f5665f6fd8",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX model saved at: random_forest_model.onnx\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T21:55:43.954548Z",
     "start_time": "2024-09-19T21:55:40.665948Z"
    }
   },
   "cell_type": "code",
   "source": [
    "optimized_model_path = 'model/optimized_random_forest_model.onnx'\n",
    "\n",
    "# Check if the model file exists\n",
    "if os.path.exists(onnx_model_path):\n",
    "    # Optimize the model by passing the file path to the optimizer\n",
    "    optimized_model = optimizer.optimize_model(onnx_model_path)\n",
    "\n",
    "    # Save the optimized model\n",
    "    optimized_model.save_model_to_file(optimized_model_path)\n",
    "    print(f\"Optimized ONNX model saved at: {optimized_model_path}\")\n",
    "\n",
    "else:\n",
    "    print(f\"ONNX model file not found at: {onnx_model_path}\")"
   ],
   "id": "9c2298cca891971a",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model producer not matched: Expect pytorch,  Got skl2onnx 1.16.0. Please specify correct --model_type parameter.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stopping at incomplete shape inference at TreeEnsembleClassifier: TreeEnsembleClassifier\n",
      "node inputs:\n",
      "name: \"float_input\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 1\n",
      "    shape {\n",
      "      dim {\n",
      "        dim_param: \"float_input_d0\"\n",
      "      }\n",
      "      dim {\n",
      "        dim_value: 38\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n",
      "node outputs:\n",
      "name: \"label\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 7\n",
      "  }\n",
      "}\n",
      "\n",
      "name: \"probabilities\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 0\n",
      "  }\n",
      "}\n",
      "\n",
      "Optimized ONNX model saved at: optimized_random_forest_model.onnx\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:00:16.531114Z",
     "start_time": "2024-09-19T22:00:11.422588Z"
    }
   },
   "cell_type": "code",
   "source": [
    "df = pd.read_excel('dataset.xlsx',engine='openpyxl')\n",
    "df.drop(['Unnamed: 0', 'visit id'], axis=1, inplace=True)"
   ],
   "id": "70995a8c7e952908",
   "outputs": [],
   "execution_count": 18
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:03:48.837560Z",
     "start_time": "2024-09-19T22:03:30.820365Z"
    }
   },
   "cell_type": "code",
   "source": [
    "\n",
    "preprocessor = Preprocess(\n",
    "    dataframe=df,\n",
    "    missing_value_per=0,\n",
    "    variance_threshold=0,\n",
    "    min_null_per=0\n",
    ")\n",
    "test = preprocessor._mapping(df)\n",
    "\n",
    "m = MissingValue(test)\n",
    "test = m.fill_dataframe()\n",
    "y_true = test['target label / yes no']\n",
    "x_test = test[cols]\n",
    "x_test = sc.transform(x_test)\n",
    "\n",
    "# y_pred = model.predict(x_test)"
   ],
   "id": "6954635fc721a11",
   "outputs": [],
   "execution_count": 20
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:05:52.982598Z",
     "start_time": "2024-09-19T22:05:52.972599Z"
    }
   },
   "cell_type": "code",
   "source": "x_test = np.array(x_test, dtype=np.float32)",
   "id": "60fe97c2f94d42fb",
   "outputs": [],
   "execution_count": 24
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:05:58.059111Z",
     "start_time": "2024-09-19T22:05:57.913089Z"
    }
   },
   "cell_type": "code",
   "source": [
    "session = ort.InferenceSession(optimized_model_path)\n",
    "\n",
    "# Run inference\n",
    "input_name = session.get_inputs()[0].name\n",
    "output_name = session.get_outputs()[0].name\n",
    "\n",
    "# Make predictions\n",
    "predictions = session.run([output_name], {input_name: x_test})"
   ],
   "id": "731d3382ba7bb0aa",
   "outputs": [],
   "execution_count": 25
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:08:15.905866Z",
     "start_time": "2024-09-19T22:08:15.895866Z"
    }
   },
   "cell_type": "code",
   "source": "type(predictions[0])",
   "id": "a582ad4933093a59",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 33
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:08:32.499981Z",
     "start_time": "2024-09-19T22:08:32.478987Z"
    }
   },
   "cell_type": "code",
   "source": "type(y_true.values)",
   "id": "d5fce887f8591592",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 35
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-19T22:08:57.356625Z",
     "start_time": "2024-09-19T22:08:57.336646Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from sklearn.metrics import f1_score\n",
    "f1_score(predictions[0].astype(int), y_true.values.astype(int), average='weighted')"
   ],
   "id": "c1ec6d40409971c4",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8934725519012681"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 37
  },
  {
   "metadata": {
    "collapsed": true
   },
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "\n",
    "\n",
    "\n",
    "app = FastAPI()\n",
    "\n",
    "@app.post(\"/predict/\")\n",
    "async def predict(request: Request):\n",
    "    try:\n",
    "        data = await request.json()\n",
    "\n",
    "        df = pd.read_json(data, orient='records')\n",
    "\n",
    "        preprocessor = Preprocess(\n",
    "            dataframe=df,\n",
    "            missing_value_per=0,\n",
    "            variance_threshold=0,\n",
    "            min_null_per=0\n",
    "        )\n",
    "        test = preprocessor._mapping(df)\n",
    "\n",
    "        m = MissingValue(test)\n",
    "        test = m.fill_dataframe()\n",
    "\n",
    "        x_test = test[cols]\n",
    "        x_test = sc.transform(x_test)\n",
    "\n",
    "        y_pred = model.predict(x_test)\n",
    "\n",
    "        return {\"predictions\": y_pred.tolist()}\n",
    "\n",
    "    except Exception as e:\n",
    "        return {\"error\": str(e)}\n",
    "\n",
    "# Run the application using: uvicorn main:app --reload\n"
   ],
   "id": "initial_id"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}