--- a +++ b/voice_activity_detection.ipynb @@ -0,0 +1,647 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn import tree\n", + "import random\n", + "import copy\n", + "import os\n", + "import graphviz\n", + "import scipy.io.wavfile as wav\n", + "from src.voice_activity_detection.extract_features import extract_features" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<class 'pandas.core.frame.DataFrame'>\n", + "RangeIndex: 1356 entries, 0 to 1355\n", + "Data columns (total 9 columns):\n", + "RMS 1356 non-null int64\n", + "ZCR 1356 non-null float64\n", + "audio 1356 non-null object\n", + "bandwidth 1356 non-null float64\n", + "nwpd 1356 non-null float64\n", + "rse 1353 non-null float64\n", + "spectral_centroid 1356 non-null float64\n", + "spectral_flux 1356 non-null float64\n", + "spectral_rolloff 1356 non-null float64\n", + "dtypes: float64(7), int64(1), object(1)\n", + "memory usage: 95.4+ KB\n" + ] + } + ], + "source": [ + "voice_noise_data = np.load(\"numpy_files/musan_mean_speech_noise.npy\").item()\n", + "voice_noise_df = pd.DataFrame.from_dict(voice_noise_data)\n", + "voice_noise_df.replace([np.inf,-np.inf,np.nan],0)\n", + "voice_noise_df.dropna()\n", + "voice_noise_df.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>RMS</th>\n", + " <th>ZCR</th>\n", + " <th>audio</th>\n", + " <th>bandwidth</th>\n", + " <th>nwpd</th>\n", + " <th>rse</th>\n", + " <th>spectral_centroid</th>\n", + " <th>spectral_flux</th>\n", + " <th>spectral_rolloff</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>120</td>\n", + " <td>0.317919</td>\n", + " <td>1</td>\n", + " <td>2.852637e+05</td>\n", + " <td>-1.086645</td>\n", + " <td>-0.190269</td>\n", + " <td>535.891491</td>\n", + " <td>0.039433</td>\n", + " <td>4020.624371</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2</td>\n", + " <td>0.442350</td>\n", + " <td>1</td>\n", + " <td>6.743143e+05</td>\n", + " <td>10.745700</td>\n", + " <td>-0.875232</td>\n", + " <td>840.023296</td>\n", + " <td>0.061607</td>\n", + " <td>3833.170732</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>13</td>\n", + " <td>0.464951</td>\n", + " <td>1</td>\n", + " <td>8.572244e+05</td>\n", + " <td>1.387262</td>\n", + " <td>-0.322468</td>\n", + " <td>572.973659</td>\n", + " <td>0.018836</td>\n", + " <td>4938.532110</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>130</td>\n", + " <td>0.150850</td>\n", + " <td>1</td>\n", + " <td>3.938587e+06</td>\n", + " <td>0.498670</td>\n", + " <td>-0.290313</td>\n", + " <td>3448.585720</td>\n", + " <td>0.010311</td>\n", + " <td>6988.571429</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>42</td>\n", + " <td>0.250210</td>\n", + " <td>1</td>\n", + " <td>6.657314e+05</td>\n", + " <td>0.925133</td>\n", + " <td>-0.263716</td>\n", + " <td>745.102831</td>\n", + " <td>0.014008</td>\n", + " <td>4590.849315</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>93</td>\n", + " <td>0.481068</td>\n", + " <td>1</td>\n", + " <td>7.152688e+05</td>\n", + " <td>-0.090272</td>\n", + " <td>-0.184233</td>\n", + " <td>738.235253</td>\n", + " <td>0.022868</td>\n", + " <td>4529.035025</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>12</td>\n", + " <td>0.496398</td>\n", + " <td>1</td>\n", + " <td>1.041786e+06</td>\n", + " <td>-0.611150</td>\n", + " <td>-0.405754</td>\n", + " <td>1783.177886</td>\n", + " <td>0.020107</td>\n", + " <td>5412.625995</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>1967</td>\n", + " <td>0.401090</td>\n", + " <td>0</td>\n", + " <td>2.284852e+05</td>\n", + " <td>-0.196838</td>\n", + " <td>-0.301843</td>\n", + " <td>527.226453</td>\n", + " <td>0.034232</td>\n", + " <td>3316.301112</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>151</td>\n", + " <td>0.499611</td>\n", + " <td>1</td>\n", + " <td>3.597816e+05</td>\n", + " <td>0.241001</td>\n", + " <td>-0.153348</td>\n", + " <td>524.060903</td>\n", + " <td>0.011216</td>\n", + " <td>4010.773810</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>5</td>\n", + " <td>0.500247</td>\n", + " <td>1</td>\n", + " <td>5.324209e+06</td>\n", + " <td>-0.130207</td>\n", + " <td>-0.202091</td>\n", + " <td>2541.108767</td>\n", + " <td>0.006781</td>\n", + " <td>7003.905325</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " RMS ZCR audio bandwidth nwpd rse \\\n", + "0 120 0.317919 1 2.852637e+05 -1.086645 -0.190269 \n", + "1 2 0.442350 1 6.743143e+05 10.745700 -0.875232 \n", + "2 13 0.464951 1 8.572244e+05 1.387262 -0.322468 \n", + "3 130 0.150850 1 3.938587e+06 0.498670 -0.290313 \n", + "4 42 0.250210 1 6.657314e+05 0.925133 -0.263716 \n", + "5 93 0.481068 1 7.152688e+05 -0.090272 -0.184233 \n", + "6 12 0.496398 1 1.041786e+06 -0.611150 -0.405754 \n", + "7 1967 0.401090 0 2.284852e+05 -0.196838 -0.301843 \n", + "8 151 0.499611 1 3.597816e+05 0.241001 -0.153348 \n", + "9 5 0.500247 1 5.324209e+06 -0.130207 -0.202091 \n", + "\n", + " spectral_centroid spectral_flux spectral_rolloff \n", + "0 535.891491 0.039433 4020.624371 \n", + "1 840.023296 0.061607 3833.170732 \n", + "2 572.973659 0.018836 4938.532110 \n", + "3 3448.585720 0.010311 6988.571429 \n", + "4 745.102831 0.014008 4590.849315 \n", + "5 738.235253 0.022868 4529.035025 \n", + "6 1783.177886 0.020107 5412.625995 \n", + "7 527.226453 0.034232 3316.301112 \n", + "8 524.060903 0.011216 4010.773810 \n", + "9 2541.108767 0.006781 7003.905325 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].astype('category')\n", + "voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].cat.codes\n", + "voice_noise_df.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(array([132, 229]), array([4, 4]))\n", + "(1356,)\n", + "(1354, 8)\n", + "(array([], dtype=int64), array([], dtype=int64))\n", + "(1083, 8)\n", + "(array([], dtype=int64), array([], dtype=int64))\n", + "(array([], dtype=int64), array([], dtype=int64))\n" + ] + } + ], + "source": [ + "TRAIN_TEST_PERCENTAGE = 0.80\n", + "data = voice_noise_df.loc[:,voice_noise_df.columns != \"audio\"].as_matrix()\n", + "labels = voice_noise_df[\"audio\"].as_matrix()\n", + "\n", + "full_data = np.hstack((data,np.expand_dims(labels,axis = 1)))\n", + "random.shuffle(full_data)\n", + "data = full_data[:,0:-1]\n", + "labels = full_data[:,-1]\n", + "\n", + "nan_indices = np.where(np.isnan(data))\n", + "print(nan_indices)\n", + "all_indices_ = np.ones(len(data),dtype = \"bool\")\n", + "print(all_indices_.shape)\n", + "all_indices_[nan_indices[0]] = False\n", + "data_ = copy.deepcopy(data[all_indices_,:])\n", + "labels_ = copy.deepcopy(labels[all_indices_])\n", + "print(data_.shape)\n", + "print(np.where(np.isnan(data_)))\n", + "\n", + "index_ = int(len(data_)*TRAIN_TEST_PERCENTAGE)\n", + "train_data = data_[0:index_,:]\n", + "train_labels = labels_[0:index_]\n", + "test_data = data_[index_:,:]\n", + "test_labels = labels_[index_:]\n", + "\n", + "print(train_data.shape)\n", + "\n", + "print(np.where(np.isnan(train_data)))\n", + "\n", + "print(np.where(np.isnan(test_data)))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "classifier = tree.DecisionTreeClassifier(max_depth = 3)\n", + "classifier = classifier.fit(train_data,train_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.97416973\n" + ] + } + ], + "source": [ + "prediction = classifier.predict(test_data)\n", + "print(np.mean(np.equal(prediction,test_labels).astype(np.float32)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## For testing new audio" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_AUDIO_FOLDER = \"/Users/siva/Documents/speaker_recognition/VOD/testwav/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def create_dataset(ROOT_FOLDER,WINDOW_LENGTH = 5,FRAME_LENGTH = 25):\n", + " os.chdir(ROOT_FOLDER)\n", + " all_audio = os.listdir()\n", + " dataset_dict = {\"ZCR\":[],\"RMS\":[],\"spectral_flux\":[],\\\n", + " \"spectral_centroid\":[],\"spectral_rolloff\":[],\\\n", + " \"bandwidth\":[],\"nwpd\":[],\"rse\":[]}\n", + " for audio in all_audio:\n", + " print(\"****************************\")\n", + " print(\"reading:\",audio)\n", + " sampling_rate,sig = wav.read(ROOT_FOLDER+audio)\n", + " print(\"sampling rate:\",sampling_rate,\"signal length\",len(sig))\n", + " index = 0\n", + " while index+(sampling_rate*WINDOW_LENGTH) < len(sig):\n", + " sample = sig[index:(index+(sampling_rate*WINDOW_LENGTH))]\n", + " ef = extract_features(sample,FRAME_LENGTH,sampling_rate)\n", + " ZCR,RMS,sf,sr,sc,bd,nwpd,rse = ef.return_()\n", + " dataset_dict[\"ZCR\"].append(ZCR)\n", + " dataset_dict[\"RMS\"].append(RMS)\n", + " dataset_dict[\"spectral_flux\"].append(np.mean(sf))\n", + " dataset_dict[\"spectral_centroid\"].append(np.mean(sc))\n", + " dataset_dict[\"spectral_rolloff\"].append(np.mean(sr))\n", + " dataset_dict[\"bandwidth\"].append(np.mean(bd))\n", + " dataset_dict[\"nwpd\"].append(np.mean(nwpd))\n", + " dataset_dict[\"rse\"].append(np.mean(rse))\n", + " index += sampling_rate*WINDOW_LENGTH\n", + " values = dataset_dict.values()\n", + " print([len(e) for e in values])\n", + " print(\"finished\")\n", + " return dataset_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "****************************\n", + "reading: audiotest08-06-2018-17-01-26.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-17-01-43.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-17-01-56.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-17-02-11.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-16-49-40.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-17-01-12.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-13-12-39.wav\n", + "sampling rate: 16000 signal length 84480\n", + "****************************\n", + "reading: audiotest08-06-2018-17-02-44.wav\n", + "sampling rate: 16000 signal length 84480\n", + "[8, 8, 8, 8, 8, 8, 8, 8]\n", + "finished\n" + ] + } + ], + "source": [ + "features_test_dict = create_dataset(TEST_AUDIO_FOLDER)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>RMS</th>\n", + " <th>ZCR</th>\n", + " <th>bandwidth</th>\n", + " <th>nwpd</th>\n", + " <th>rse</th>\n", + " <th>spectral_centroid</th>\n", + " <th>spectral_flux</th>\n", + " <th>spectral_rolloff</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>235</td>\n", + " <td>0.148002</td>\n", + " <td>1.800357e+06</td>\n", + " <td>0.323637</td>\n", + " <td>-0.258980</td>\n", + " <td>1043.434918</td>\n", + " <td>0.021088</td>\n", + " <td>6197.853916</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>347</td>\n", + " <td>0.239553</td>\n", + " <td>3.614656e+06</td>\n", + " <td>1.323582</td>\n", + " <td>-0.249343</td>\n", + " <td>1700.323351</td>\n", + " <td>0.008497</td>\n", + " <td>6469.879518</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>271</td>\n", + " <td>0.130902</td>\n", + " <td>1.469745e+06</td>\n", + " <td>1.151613</td>\n", + " <td>-0.253432</td>\n", + " <td>988.019173</td>\n", + " <td>0.014795</td>\n", + " <td>5685.115462</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>140</td>\n", + " <td>0.125277</td>\n", + " <td>1.089120e+06</td>\n", + " <td>-0.566549</td>\n", + " <td>-0.337949</td>\n", + " <td>926.835474</td>\n", + " <td>0.019217</td>\n", + " <td>5463.227912</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>270</td>\n", + " <td>0.093726</td>\n", + " <td>1.016863e+06</td>\n", + " <td>3.289868</td>\n", + " <td>-0.305840</td>\n", + " <td>786.064634</td>\n", + " <td>0.027444</td>\n", + " <td>5227.158635</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>242</td>\n", + " <td>0.179977</td>\n", + " <td>2.751192e+06</td>\n", + " <td>0.520077</td>\n", + " <td>-0.235335</td>\n", + " <td>1297.176396</td>\n", + " <td>0.016272</td>\n", + " <td>6467.871486</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>31</td>\n", + " <td>0.085739</td>\n", + " <td>1.137872e+06</td>\n", + " <td>0.515145</td>\n", + " <td>-0.236153</td>\n", + " <td>641.419533</td>\n", + " <td>0.026151</td>\n", + " <td>4864.646084</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>187</td>\n", + " <td>0.107876</td>\n", + " <td>6.686752e+05</td>\n", + " <td>0.009063</td>\n", + " <td>-0.304043</td>\n", + " <td>802.714148</td>\n", + " <td>0.020602</td>\n", + " <td>5183.985944</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " RMS ZCR bandwidth nwpd rse spectral_centroid \\\n", + "0 235 0.148002 1.800357e+06 0.323637 -0.258980 1043.434918 \n", + "1 347 0.239553 3.614656e+06 1.323582 -0.249343 1700.323351 \n", + "2 271 0.130902 1.469745e+06 1.151613 -0.253432 988.019173 \n", + "3 140 0.125277 1.089120e+06 -0.566549 -0.337949 926.835474 \n", + "4 270 0.093726 1.016863e+06 3.289868 -0.305840 786.064634 \n", + "5 242 0.179977 2.751192e+06 0.520077 -0.235335 1297.176396 \n", + "6 31 0.085739 1.137872e+06 0.515145 -0.236153 641.419533 \n", + "7 187 0.107876 6.686752e+05 0.009063 -0.304043 802.714148 \n", + "\n", + " spectral_flux spectral_rolloff \n", + "0 0.021088 6197.853916 \n", + "1 0.008497 6469.879518 \n", + "2 0.014795 5685.115462 \n", + "3 0.019217 5463.227912 \n", + "4 0.027444 5227.158635 \n", + "5 0.016272 6467.871486 \n", + "6 0.026151 4864.646084 \n", + "7 0.020602 5183.985944 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df = pd.DataFrame.from_dict(features_test_dict)\n", + "test_df.replace([np.inf,-np.inf,np.nan],0)\n", + "test_df.dropna()\n", + "test_df_data = test_df.as_matrix()\n", + "test_df.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 - Voice and 0 - Noise" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1. 1. 1. 1. 1. 1. 1. 1.]\n" + ] + } + ], + "source": [ + "test_predictions = classifier.predict(test_df_data)\n", + "print(test_predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [default]", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}