--- a +++ b/test_preds.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: efficientnet in /home/surya/anaconda3/lib/python3.6/site-packages (1.0.0)\n", + "Requirement already satisfied: scikit-image in /home/surya/anaconda3/lib/python3.6/site-packages (from efficientnet) (0.15.0)\n", + "Requirement already satisfied: keras-applications<=1.0.8,>=1.0.7 in /home/surya/anaconda3/lib/python3.6/site-packages (from efficientnet) (1.0.8)\n", + "Requirement already satisfied: imageio>=2.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.3.0)\n", + "Requirement already satisfied: scipy>=0.17.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (1.1.0)\n", + "Requirement already satisfied: networkx>=2.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.1)\n", + "Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.2.2)\n", + "Requirement already satisfied: PyWavelets>=0.4.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (0.5.2)\n", + "Requirement already satisfied: pillow>=4.3.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (6.2.0)\n", + "Requirement already satisfied: h5py in /home/surya/anaconda3/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (2.7.1)\n", + "Requirement already satisfied: numpy>=1.9.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (1.17.3)\n", + "Requirement already satisfied: decorator>=4.1.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from networkx>=2.0->scikit-image->efficientnet) (4.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (0.10.0)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.2.0)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.7.3)\n", + "Requirement already satisfied: pytz in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2018.5)\n", + "Requirement already satisfied: six>=1.10 in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (1.11.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (1.0.1)\n", + "Requirement already satisfied: setuptools in /home/surya/anaconda3/lib/python3.6/site-packages (from kiwisolver>=1.0.1->matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (39.1.0)\n", + "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n", + "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n", + "Requirement already satisfied: iterative-stratification in /home/surya/anaconda3/lib/python3.6/site-packages (0.1.6)\n", + "Requirement already satisfied: scikit-learn in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (0.20.3)\n", + "Requirement already satisfied: numpy in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (1.17.3)\n", + "Requirement already satisfied: scipy in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (1.1.0)\n", + "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n", + "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n", + "Requirement already satisfied: gdown in /home/surya/anaconda3/lib/python3.6/site-packages (3.8.1)\n", + "Requirement already satisfied: six in /home/surya/.local/lib/python3.6/site-packages (from gdown) (1.11.0)\n", + "Requirement already satisfied: filelock in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (3.0.4)\n", + "Requirement already satisfied: requests in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (2.18.4)\n", + "Requirement already satisfied: tqdm in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (4.28.1)\n", + "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (3.0.4)\n", + "Requirement already satisfied: idna<2.7,>=2.5 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (2.6)\n", + "Requirement already satisfied: urllib3<1.23,>=1.21.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (1.22)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (2018.4.16)\n", + "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n", + "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install efficientnet\n", + "!pip install iterative-stratification\n", + "!pip install gdown" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "if not os.path.isfile('model_effnet_bo_087.h5'):\n", + " !gdown https://drive.google.com/uc?id=1FXF1HymYbRf3OlThMTXAa74TRup3AhD_" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", + "/home/surya/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n", + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import pydicom\n", + "import cv2\n", + "import tensorflow as tf\n", + "import multiprocessing\n", + "from math import ceil, floor\n", + "import keras\n", + "import keras.backend as K\n", + "from keras.callbacks import Callback, ModelCheckpoint\n", + "from keras.layers import Dense, Flatten, Dropout\n", + "from keras.models import Model, load_model\n", + "from keras.utils import Sequence\n", + "from keras.losses import binary_crossentropy\n", + "from keras.optimizers import Adam\n", + "import efficientnet.keras as efn \n", + "from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def correct_dcm(dcm):\n", + " x = dcm.pixel_array + 1000\n", + " px_mode = 4096\n", + " x[x>=px_mode] = x[x>=px_mode] - px_mode\n", + " dcm.PixelData = x.tobytes()\n", + " dcm.RescaleIntercept = -1000\n", + "\n", + "def window_image(dcm, window_center, window_width): \n", + " if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):\n", + " correct_dcm(dcm)\n", + " img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept\n", + " \n", + " # Resize\n", + " img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)\n", + " \n", + " img_min = window_center - window_width // 2\n", + " img_max = window_center + window_width // 2\n", + " img = np.clip(img, img_min, img_max)\n", + " return img\n", + "\n", + "def bsb_window(dcm):\n", + " brain_img = window_image(dcm, 40, 80)\n", + " subdural_img = window_image(dcm, 80, 200)\n", + " soft_img = window_image(dcm, 40, 380)\n", + " \n", + " brain_img = (brain_img - 0) / 80\n", + " subdural_img = (subdural_img - (-20)) / 200\n", + " soft_img = (soft_img - (-150)) / 380\n", + " bsb_img = np.array([brain_img, subdural_img, soft_img]).transpose(1,2,0)\n", + " return bsb_img\n", + "\n", + "def _read(path, SHAPE):\n", + " dcm = pydicom.dcmread(path)\n", + " try:\n", + " img = bsb_window(dcm)\n", + " except:\n", + " img = np.zeros(SHAPE)\n", + " return img" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "base_model = efn.EfficientNetB0(weights = 'imagenet', include_top = False, \\\n", + " pooling = 'avg', input_shape = (256, 256, 3))\n", + "\n", + "x = base_model.output\n", + "x = Dropout(0.125)(x)\n", + "output_layer = Dense(6, activation = 'sigmoid')(x)\n", + "model = Model(inputs=base_model.input, outputs=output_layer)\n", + "model.compile(optimizer = Adam(lr = 0.0001), \n", + " loss = 'binary_crossentropy',\n", + " metrics = ['acc'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model.load_weights('model_effnet_bo_087.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(256, 256)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "img_data = _read('ID_000012eaf.dcm', (256, 256))\n", + "img_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.image.AxesImage at 0x7f1265f52588>" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADIlJREFUeJzt3E+MnPV9x/H3p9gYhRCBS0CusQqJXKnkUMdaARJVRIWagC8mByo4BCtCcg5GSqT04CSHcEyrJpGQWiRHQTFVCkVJED7QNmBFQj1AMIgYG5ewIS5sbNlNiQhqJAfIt4d53Az+zXqXnXl2Zqv3S1rN7G+fmf36YfXmeeZfqgpJGvYH0x5A0uwxDJIahkFSwzBIahgGSQ3DIKnRWxiS3JLk5STzSfb29XskTV76eB1DkguAnwJ/CSwAzwJ3VtVLE/9lkiauryOG64D5qnq1qn4LPAzs7Ol3SZqwdT3d72bg9aHvF4DrF9v4wmyoi7i4p1EkAbzFr35ZVR9ezrZ9hSEj1t5zzpJkN7Ab4CI+wPW5uadRJAE8Wd/7z+Vu29epxAKwZej7q4ATwxtU1b6qmququfVs6GkMSSvRVxieBbYmuSbJhcAdwIGefpekCevlVKKq3klyD/BvwAXAA1V1tI/fJWny+nqMgap6HHi8r/uX1B9f+SipYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ11o1z4yTHgbeAd4F3qmouyUbgn4GrgePAX1XVr8YbU9JqmsQRw19U1baqmuu+3wscrKqtwMHue0lrSB+nEjuB/d31/cBtPfwOST0aNwwF/DDJc0l2d2tXVtVJgO7yilE3TLI7yaEkh97mzJhjSJqksR5jAG6sqhNJrgCeSPIfy71hVe0D9gF8KBtrzDkkTdBYRwxVdaK7PA08ClwHnEqyCaC7PD3ukJJW14rDkOTiJJecvQ58EjgCHAB2dZvtAh4bd0hJq2ucU4krgUeTnL2ff6qqf03yLPBIkruB14Dbxx9T0mpacRiq6lXgz0as/zdw8zhDSZouX/koqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpsWQYkjyQ5HSSI0NrG5M8keSV7vKybj1J7ksyn+Rwku19Di+pH8s5YvgOcMs5a3uBg1W1FTjYfQ9wK7C1+9oN3D+ZMSWtpiXDUFVPAW+cs7wT2N9d3w/cNrT+YA08DVyaZNOkhpW0Olb6GMOVVXUSoLu8olvfDLw+tN1CtyZpDVk34fvLiLUauWGym8HpBhfxgQmPIWkcKz1iOHX2FKG7PN2tLwBbhra7Cjgx6g6qal9VzVXV3Ho2rHAMSX1YaRgOALu667uAx4bW7+qenbgBePPsKYektWPJU4kkDwE3AZcnWQC+CnwNeCTJ3cBrwO3d5o8DO4B54DfAZ3uYWVLPlgxDVd25yI9uHrFtAXvGHUrSdPnKR0kNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKmxZBiSPJDkdJIjQ2v3JvlFkhe6rx1DP/tSkvkkLyf5VF+DS+rPco4YvgPcMmL9m1W1rft6HCDJtcAdwMe62/xDkgsmNayk1bFkGKrqKeCNZd7fTuDhqjpTVT8H5oHrxphP0hSM8xjDPUkOd6cal3Vrm4HXh7ZZ6NYaSXYnOZTk0NucGWMMSZO20jDcD3wU2AacBL7erWfEtjXqDqpqX1XNVdXcejascAxJfVhRGKrqVFW9W1W/A77F708XFoAtQ5teBZwYb0RJq21FYUiyaejbTwNnn7E4ANyRZEOSa4CtwI/HG1HSalu31AZJHgJuAi5PsgB8FbgpyTYGpwnHgc8BVNXRJI8ALwHvAHuq6t1+RpfUl1SNfAhgVX0oG+v63DztMaT/156s7z1XVXPL2dZXPkpqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGosGYYkW5L8KMmxJEeTfL5b35jkiSSvdJeXdetJcl+S+SSHk2zv+x8habKWc8TwDvDFqvpT4AZgT5Jrgb3AwaraChzsvge4Fdjafe0G7p/41JJ6tWQYqupkVT3fXX8LOAZsBnYC+7vN9gO3ddd3Ag/WwNPApUk2TXxySb15X48xJLka+DjwDHBlVZ2EQTyAK7rNNgOvD91soVuTtEYsOwxJPgh8H/hCVf36fJuOWKsR97c7yaEkh97mzHLHkLQKlhWGJOsZROG7VfWDbvnU2VOE7vJ0t74AbBm6+VXAiXPvs6r2VdVcVc2tZ8NK55fUg+U8KxHg28CxqvrG0I8OALu667uAx4bW7+qenbgBePPsKYektWHdMra5EfgM8GKSF7q1LwNfAx5JcjfwGnB797PHgR3APPAb4LMTnVhS75YMQ1X9O6MfNwC4ecT2BewZcy5JU+QrHyU1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGoYBkkNwyCpYRgkNQyDpIZhkNQwDJIahkFSwzBIahgGSQ3DIKlhGCQ1DIOkhmGQ1DAMkhqGQVLDMEhqGAZJDcMgqWEYJDUMg6TGkmFIsiXJj5IcS3I0yee79XuT/CLJC93XjqHbfCnJfJKXk3yqz3+ApMlbt4xt3gG+WFXPJ7kEeC7JE93PvllVfze8cZJrgTuAjwF/BDyZ5E+q6t1JDi6pP0seMVTVyap6vrv+FnAM2Hyem+wEHq6qM1X1c2AeuG4Sw0paHe/rMYYkVwMfB57plu5JcjjJA0ku69Y2A68P3WyBESFJsjvJoSSH3ubM+x5cUn+WHYYkHwS+D3yhqn4N3A98FNgGnAS+fnbTETevZqFqX1XNVdXceja878El9WdZYUiynkEUvltVPwCoqlNV9W5V/Q74Fr8/XVgAtgzd/CrgxORGltS35TwrEeDbwLGq+sbQ+qahzT4NHOmuHwDuSLIhyTXAVuDHkxtZUt+W86zEjcBngBeTvNCtfRm4M8k2BqcJx4HPAVTV0SSPAC8xeEZjj89ISGtLqprT/9UfIvkv4H+AX057lmW4nLUxJ6ydWZ1z8kbN+sdV9eHl3HgmwgCQ5FBVzU17jqWslTlh7czqnJM37qy+JFpSwzBIasxSGPZNe4BlWitzwtqZ1Tknb6xZZ+YxBkmzY5aOGCTNiKmHIckt3duz55PsnfY850pyPMmL3VvLD3VrG5M8keSV7vKype6nh7keSHI6yZGhtZFzZeC+bh8fTrJ9Bmadubftn+cjBmZqv67KRyFU1dS+gAuAnwEfAS4EfgJcO82ZRsx4HLj8nLW/BfZ21/cCfzOFuT4BbAeOLDUXsAP4FwbvY7kBeGYGZr0X+OsR217b/R1sAK7p/j4uWKU5NwHbu+uXAD/t5pmp/XqeOSe2T6d9xHAdMF9Vr1bVb4GHGbxte9btBPZ31/cDt632AFX1FPDGOcuLzbUTeLAGngYuPecl7b1aZNbFTO1t+7X4RwzM1H49z5yLed/7dNphWNZbtKesgB8meS7J7m7tyqo6CYP/SMAVU5vuvRaba1b384rftt+3cz5iYGb36yQ/CmHYtMOwrLdoT9mNVbUduBXYk+QT0x5oBWZxP4/1tv0+jfiIgUU3HbG2arNO+qMQhk07DDP/Fu2qOtFdngYeZXAIdursIWN3eXp6E77HYnPN3H6uGX3b/qiPGGAG92vfH4Uw7TA8C2xNck2SCxl8VuSBKc/0f5Jc3H3OJUkuBj7J4O3lB4Bd3Wa7gMemM2FjsbkOAHd1j6LfALx59tB4WmbxbfuLfcQAM7ZfF5tzovt0NR5FXeIR1h0MHlX9GfCVac9zzmwfYfBo7k+Ao2fnA/4QOAi80l1unMJsDzE4XHybwf8R7l5sLgaHkn/f7eMXgbkZmPUfu1kOd3+4m4a2/0o368vAras4558zOMQ+DLzQfe2Ytf16njkntk995aOkxrRPJSTNIMMgqWEYJDUMg6SGYZDUMAySGoZBUsMwSGr8L5WBYebTjJswAAAAAElFTkSuQmCC\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(img_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 256, 256, 3)" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.empty((1, 256,256, 3))\n", + "X[0] = img_data\n", + "X.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.02540314, 0.00152788, 0.0010443 , 0.0023855 , 0.00326541,\n", + " 0.00998502]], dtype=float32)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preds = model.predict(X)\n", + "preds" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 6)" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preds.shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}