Switch to side-by-side view

--- a
+++ b/voice_activity_detection.ipynb
@@ -0,0 +1,647 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "from sklearn import tree\n",
+    "import random\n",
+    "import copy\n",
+    "import os\n",
+    "import graphviz\n",
+    "import scipy.io.wavfile as wav\n",
+    "from src.voice_activity_detection.extract_features import extract_features"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<class 'pandas.core.frame.DataFrame'>\n",
+      "RangeIndex: 1356 entries, 0 to 1355\n",
+      "Data columns (total 9 columns):\n",
+      "RMS                  1356 non-null int64\n",
+      "ZCR                  1356 non-null float64\n",
+      "audio                1356 non-null object\n",
+      "bandwidth            1356 non-null float64\n",
+      "nwpd                 1356 non-null float64\n",
+      "rse                  1353 non-null float64\n",
+      "spectral_centroid    1356 non-null float64\n",
+      "spectral_flux        1356 non-null float64\n",
+      "spectral_rolloff     1356 non-null float64\n",
+      "dtypes: float64(7), int64(1), object(1)\n",
+      "memory usage: 95.4+ KB\n"
+     ]
+    }
+   ],
+   "source": [
+    "voice_noise_data = np.load(\"numpy_files/musan_mean_speech_noise.npy\").item()\n",
+    "voice_noise_df = pd.DataFrame.from_dict(voice_noise_data)\n",
+    "voice_noise_df.replace([np.inf,-np.inf,np.nan],0)\n",
+    "voice_noise_df.dropna()\n",
+    "voice_noise_df.info()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>RMS</th>\n",
+       "      <th>ZCR</th>\n",
+       "      <th>audio</th>\n",
+       "      <th>bandwidth</th>\n",
+       "      <th>nwpd</th>\n",
+       "      <th>rse</th>\n",
+       "      <th>spectral_centroid</th>\n",
+       "      <th>spectral_flux</th>\n",
+       "      <th>spectral_rolloff</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>120</td>\n",
+       "      <td>0.317919</td>\n",
+       "      <td>1</td>\n",
+       "      <td>2.852637e+05</td>\n",
+       "      <td>-1.086645</td>\n",
+       "      <td>-0.190269</td>\n",
+       "      <td>535.891491</td>\n",
+       "      <td>0.039433</td>\n",
+       "      <td>4020.624371</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>2</td>\n",
+       "      <td>0.442350</td>\n",
+       "      <td>1</td>\n",
+       "      <td>6.743143e+05</td>\n",
+       "      <td>10.745700</td>\n",
+       "      <td>-0.875232</td>\n",
+       "      <td>840.023296</td>\n",
+       "      <td>0.061607</td>\n",
+       "      <td>3833.170732</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>13</td>\n",
+       "      <td>0.464951</td>\n",
+       "      <td>1</td>\n",
+       "      <td>8.572244e+05</td>\n",
+       "      <td>1.387262</td>\n",
+       "      <td>-0.322468</td>\n",
+       "      <td>572.973659</td>\n",
+       "      <td>0.018836</td>\n",
+       "      <td>4938.532110</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>130</td>\n",
+       "      <td>0.150850</td>\n",
+       "      <td>1</td>\n",
+       "      <td>3.938587e+06</td>\n",
+       "      <td>0.498670</td>\n",
+       "      <td>-0.290313</td>\n",
+       "      <td>3448.585720</td>\n",
+       "      <td>0.010311</td>\n",
+       "      <td>6988.571429</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>42</td>\n",
+       "      <td>0.250210</td>\n",
+       "      <td>1</td>\n",
+       "      <td>6.657314e+05</td>\n",
+       "      <td>0.925133</td>\n",
+       "      <td>-0.263716</td>\n",
+       "      <td>745.102831</td>\n",
+       "      <td>0.014008</td>\n",
+       "      <td>4590.849315</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>93</td>\n",
+       "      <td>0.481068</td>\n",
+       "      <td>1</td>\n",
+       "      <td>7.152688e+05</td>\n",
+       "      <td>-0.090272</td>\n",
+       "      <td>-0.184233</td>\n",
+       "      <td>738.235253</td>\n",
+       "      <td>0.022868</td>\n",
+       "      <td>4529.035025</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>12</td>\n",
+       "      <td>0.496398</td>\n",
+       "      <td>1</td>\n",
+       "      <td>1.041786e+06</td>\n",
+       "      <td>-0.611150</td>\n",
+       "      <td>-0.405754</td>\n",
+       "      <td>1783.177886</td>\n",
+       "      <td>0.020107</td>\n",
+       "      <td>5412.625995</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>1967</td>\n",
+       "      <td>0.401090</td>\n",
+       "      <td>0</td>\n",
+       "      <td>2.284852e+05</td>\n",
+       "      <td>-0.196838</td>\n",
+       "      <td>-0.301843</td>\n",
+       "      <td>527.226453</td>\n",
+       "      <td>0.034232</td>\n",
+       "      <td>3316.301112</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>151</td>\n",
+       "      <td>0.499611</td>\n",
+       "      <td>1</td>\n",
+       "      <td>3.597816e+05</td>\n",
+       "      <td>0.241001</td>\n",
+       "      <td>-0.153348</td>\n",
+       "      <td>524.060903</td>\n",
+       "      <td>0.011216</td>\n",
+       "      <td>4010.773810</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>5</td>\n",
+       "      <td>0.500247</td>\n",
+       "      <td>1</td>\n",
+       "      <td>5.324209e+06</td>\n",
+       "      <td>-0.130207</td>\n",
+       "      <td>-0.202091</td>\n",
+       "      <td>2541.108767</td>\n",
+       "      <td>0.006781</td>\n",
+       "      <td>7003.905325</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "    RMS       ZCR  audio     bandwidth       nwpd       rse  \\\n",
+       "0   120  0.317919      1  2.852637e+05  -1.086645 -0.190269   \n",
+       "1     2  0.442350      1  6.743143e+05  10.745700 -0.875232   \n",
+       "2    13  0.464951      1  8.572244e+05   1.387262 -0.322468   \n",
+       "3   130  0.150850      1  3.938587e+06   0.498670 -0.290313   \n",
+       "4    42  0.250210      1  6.657314e+05   0.925133 -0.263716   \n",
+       "5    93  0.481068      1  7.152688e+05  -0.090272 -0.184233   \n",
+       "6    12  0.496398      1  1.041786e+06  -0.611150 -0.405754   \n",
+       "7  1967  0.401090      0  2.284852e+05  -0.196838 -0.301843   \n",
+       "8   151  0.499611      1  3.597816e+05   0.241001 -0.153348   \n",
+       "9     5  0.500247      1  5.324209e+06  -0.130207 -0.202091   \n",
+       "\n",
+       "   spectral_centroid  spectral_flux  spectral_rolloff  \n",
+       "0         535.891491       0.039433       4020.624371  \n",
+       "1         840.023296       0.061607       3833.170732  \n",
+       "2         572.973659       0.018836       4938.532110  \n",
+       "3        3448.585720       0.010311       6988.571429  \n",
+       "4         745.102831       0.014008       4590.849315  \n",
+       "5         738.235253       0.022868       4529.035025  \n",
+       "6        1783.177886       0.020107       5412.625995  \n",
+       "7         527.226453       0.034232       3316.301112  \n",
+       "8         524.060903       0.011216       4010.773810  \n",
+       "9        2541.108767       0.006781       7003.905325  "
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].astype('category')\n",
+    "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].cat.codes\n",
+    "voice_noise_df.head(10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(array([132, 229]), array([4, 4]))\n",
+      "(1356,)\n",
+      "(1354, 8)\n",
+      "(array([], dtype=int64), array([], dtype=int64))\n",
+      "(1083, 8)\n",
+      "(array([], dtype=int64), array([], dtype=int64))\n",
+      "(array([], dtype=int64), array([], dtype=int64))\n"
+     ]
+    }
+   ],
+   "source": [
+    "TRAIN_TEST_PERCENTAGE = 0.80\n",
+    "data = voice_noise_df.loc[:,voice_noise_df.columns != \"audio\"].as_matrix()\n",
+    "labels = voice_noise_df[\"audio\"].as_matrix()\n",
+    "\n",
+    "full_data = np.hstack((data,np.expand_dims(labels,axis = 1)))\n",
+    "random.shuffle(full_data)\n",
+    "data = full_data[:,0:-1]\n",
+    "labels = full_data[:,-1]\n",
+    "\n",
+    "nan_indices = np.where(np.isnan(data))\n",
+    "print(nan_indices)\n",
+    "all_indices_ = np.ones(len(data),dtype = \"bool\")\n",
+    "print(all_indices_.shape)\n",
+    "all_indices_[nan_indices[0]] = False\n",
+    "data_ = copy.deepcopy(data[all_indices_,:])\n",
+    "labels_ = copy.deepcopy(labels[all_indices_])\n",
+    "print(data_.shape)\n",
+    "print(np.where(np.isnan(data_)))\n",
+    "\n",
+    "index_ = int(len(data_)*TRAIN_TEST_PERCENTAGE)\n",
+    "train_data = data_[0:index_,:]\n",
+    "train_labels = labels_[0:index_]\n",
+    "test_data = data_[index_:,:]\n",
+    "test_labels = labels_[index_:]\n",
+    "\n",
+    "print(train_data.shape)\n",
+    "\n",
+    "print(np.where(np.isnan(train_data)))\n",
+    "\n",
+    "print(np.where(np.isnan(test_data)))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "classifier = tree.DecisionTreeClassifier(max_depth = 3)\n",
+    "classifier = classifier.fit(train_data,train_labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.97416973\n"
+     ]
+    }
+   ],
+   "source": [
+    "prediction = classifier.predict(test_data)\n",
+    "print(np.mean(np.equal(prediction,test_labels).astype(np.float32)))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## For testing new audio"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TEST_AUDIO_FOLDER = \"/Users/siva/Documents/speaker_recognition/VOD/testwav/\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_dataset(ROOT_FOLDER,WINDOW_LENGTH = 5,FRAME_LENGTH = 25):\n",
+    "    os.chdir(ROOT_FOLDER)\n",
+    "    all_audio = os.listdir()\n",
+    "    dataset_dict = {\"ZCR\":[],\"RMS\":[],\"spectral_flux\":[],\\\n",
+    "                   \"spectral_centroid\":[],\"spectral_rolloff\":[],\\\n",
+    "                   \"bandwidth\":[],\"nwpd\":[],\"rse\":[]}\n",
+    "    for audio in all_audio:\n",
+    "        print(\"****************************\")\n",
+    "        print(\"reading:\",audio)\n",
+    "        sampling_rate,sig = wav.read(ROOT_FOLDER+audio)\n",
+    "        print(\"sampling rate:\",sampling_rate,\"signal length\",len(sig))\n",
+    "        index = 0\n",
+    "        while index+(sampling_rate*WINDOW_LENGTH) < len(sig):\n",
+    "            sample = sig[index:(index+(sampling_rate*WINDOW_LENGTH))]\n",
+    "            ef = extract_features(sample,FRAME_LENGTH,sampling_rate)\n",
+    "            ZCR,RMS,sf,sr,sc,bd,nwpd,rse = ef.return_()\n",
+    "            dataset_dict[\"ZCR\"].append(ZCR)\n",
+    "            dataset_dict[\"RMS\"].append(RMS)\n",
+    "            dataset_dict[\"spectral_flux\"].append(np.mean(sf))\n",
+    "            dataset_dict[\"spectral_centroid\"].append(np.mean(sc))\n",
+    "            dataset_dict[\"spectral_rolloff\"].append(np.mean(sr))\n",
+    "            dataset_dict[\"bandwidth\"].append(np.mean(bd))\n",
+    "            dataset_dict[\"nwpd\"].append(np.mean(nwpd))\n",
+    "            dataset_dict[\"rse\"].append(np.mean(rse))\n",
+    "            index += sampling_rate*WINDOW_LENGTH\n",
+    "    values = dataset_dict.values()\n",
+    "    print([len(e) for e in values])\n",
+    "    print(\"finished\")\n",
+    "    return dataset_dict"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-01-26.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-01-43.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-01-56.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-02-11.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-16-49-40.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-01-12.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-13-12-39.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "****************************\n",
+      "reading: audiotest08-06-2018-17-02-44.wav\n",
+      "sampling rate: 16000 signal length 84480\n",
+      "[8, 8, 8, 8, 8, 8, 8, 8]\n",
+      "finished\n"
+     ]
+    }
+   ],
+   "source": [
+    "features_test_dict = create_dataset(TEST_AUDIO_FOLDER)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>RMS</th>\n",
+       "      <th>ZCR</th>\n",
+       "      <th>bandwidth</th>\n",
+       "      <th>nwpd</th>\n",
+       "      <th>rse</th>\n",
+       "      <th>spectral_centroid</th>\n",
+       "      <th>spectral_flux</th>\n",
+       "      <th>spectral_rolloff</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>235</td>\n",
+       "      <td>0.148002</td>\n",
+       "      <td>1.800357e+06</td>\n",
+       "      <td>0.323637</td>\n",
+       "      <td>-0.258980</td>\n",
+       "      <td>1043.434918</td>\n",
+       "      <td>0.021088</td>\n",
+       "      <td>6197.853916</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>347</td>\n",
+       "      <td>0.239553</td>\n",
+       "      <td>3.614656e+06</td>\n",
+       "      <td>1.323582</td>\n",
+       "      <td>-0.249343</td>\n",
+       "      <td>1700.323351</td>\n",
+       "      <td>0.008497</td>\n",
+       "      <td>6469.879518</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>271</td>\n",
+       "      <td>0.130902</td>\n",
+       "      <td>1.469745e+06</td>\n",
+       "      <td>1.151613</td>\n",
+       "      <td>-0.253432</td>\n",
+       "      <td>988.019173</td>\n",
+       "      <td>0.014795</td>\n",
+       "      <td>5685.115462</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>140</td>\n",
+       "      <td>0.125277</td>\n",
+       "      <td>1.089120e+06</td>\n",
+       "      <td>-0.566549</td>\n",
+       "      <td>-0.337949</td>\n",
+       "      <td>926.835474</td>\n",
+       "      <td>0.019217</td>\n",
+       "      <td>5463.227912</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>270</td>\n",
+       "      <td>0.093726</td>\n",
+       "      <td>1.016863e+06</td>\n",
+       "      <td>3.289868</td>\n",
+       "      <td>-0.305840</td>\n",
+       "      <td>786.064634</td>\n",
+       "      <td>0.027444</td>\n",
+       "      <td>5227.158635</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>242</td>\n",
+       "      <td>0.179977</td>\n",
+       "      <td>2.751192e+06</td>\n",
+       "      <td>0.520077</td>\n",
+       "      <td>-0.235335</td>\n",
+       "      <td>1297.176396</td>\n",
+       "      <td>0.016272</td>\n",
+       "      <td>6467.871486</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>31</td>\n",
+       "      <td>0.085739</td>\n",
+       "      <td>1.137872e+06</td>\n",
+       "      <td>0.515145</td>\n",
+       "      <td>-0.236153</td>\n",
+       "      <td>641.419533</td>\n",
+       "      <td>0.026151</td>\n",
+       "      <td>4864.646084</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>187</td>\n",
+       "      <td>0.107876</td>\n",
+       "      <td>6.686752e+05</td>\n",
+       "      <td>0.009063</td>\n",
+       "      <td>-0.304043</td>\n",
+       "      <td>802.714148</td>\n",
+       "      <td>0.020602</td>\n",
+       "      <td>5183.985944</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   RMS       ZCR     bandwidth      nwpd       rse  spectral_centroid  \\\n",
+       "0  235  0.148002  1.800357e+06  0.323637 -0.258980        1043.434918   \n",
+       "1  347  0.239553  3.614656e+06  1.323582 -0.249343        1700.323351   \n",
+       "2  271  0.130902  1.469745e+06  1.151613 -0.253432         988.019173   \n",
+       "3  140  0.125277  1.089120e+06 -0.566549 -0.337949         926.835474   \n",
+       "4  270  0.093726  1.016863e+06  3.289868 -0.305840         786.064634   \n",
+       "5  242  0.179977  2.751192e+06  0.520077 -0.235335        1297.176396   \n",
+       "6   31  0.085739  1.137872e+06  0.515145 -0.236153         641.419533   \n",
+       "7  187  0.107876  6.686752e+05  0.009063 -0.304043         802.714148   \n",
+       "\n",
+       "   spectral_flux  spectral_rolloff  \n",
+       "0       0.021088       6197.853916  \n",
+       "1       0.008497       6469.879518  \n",
+       "2       0.014795       5685.115462  \n",
+       "3       0.019217       5463.227912  \n",
+       "4       0.027444       5227.158635  \n",
+       "5       0.016272       6467.871486  \n",
+       "6       0.026151       4864.646084  \n",
+       "7       0.020602       5183.985944  "
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_df = pd.DataFrame.from_dict(features_test_dict)\n",
+    "test_df.replace([np.inf,-np.inf,np.nan],0)\n",
+    "test_df.dropna()\n",
+    "test_df_data = test_df.as_matrix()\n",
+    "test_df.head(10)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 1 - Voice and 0 - Noise"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1. 1. 1. 1. 1. 1. 1. 1.]\n"
+     ]
+    }
+   ],
+   "source": [
+    "test_predictions = classifier.predict(test_df_data)\n",
+    "print(test_predictions)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "anaconda-cloud": {},
+  "kernelspec": {
+   "display_name": "Python [default]",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}