450 lines (449 with data), 17.9 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from sklearn.model_selection import train_test_split, GridSearchCV, PredefinedSplit\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"from sklearn import metrics\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.svm import SVC\n",
"\n",
"from keras import models, Input\n",
"from keras import optimizers as opt\n",
"from keras import backend as K\n",
"from keras.layers import Dense\n",
"from keras_tuner.tuners import RandomSearch\n",
"from tensorflow.keras.utils import to_categorical"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"from dataset import load_dataset, load_labels, split_data, format_labels\n",
"from features import time_series_features, fractal_features, entropy_features, hjorth_features, freq_band_features\n",
"import variables as v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Variables"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"data_type = \"ica_filtered\"\n",
"test_type = \"Arithmetic\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"dataset_ = load_dataset(data_type=data_type, test_type=test_type)\n",
"dataset = split_data(dataset_, v.SFREQ)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"label_ = load_labels()\n",
"label = format_labels(label_, test_type=test_type, epochs=dataset.shape[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Compute Features"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# features = time_series_features(dataset)\n",
"# freq_bands = np.array([1, 4, 8, 12, 30, 50])\n",
"# features = freq_band_features(dataset, freq_bands)\n",
"# features = hjorth_features(dataset)\n",
"# features = entropy_features(dataset)\n",
"features = fractal_features(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"data = features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# k-NN Classifier"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"x, x_test, y, y_test = train_test_split(\n",
" data, label, test_size=0.2, random_state=1)\n",
"x_train, x_val, y_train, y_val = train_test_split(\n",
" x, y, test_size=0.25, random_state=1)\n",
"scaler = MinMaxScaler()\n",
"scaler.fit(x_train)\n",
"x = scaler.transform(x)\n",
"x_train = scaler.transform(x_train)\n",
"x_val = scaler.transform(x_val)\n",
"x_test = scaler.transform(x_test)\n",
"\n",
"param_grid = {\n",
" 'leaf_size': range(50),\n",
" 'n_neighbors': range(1, 10),\n",
" 'p': [1, 2]\n",
"}\n",
"split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n",
"ps = PredefinedSplit(test_fold=split_index)\n",
"knn_clf = GridSearchCV(KNeighborsClassifier(), param_grid, cv=ps, refit=True)\n",
"knn_clf.fit(x, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = knn_clf.predict(x_test)\n",
"y_true = y_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" False 0.57 0.55 0.56 311\n",
" True 0.54 0.56 0.55 289\n",
"\n",
" accuracy 0.56 600\n",
" macro avg 0.56 0.56 0.55 600\n",
"weighted avg 0.56 0.56 0.56 600\n",
"\n",
"[[170 141]\n",
" [126 163]]\n"
]
}
],
"source": [
"print(metrics.classification_report(y_true, y_pred))\n",
"print(metrics.confusion_matrix(y_true, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SVM Classifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n",
" estimator=SVC(),\n",
" param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-16\" type=\"checkbox\" ><label for=\"sk-estimator-id-16\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n",
" estimator=SVC(),\n",
" param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-17\" type=\"checkbox\" ><label for=\"sk-estimator-id-17\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-18\" type=\"checkbox\" ><label for=\"sk-estimator-id-18\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n",
" estimator=SVC(),\n",
" param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, x_test, y, y_test = train_test_split(\n",
" data, label, test_size=0.2, random_state=1)\n",
"x_train, x_val, y_train, y_val = train_test_split(\n",
" x, y, test_size=0.25, random_state=1)\n",
"\n",
"param_grid = {\n",
" 'C': [0.1, 1, 10, 100, 1000],\n",
" 'kernel': ['rbf']\n",
"}\n",
"split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n",
"ps = PredefinedSplit(test_fold=split_index)\n",
"svm_clf = GridSearchCV(SVC(), param_grid, cv=ps, refit=True)\n",
"svm_clf.fit(x, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = svm_clf.predict(x_test)\n",
"y_true = y_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" False 0.60 0.59 0.59 311\n",
" True 0.56 0.57 0.57 289\n",
"\n",
" accuracy 0.58 600\n",
" macro avg 0.58 0.58 0.58 600\n",
"weighted avg 0.58 0.58 0.58 600\n",
"\n",
"[[183 128]\n",
" [124 165]]\n"
]
}
],
"source": [
"print(metrics.classification_report(y_true, y_pred))\n",
"print(metrics.confusion_matrix(y_true, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multilayer Perceptron"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"K.clear_session()\n",
"y_v = label\n",
"y_v = to_categorical(y_v)\n",
"x_train, x_test, y_train, y_test = train_test_split(\n",
" data, y_v, test_size=0.2, random_state=1)\n",
"x_train, x_val, y_train, y_val = train_test_split(\n",
" x_train, y_train, test_size=0.25, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model_builder(hp):\n",
" model = models.Sequential()\n",
" model.add(Input(shape=(x_train.shape[1],)))\n",
"\n",
" for i in range(hp.Int('layers', 2, 6)):\n",
" model.add(Dense(units=hp.Int('units_' + str(i), 32, 1024, step=32),\n",
" activation=hp.Choice('act_' + str(i), ['relu', 'sigmoid'])))\n",
"\n",
" model.add(Dense(v.N_CLASSES, activation='softmax', name='out'))\n",
"\n",
" hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
"\n",
" model.compile(optimizer=opt.adam_v2.Adam(learning_rate=hp_learning_rate),\n",
" loss=\"binary_crossentropy\",\n",
" metrics=['accuracy'])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tuner = RandomSearch(\n",
" model_builder,\n",
" objective='val_accuracy',\n",
" max_trials=15,\n",
" executions_per_trial=2,\n",
" overwrite=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 15 Complete [00h 01m 08s]\n",
"val_accuracy: 0.5450000166893005\n",
"\n",
"Best val_accuracy So Far: 0.5541666746139526\n",
"Total elapsed time: 00h 14m 40s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
}
],
"source": [
"tuner.search(x_train, y_train, epochs=50, validation_data=[x_val, y_val])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = tuner.get_best_models(num_models=1)[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-11 13:24:08.707790: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
]
}
],
"source": [
"y_pred = model.predict(x_test)\n",
"y_true = y_test\n",
"y_pred = np.argmax(y_pred, axis=1)\n",
"y_true = np.argmax(y_true, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.52 0.90 0.66 311\n",
" 1 0.52 0.11 0.19 289\n",
"\n",
" accuracy 0.52 600\n",
" macro avg 0.52 0.51 0.43 600\n",
"weighted avg 0.52 0.52 0.43 600\n",
"\n",
"[[281 30]\n",
" [256 33]]\n"
]
}
],
"source": [
"print(metrics.classification_report(y_true, y_pred))\n",
"print(metrics.confusion_matrix(y_true, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('init')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6a0b16b431f91af56543167d2335ade6a4f69621936ac10d0388e1e58aabcd37"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}