[a989eb]: / voice_activity_detection.ipynb

Download this file

648 lines (647 with data), 21.1 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn import tree\n",
    "import random\n",
    "import copy\n",
    "import os\n",
    "import graphviz\n",
    "import scipy.io.wavfile as wav\n",
    "from src.voice_activity_detection.extract_features import extract_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1356 entries, 0 to 1355\n",
      "Data columns (total 9 columns):\n",
      "RMS                  1356 non-null int64\n",
      "ZCR                  1356 non-null float64\n",
      "audio                1356 non-null object\n",
      "bandwidth            1356 non-null float64\n",
      "nwpd                 1356 non-null float64\n",
      "rse                  1353 non-null float64\n",
      "spectral_centroid    1356 non-null float64\n",
      "spectral_flux        1356 non-null float64\n",
      "spectral_rolloff     1356 non-null float64\n",
      "dtypes: float64(7), int64(1), object(1)\n",
      "memory usage: 95.4+ KB\n"
     ]
    }
   ],
   "source": [
    "voice_noise_data = np.load(\"numpy_files/musan_mean_speech_noise.npy\").item()\n",
    "voice_noise_df = pd.DataFrame.from_dict(voice_noise_data)\n",
    "voice_noise_df.replace([np.inf,-np.inf,np.nan],0)\n",
    "voice_noise_df.dropna()\n",
    "voice_noise_df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RMS</th>\n",
       "      <th>ZCR</th>\n",
       "      <th>audio</th>\n",
       "      <th>bandwidth</th>\n",
       "      <th>nwpd</th>\n",
       "      <th>rse</th>\n",
       "      <th>spectral_centroid</th>\n",
       "      <th>spectral_flux</th>\n",
       "      <th>spectral_rolloff</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>120</td>\n",
       "      <td>0.317919</td>\n",
       "      <td>1</td>\n",
       "      <td>2.852637e+05</td>\n",
       "      <td>-1.086645</td>\n",
       "      <td>-0.190269</td>\n",
       "      <td>535.891491</td>\n",
       "      <td>0.039433</td>\n",
       "      <td>4020.624371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.442350</td>\n",
       "      <td>1</td>\n",
       "      <td>6.743143e+05</td>\n",
       "      <td>10.745700</td>\n",
       "      <td>-0.875232</td>\n",
       "      <td>840.023296</td>\n",
       "      <td>0.061607</td>\n",
       "      <td>3833.170732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>13</td>\n",
       "      <td>0.464951</td>\n",
       "      <td>1</td>\n",
       "      <td>8.572244e+05</td>\n",
       "      <td>1.387262</td>\n",
       "      <td>-0.322468</td>\n",
       "      <td>572.973659</td>\n",
       "      <td>0.018836</td>\n",
       "      <td>4938.532110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>130</td>\n",
       "      <td>0.150850</td>\n",
       "      <td>1</td>\n",
       "      <td>3.938587e+06</td>\n",
       "      <td>0.498670</td>\n",
       "      <td>-0.290313</td>\n",
       "      <td>3448.585720</td>\n",
       "      <td>0.010311</td>\n",
       "      <td>6988.571429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>42</td>\n",
       "      <td>0.250210</td>\n",
       "      <td>1</td>\n",
       "      <td>6.657314e+05</td>\n",
       "      <td>0.925133</td>\n",
       "      <td>-0.263716</td>\n",
       "      <td>745.102831</td>\n",
       "      <td>0.014008</td>\n",
       "      <td>4590.849315</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>93</td>\n",
       "      <td>0.481068</td>\n",
       "      <td>1</td>\n",
       "      <td>7.152688e+05</td>\n",
       "      <td>-0.090272</td>\n",
       "      <td>-0.184233</td>\n",
       "      <td>738.235253</td>\n",
       "      <td>0.022868</td>\n",
       "      <td>4529.035025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>12</td>\n",
       "      <td>0.496398</td>\n",
       "      <td>1</td>\n",
       "      <td>1.041786e+06</td>\n",
       "      <td>-0.611150</td>\n",
       "      <td>-0.405754</td>\n",
       "      <td>1783.177886</td>\n",
       "      <td>0.020107</td>\n",
       "      <td>5412.625995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1967</td>\n",
       "      <td>0.401090</td>\n",
       "      <td>0</td>\n",
       "      <td>2.284852e+05</td>\n",
       "      <td>-0.196838</td>\n",
       "      <td>-0.301843</td>\n",
       "      <td>527.226453</td>\n",
       "      <td>0.034232</td>\n",
       "      <td>3316.301112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>151</td>\n",
       "      <td>0.499611</td>\n",
       "      <td>1</td>\n",
       "      <td>3.597816e+05</td>\n",
       "      <td>0.241001</td>\n",
       "      <td>-0.153348</td>\n",
       "      <td>524.060903</td>\n",
       "      <td>0.011216</td>\n",
       "      <td>4010.773810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5</td>\n",
       "      <td>0.500247</td>\n",
       "      <td>1</td>\n",
       "      <td>5.324209e+06</td>\n",
       "      <td>-0.130207</td>\n",
       "      <td>-0.202091</td>\n",
       "      <td>2541.108767</td>\n",
       "      <td>0.006781</td>\n",
       "      <td>7003.905325</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    RMS       ZCR  audio     bandwidth       nwpd       rse  \\\n",
       "0   120  0.317919      1  2.852637e+05  -1.086645 -0.190269   \n",
       "1     2  0.442350      1  6.743143e+05  10.745700 -0.875232   \n",
       "2    13  0.464951      1  8.572244e+05   1.387262 -0.322468   \n",
       "3   130  0.150850      1  3.938587e+06   0.498670 -0.290313   \n",
       "4    42  0.250210      1  6.657314e+05   0.925133 -0.263716   \n",
       "5    93  0.481068      1  7.152688e+05  -0.090272 -0.184233   \n",
       "6    12  0.496398      1  1.041786e+06  -0.611150 -0.405754   \n",
       "7  1967  0.401090      0  2.284852e+05  -0.196838 -0.301843   \n",
       "8   151  0.499611      1  3.597816e+05   0.241001 -0.153348   \n",
       "9     5  0.500247      1  5.324209e+06  -0.130207 -0.202091   \n",
       "\n",
       "   spectral_centroid  spectral_flux  spectral_rolloff  \n",
       "0         535.891491       0.039433       4020.624371  \n",
       "1         840.023296       0.061607       3833.170732  \n",
       "2         572.973659       0.018836       4938.532110  \n",
       "3        3448.585720       0.010311       6988.571429  \n",
       "4         745.102831       0.014008       4590.849315  \n",
       "5         738.235253       0.022868       4529.035025  \n",
       "6        1783.177886       0.020107       5412.625995  \n",
       "7         527.226453       0.034232       3316.301112  \n",
       "8         524.060903       0.011216       4010.773810  \n",
       "9        2541.108767       0.006781       7003.905325  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].astype('category')\n",
    "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].cat.codes\n",
    "voice_noise_df.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([132, 229]), array([4, 4]))\n",
      "(1356,)\n",
      "(1354, 8)\n",
      "(array([], dtype=int64), array([], dtype=int64))\n",
      "(1083, 8)\n",
      "(array([], dtype=int64), array([], dtype=int64))\n",
      "(array([], dtype=int64), array([], dtype=int64))\n"
     ]
    }
   ],
   "source": [
    "TRAIN_TEST_PERCENTAGE = 0.80\n",
    "data = voice_noise_df.loc[:,voice_noise_df.columns != \"audio\"].as_matrix()\n",
    "labels = voice_noise_df[\"audio\"].as_matrix()\n",
    "\n",
    "full_data = np.hstack((data,np.expand_dims(labels,axis = 1)))\n",
    "random.shuffle(full_data)\n",
    "data = full_data[:,0:-1]\n",
    "labels = full_data[:,-1]\n",
    "\n",
    "nan_indices = np.where(np.isnan(data))\n",
    "print(nan_indices)\n",
    "all_indices_ = np.ones(len(data),dtype = \"bool\")\n",
    "print(all_indices_.shape)\n",
    "all_indices_[nan_indices[0]] = False\n",
    "data_ = copy.deepcopy(data[all_indices_,:])\n",
    "labels_ = copy.deepcopy(labels[all_indices_])\n",
    "print(data_.shape)\n",
    "print(np.where(np.isnan(data_)))\n",
    "\n",
    "index_ = int(len(data_)*TRAIN_TEST_PERCENTAGE)\n",
    "train_data = data_[0:index_,:]\n",
    "train_labels = labels_[0:index_]\n",
    "test_data = data_[index_:,:]\n",
    "test_labels = labels_[index_:]\n",
    "\n",
    "print(train_data.shape)\n",
    "\n",
    "print(np.where(np.isnan(train_data)))\n",
    "\n",
    "print(np.where(np.isnan(test_data)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = tree.DecisionTreeClassifier(max_depth = 3)\n",
    "classifier = classifier.fit(train_data,train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.97416973\n"
     ]
    }
   ],
   "source": [
    "prediction = classifier.predict(test_data)\n",
    "print(np.mean(np.equal(prediction,test_labels).astype(np.float32)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For testing new audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_AUDIO_FOLDER = \"/Users/siva/Documents/speaker_recognition/VOD/testwav/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset(ROOT_FOLDER,WINDOW_LENGTH = 5,FRAME_LENGTH = 25):\n",
    "    os.chdir(ROOT_FOLDER)\n",
    "    all_audio = os.listdir()\n",
    "    dataset_dict = {\"ZCR\":[],\"RMS\":[],\"spectral_flux\":[],\\\n",
    "                   \"spectral_centroid\":[],\"spectral_rolloff\":[],\\\n",
    "                   \"bandwidth\":[],\"nwpd\":[],\"rse\":[]}\n",
    "    for audio in all_audio:\n",
    "        print(\"****************************\")\n",
    "        print(\"reading:\",audio)\n",
    "        sampling_rate,sig = wav.read(ROOT_FOLDER+audio)\n",
    "        print(\"sampling rate:\",sampling_rate,\"signal length\",len(sig))\n",
    "        index = 0\n",
    "        while index+(sampling_rate*WINDOW_LENGTH) < len(sig):\n",
    "            sample = sig[index:(index+(sampling_rate*WINDOW_LENGTH))]\n",
    "            ef = extract_features(sample,FRAME_LENGTH,sampling_rate)\n",
    "            ZCR,RMS,sf,sr,sc,bd,nwpd,rse = ef.return_()\n",
    "            dataset_dict[\"ZCR\"].append(ZCR)\n",
    "            dataset_dict[\"RMS\"].append(RMS)\n",
    "            dataset_dict[\"spectral_flux\"].append(np.mean(sf))\n",
    "            dataset_dict[\"spectral_centroid\"].append(np.mean(sc))\n",
    "            dataset_dict[\"spectral_rolloff\"].append(np.mean(sr))\n",
    "            dataset_dict[\"bandwidth\"].append(np.mean(bd))\n",
    "            dataset_dict[\"nwpd\"].append(np.mean(nwpd))\n",
    "            dataset_dict[\"rse\"].append(np.mean(rse))\n",
    "            index += sampling_rate*WINDOW_LENGTH\n",
    "    values = dataset_dict.values()\n",
    "    print([len(e) for e in values])\n",
    "    print(\"finished\")\n",
    "    return dataset_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "****************************\n",
      "reading: audiotest08-06-2018-17-01-26.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-17-01-43.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-17-01-56.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-17-02-11.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-16-49-40.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-17-01-12.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-13-12-39.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "****************************\n",
      "reading: audiotest08-06-2018-17-02-44.wav\n",
      "sampling rate: 16000 signal length 84480\n",
      "[8, 8, 8, 8, 8, 8, 8, 8]\n",
      "finished\n"
     ]
    }
   ],
   "source": [
    "features_test_dict = create_dataset(TEST_AUDIO_FOLDER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RMS</th>\n",
       "      <th>ZCR</th>\n",
       "      <th>bandwidth</th>\n",
       "      <th>nwpd</th>\n",
       "      <th>rse</th>\n",
       "      <th>spectral_centroid</th>\n",
       "      <th>spectral_flux</th>\n",
       "      <th>spectral_rolloff</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>235</td>\n",
       "      <td>0.148002</td>\n",
       "      <td>1.800357e+06</td>\n",
       "      <td>0.323637</td>\n",
       "      <td>-0.258980</td>\n",
       "      <td>1043.434918</td>\n",
       "      <td>0.021088</td>\n",
       "      <td>6197.853916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>347</td>\n",
       "      <td>0.239553</td>\n",
       "      <td>3.614656e+06</td>\n",
       "      <td>1.323582</td>\n",
       "      <td>-0.249343</td>\n",
       "      <td>1700.323351</td>\n",
       "      <td>0.008497</td>\n",
       "      <td>6469.879518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>271</td>\n",
       "      <td>0.130902</td>\n",
       "      <td>1.469745e+06</td>\n",
       "      <td>1.151613</td>\n",
       "      <td>-0.253432</td>\n",
       "      <td>988.019173</td>\n",
       "      <td>0.014795</td>\n",
       "      <td>5685.115462</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>140</td>\n",
       "      <td>0.125277</td>\n",
       "      <td>1.089120e+06</td>\n",
       "      <td>-0.566549</td>\n",
       "      <td>-0.337949</td>\n",
       "      <td>926.835474</td>\n",
       "      <td>0.019217</td>\n",
       "      <td>5463.227912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>270</td>\n",
       "      <td>0.093726</td>\n",
       "      <td>1.016863e+06</td>\n",
       "      <td>3.289868</td>\n",
       "      <td>-0.305840</td>\n",
       "      <td>786.064634</td>\n",
       "      <td>0.027444</td>\n",
       "      <td>5227.158635</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>242</td>\n",
       "      <td>0.179977</td>\n",
       "      <td>2.751192e+06</td>\n",
       "      <td>0.520077</td>\n",
       "      <td>-0.235335</td>\n",
       "      <td>1297.176396</td>\n",
       "      <td>0.016272</td>\n",
       "      <td>6467.871486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>31</td>\n",
       "      <td>0.085739</td>\n",
       "      <td>1.137872e+06</td>\n",
       "      <td>0.515145</td>\n",
       "      <td>-0.236153</td>\n",
       "      <td>641.419533</td>\n",
       "      <td>0.026151</td>\n",
       "      <td>4864.646084</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>187</td>\n",
       "      <td>0.107876</td>\n",
       "      <td>6.686752e+05</td>\n",
       "      <td>0.009063</td>\n",
       "      <td>-0.304043</td>\n",
       "      <td>802.714148</td>\n",
       "      <td>0.020602</td>\n",
       "      <td>5183.985944</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   RMS       ZCR     bandwidth      nwpd       rse  spectral_centroid  \\\n",
       "0  235  0.148002  1.800357e+06  0.323637 -0.258980        1043.434918   \n",
       "1  347  0.239553  3.614656e+06  1.323582 -0.249343        1700.323351   \n",
       "2  271  0.130902  1.469745e+06  1.151613 -0.253432         988.019173   \n",
       "3  140  0.125277  1.089120e+06 -0.566549 -0.337949         926.835474   \n",
       "4  270  0.093726  1.016863e+06  3.289868 -0.305840         786.064634   \n",
       "5  242  0.179977  2.751192e+06  0.520077 -0.235335        1297.176396   \n",
       "6   31  0.085739  1.137872e+06  0.515145 -0.236153         641.419533   \n",
       "7  187  0.107876  6.686752e+05  0.009063 -0.304043         802.714148   \n",
       "\n",
       "   spectral_flux  spectral_rolloff  \n",
       "0       0.021088       6197.853916  \n",
       "1       0.008497       6469.879518  \n",
       "2       0.014795       5685.115462  \n",
       "3       0.019217       5463.227912  \n",
       "4       0.027444       5227.158635  \n",
       "5       0.016272       6467.871486  \n",
       "6       0.026151       4864.646084  \n",
       "7       0.020602       5183.985944  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df = pd.DataFrame.from_dict(features_test_dict)\n",
    "test_df.replace([np.inf,-np.inf,np.nan],0)\n",
    "test_df.dropna()\n",
    "test_df_data = test_df.as_matrix()\n",
    "test_df.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 - Voice and 0 - Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 1. 1. 1. 1. 1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "test_predictions = classifier.predict(test_df_data)\n",
    "print(test_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}