648 lines (647 with data), 21.1 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn import tree\n",
"import random\n",
"import copy\n",
"import os\n",
"import graphviz\n",
"import scipy.io.wavfile as wav\n",
"from src.voice_activity_detection.extract_features import extract_features"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1356 entries, 0 to 1355\n",
"Data columns (total 9 columns):\n",
"RMS 1356 non-null int64\n",
"ZCR 1356 non-null float64\n",
"audio 1356 non-null object\n",
"bandwidth 1356 non-null float64\n",
"nwpd 1356 non-null float64\n",
"rse 1353 non-null float64\n",
"spectral_centroid 1356 non-null float64\n",
"spectral_flux 1356 non-null float64\n",
"spectral_rolloff 1356 non-null float64\n",
"dtypes: float64(7), int64(1), object(1)\n",
"memory usage: 95.4+ KB\n"
]
}
],
"source": [
"voice_noise_data = np.load(\"numpy_files/musan_mean_speech_noise.npy\").item()\n",
"voice_noise_df = pd.DataFrame.from_dict(voice_noise_data)\n",
"voice_noise_df.replace([np.inf,-np.inf,np.nan],0)\n",
"voice_noise_df.dropna()\n",
"voice_noise_df.info()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMS</th>\n",
" <th>ZCR</th>\n",
" <th>audio</th>\n",
" <th>bandwidth</th>\n",
" <th>nwpd</th>\n",
" <th>rse</th>\n",
" <th>spectral_centroid</th>\n",
" <th>spectral_flux</th>\n",
" <th>spectral_rolloff</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>120</td>\n",
" <td>0.317919</td>\n",
" <td>1</td>\n",
" <td>2.852637e+05</td>\n",
" <td>-1.086645</td>\n",
" <td>-0.190269</td>\n",
" <td>535.891491</td>\n",
" <td>0.039433</td>\n",
" <td>4020.624371</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>0.442350</td>\n",
" <td>1</td>\n",
" <td>6.743143e+05</td>\n",
" <td>10.745700</td>\n",
" <td>-0.875232</td>\n",
" <td>840.023296</td>\n",
" <td>0.061607</td>\n",
" <td>3833.170732</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>13</td>\n",
" <td>0.464951</td>\n",
" <td>1</td>\n",
" <td>8.572244e+05</td>\n",
" <td>1.387262</td>\n",
" <td>-0.322468</td>\n",
" <td>572.973659</td>\n",
" <td>0.018836</td>\n",
" <td>4938.532110</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>130</td>\n",
" <td>0.150850</td>\n",
" <td>1</td>\n",
" <td>3.938587e+06</td>\n",
" <td>0.498670</td>\n",
" <td>-0.290313</td>\n",
" <td>3448.585720</td>\n",
" <td>0.010311</td>\n",
" <td>6988.571429</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>42</td>\n",
" <td>0.250210</td>\n",
" <td>1</td>\n",
" <td>6.657314e+05</td>\n",
" <td>0.925133</td>\n",
" <td>-0.263716</td>\n",
" <td>745.102831</td>\n",
" <td>0.014008</td>\n",
" <td>4590.849315</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>93</td>\n",
" <td>0.481068</td>\n",
" <td>1</td>\n",
" <td>7.152688e+05</td>\n",
" <td>-0.090272</td>\n",
" <td>-0.184233</td>\n",
" <td>738.235253</td>\n",
" <td>0.022868</td>\n",
" <td>4529.035025</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>12</td>\n",
" <td>0.496398</td>\n",
" <td>1</td>\n",
" <td>1.041786e+06</td>\n",
" <td>-0.611150</td>\n",
" <td>-0.405754</td>\n",
" <td>1783.177886</td>\n",
" <td>0.020107</td>\n",
" <td>5412.625995</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>1967</td>\n",
" <td>0.401090</td>\n",
" <td>0</td>\n",
" <td>2.284852e+05</td>\n",
" <td>-0.196838</td>\n",
" <td>-0.301843</td>\n",
" <td>527.226453</td>\n",
" <td>0.034232</td>\n",
" <td>3316.301112</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>151</td>\n",
" <td>0.499611</td>\n",
" <td>1</td>\n",
" <td>3.597816e+05</td>\n",
" <td>0.241001</td>\n",
" <td>-0.153348</td>\n",
" <td>524.060903</td>\n",
" <td>0.011216</td>\n",
" <td>4010.773810</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>5</td>\n",
" <td>0.500247</td>\n",
" <td>1</td>\n",
" <td>5.324209e+06</td>\n",
" <td>-0.130207</td>\n",
" <td>-0.202091</td>\n",
" <td>2541.108767</td>\n",
" <td>0.006781</td>\n",
" <td>7003.905325</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMS ZCR audio bandwidth nwpd rse \\\n",
"0 120 0.317919 1 2.852637e+05 -1.086645 -0.190269 \n",
"1 2 0.442350 1 6.743143e+05 10.745700 -0.875232 \n",
"2 13 0.464951 1 8.572244e+05 1.387262 -0.322468 \n",
"3 130 0.150850 1 3.938587e+06 0.498670 -0.290313 \n",
"4 42 0.250210 1 6.657314e+05 0.925133 -0.263716 \n",
"5 93 0.481068 1 7.152688e+05 -0.090272 -0.184233 \n",
"6 12 0.496398 1 1.041786e+06 -0.611150 -0.405754 \n",
"7 1967 0.401090 0 2.284852e+05 -0.196838 -0.301843 \n",
"8 151 0.499611 1 3.597816e+05 0.241001 -0.153348 \n",
"9 5 0.500247 1 5.324209e+06 -0.130207 -0.202091 \n",
"\n",
" spectral_centroid spectral_flux spectral_rolloff \n",
"0 535.891491 0.039433 4020.624371 \n",
"1 840.023296 0.061607 3833.170732 \n",
"2 572.973659 0.018836 4938.532110 \n",
"3 3448.585720 0.010311 6988.571429 \n",
"4 745.102831 0.014008 4590.849315 \n",
"5 738.235253 0.022868 4529.035025 \n",
"6 1783.177886 0.020107 5412.625995 \n",
"7 527.226453 0.034232 3316.301112 \n",
"8 524.060903 0.011216 4010.773810 \n",
"9 2541.108767 0.006781 7003.905325 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].astype('category')\n",
"voice_noise_df[\"audio\"] = voice_noise_df[\"audio\"].cat.codes\n",
"voice_noise_df.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([132, 229]), array([4, 4]))\n",
"(1356,)\n",
"(1354, 8)\n",
"(array([], dtype=int64), array([], dtype=int64))\n",
"(1083, 8)\n",
"(array([], dtype=int64), array([], dtype=int64))\n",
"(array([], dtype=int64), array([], dtype=int64))\n"
]
}
],
"source": [
"TRAIN_TEST_PERCENTAGE = 0.80\n",
"data = voice_noise_df.loc[:,voice_noise_df.columns != \"audio\"].as_matrix()\n",
"labels = voice_noise_df[\"audio\"].as_matrix()\n",
"\n",
"full_data = np.hstack((data,np.expand_dims(labels,axis = 1)))\n",
"random.shuffle(full_data)\n",
"data = full_data[:,0:-1]\n",
"labels = full_data[:,-1]\n",
"\n",
"nan_indices = np.where(np.isnan(data))\n",
"print(nan_indices)\n",
"all_indices_ = np.ones(len(data),dtype = \"bool\")\n",
"print(all_indices_.shape)\n",
"all_indices_[nan_indices[0]] = False\n",
"data_ = copy.deepcopy(data[all_indices_,:])\n",
"labels_ = copy.deepcopy(labels[all_indices_])\n",
"print(data_.shape)\n",
"print(np.where(np.isnan(data_)))\n",
"\n",
"index_ = int(len(data_)*TRAIN_TEST_PERCENTAGE)\n",
"train_data = data_[0:index_,:]\n",
"train_labels = labels_[0:index_]\n",
"test_data = data_[index_:,:]\n",
"test_labels = labels_[index_:]\n",
"\n",
"print(train_data.shape)\n",
"\n",
"print(np.where(np.isnan(train_data)))\n",
"\n",
"print(np.where(np.isnan(test_data)))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"classifier = tree.DecisionTreeClassifier(max_depth = 3)\n",
"classifier = classifier.fit(train_data,train_labels)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.97416973\n"
]
}
],
"source": [
"prediction = classifier.predict(test_data)\n",
"print(np.mean(np.equal(prediction,test_labels).astype(np.float32)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## For testing new audio"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"TEST_AUDIO_FOLDER = \"/Users/siva/Documents/speaker_recognition/VOD/testwav/\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def create_dataset(ROOT_FOLDER,WINDOW_LENGTH = 5,FRAME_LENGTH = 25):\n",
" os.chdir(ROOT_FOLDER)\n",
" all_audio = os.listdir()\n",
" dataset_dict = {\"ZCR\":[],\"RMS\":[],\"spectral_flux\":[],\\\n",
" \"spectral_centroid\":[],\"spectral_rolloff\":[],\\\n",
" \"bandwidth\":[],\"nwpd\":[],\"rse\":[]}\n",
" for audio in all_audio:\n",
" print(\"****************************\")\n",
" print(\"reading:\",audio)\n",
" sampling_rate,sig = wav.read(ROOT_FOLDER+audio)\n",
" print(\"sampling rate:\",sampling_rate,\"signal length\",len(sig))\n",
" index = 0\n",
" while index+(sampling_rate*WINDOW_LENGTH) < len(sig):\n",
" sample = sig[index:(index+(sampling_rate*WINDOW_LENGTH))]\n",
" ef = extract_features(sample,FRAME_LENGTH,sampling_rate)\n",
" ZCR,RMS,sf,sr,sc,bd,nwpd,rse = ef.return_()\n",
" dataset_dict[\"ZCR\"].append(ZCR)\n",
" dataset_dict[\"RMS\"].append(RMS)\n",
" dataset_dict[\"spectral_flux\"].append(np.mean(sf))\n",
" dataset_dict[\"spectral_centroid\"].append(np.mean(sc))\n",
" dataset_dict[\"spectral_rolloff\"].append(np.mean(sr))\n",
" dataset_dict[\"bandwidth\"].append(np.mean(bd))\n",
" dataset_dict[\"nwpd\"].append(np.mean(nwpd))\n",
" dataset_dict[\"rse\"].append(np.mean(rse))\n",
" index += sampling_rate*WINDOW_LENGTH\n",
" values = dataset_dict.values()\n",
" print([len(e) for e in values])\n",
" print(\"finished\")\n",
" return dataset_dict"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************************\n",
"reading: audiotest08-06-2018-17-01-26.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-17-01-43.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-17-01-56.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-17-02-11.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-16-49-40.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-17-01-12.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-13-12-39.wav\n",
"sampling rate: 16000 signal length 84480\n",
"****************************\n",
"reading: audiotest08-06-2018-17-02-44.wav\n",
"sampling rate: 16000 signal length 84480\n",
"[8, 8, 8, 8, 8, 8, 8, 8]\n",
"finished\n"
]
}
],
"source": [
"features_test_dict = create_dataset(TEST_AUDIO_FOLDER)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMS</th>\n",
" <th>ZCR</th>\n",
" <th>bandwidth</th>\n",
" <th>nwpd</th>\n",
" <th>rse</th>\n",
" <th>spectral_centroid</th>\n",
" <th>spectral_flux</th>\n",
" <th>spectral_rolloff</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>235</td>\n",
" <td>0.148002</td>\n",
" <td>1.800357e+06</td>\n",
" <td>0.323637</td>\n",
" <td>-0.258980</td>\n",
" <td>1043.434918</td>\n",
" <td>0.021088</td>\n",
" <td>6197.853916</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>347</td>\n",
" <td>0.239553</td>\n",
" <td>3.614656e+06</td>\n",
" <td>1.323582</td>\n",
" <td>-0.249343</td>\n",
" <td>1700.323351</td>\n",
" <td>0.008497</td>\n",
" <td>6469.879518</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>271</td>\n",
" <td>0.130902</td>\n",
" <td>1.469745e+06</td>\n",
" <td>1.151613</td>\n",
" <td>-0.253432</td>\n",
" <td>988.019173</td>\n",
" <td>0.014795</td>\n",
" <td>5685.115462</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>140</td>\n",
" <td>0.125277</td>\n",
" <td>1.089120e+06</td>\n",
" <td>-0.566549</td>\n",
" <td>-0.337949</td>\n",
" <td>926.835474</td>\n",
" <td>0.019217</td>\n",
" <td>5463.227912</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>270</td>\n",
" <td>0.093726</td>\n",
" <td>1.016863e+06</td>\n",
" <td>3.289868</td>\n",
" <td>-0.305840</td>\n",
" <td>786.064634</td>\n",
" <td>0.027444</td>\n",
" <td>5227.158635</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>242</td>\n",
" <td>0.179977</td>\n",
" <td>2.751192e+06</td>\n",
" <td>0.520077</td>\n",
" <td>-0.235335</td>\n",
" <td>1297.176396</td>\n",
" <td>0.016272</td>\n",
" <td>6467.871486</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>31</td>\n",
" <td>0.085739</td>\n",
" <td>1.137872e+06</td>\n",
" <td>0.515145</td>\n",
" <td>-0.236153</td>\n",
" <td>641.419533</td>\n",
" <td>0.026151</td>\n",
" <td>4864.646084</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>187</td>\n",
" <td>0.107876</td>\n",
" <td>6.686752e+05</td>\n",
" <td>0.009063</td>\n",
" <td>-0.304043</td>\n",
" <td>802.714148</td>\n",
" <td>0.020602</td>\n",
" <td>5183.985944</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMS ZCR bandwidth nwpd rse spectral_centroid \\\n",
"0 235 0.148002 1.800357e+06 0.323637 -0.258980 1043.434918 \n",
"1 347 0.239553 3.614656e+06 1.323582 -0.249343 1700.323351 \n",
"2 271 0.130902 1.469745e+06 1.151613 -0.253432 988.019173 \n",
"3 140 0.125277 1.089120e+06 -0.566549 -0.337949 926.835474 \n",
"4 270 0.093726 1.016863e+06 3.289868 -0.305840 786.064634 \n",
"5 242 0.179977 2.751192e+06 0.520077 -0.235335 1297.176396 \n",
"6 31 0.085739 1.137872e+06 0.515145 -0.236153 641.419533 \n",
"7 187 0.107876 6.686752e+05 0.009063 -0.304043 802.714148 \n",
"\n",
" spectral_flux spectral_rolloff \n",
"0 0.021088 6197.853916 \n",
"1 0.008497 6469.879518 \n",
"2 0.014795 5685.115462 \n",
"3 0.019217 5463.227912 \n",
"4 0.027444 5227.158635 \n",
"5 0.016272 6467.871486 \n",
"6 0.026151 4864.646084 \n",
"7 0.020602 5183.985944 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_df = pd.DataFrame.from_dict(features_test_dict)\n",
"test_df.replace([np.inf,-np.inf,np.nan],0)\n",
"test_df.dropna()\n",
"test_df_data = test_df.as_matrix()\n",
"test_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 - Voice and 0 - Noise"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 1. 1. 1. 1. 1. 1. 1.]\n"
]
}
],
"source": [
"test_predictions = classifier.predict(test_df_data)\n",
"print(test_predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}