--- a +++ b/classification.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from sklearn.model_selection import train_test_split, GridSearchCV, PredefinedSplit\n", + "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", + "from sklearn import metrics\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.svm import SVC\n", + "\n", + "from keras import models, Input\n", + "from keras import optimizers as opt\n", + "from keras import backend as K\n", + "from keras.layers import Dense\n", + "from keras_tuner.tuners import RandomSearch\n", + "from tensorflow.keras.utils import to_categorical" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "from dataset import load_dataset, load_labels, split_data, format_labels\n", + "from features import time_series_features, fractal_features, entropy_features, hjorth_features, freq_band_features\n", + "import variables as v" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Variables" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "data_type = \"ica_filtered\"\n", + "test_type = \"Arithmetic\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_ = load_dataset(data_type=data_type, test_type=test_type)\n", + "dataset = split_data(dataset_, v.SFREQ)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "label_ = load_labels()\n", + "label = format_labels(label_, test_type=test_type, epochs=dataset.shape[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compute Features" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "# features = time_series_features(dataset)\n", + "# freq_bands = np.array([1, 4, 8, 12, 30, 50])\n", + "# features = freq_band_features(dataset, freq_bands)\n", + "# features = hjorth_features(dataset)\n", + "# features = entropy_features(dataset)\n", + "features = fractal_features(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "data = features" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# k-NN Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "x, x_test, y, y_test = train_test_split(\n", + " data, label, test_size=0.2, random_state=1)\n", + "x_train, x_val, y_train, y_val = train_test_split(\n", + " x, y, test_size=0.25, random_state=1)\n", + "scaler = MinMaxScaler()\n", + "scaler.fit(x_train)\n", + "x = scaler.transform(x)\n", + "x_train = scaler.transform(x_train)\n", + "x_val = scaler.transform(x_val)\n", + "x_test = scaler.transform(x_test)\n", + "\n", + "param_grid = {\n", + " 'leaf_size': range(50),\n", + " 'n_neighbors': range(1, 10),\n", + " 'p': [1, 2]\n", + "}\n", + "split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n", + "ps = PredefinedSplit(test_fold=split_index)\n", + "knn_clf = GridSearchCV(KNeighborsClassifier(), param_grid, cv=ps, refit=True)\n", + "knn_clf.fit(x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = knn_clf.predict(x_test)\n", + "y_true = y_test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " False 0.57 0.55 0.56 311\n", + " True 0.54 0.56 0.55 289\n", + "\n", + " accuracy 0.56 600\n", + " macro avg 0.56 0.56 0.55 600\n", + "weighted avg 0.56 0.56 0.56 600\n", + "\n", + "[[170 141]\n", + " [126 163]]\n" + ] + } + ], + "source": [ + "print(metrics.classification_report(y_true, y_pred))\n", + "print(metrics.confusion_matrix(y_true, y_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SVM Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n", + " estimator=SVC(),\n", + " param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-16\" type=\"checkbox\" ><label for=\"sk-estimator-id-16\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n", + " estimator=SVC(),\n", + " param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-17\" type=\"checkbox\" ><label for=\"sk-estimator-id-17\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-18\" type=\"checkbox\" ><label for=\"sk-estimator-id-18\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div></div></div></div></div></div></div></div>" + ], + "text/plain": [ + "GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ..., 0, 0])),\n", + " estimator=SVC(),\n", + " param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x, x_test, y, y_test = train_test_split(\n", + " data, label, test_size=0.2, random_state=1)\n", + "x_train, x_val, y_train, y_val = train_test_split(\n", + " x, y, test_size=0.25, random_state=1)\n", + "\n", + "param_grid = {\n", + " 'C': [0.1, 1, 10, 100, 1000],\n", + " 'kernel': ['rbf']\n", + "}\n", + "split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n", + "ps = PredefinedSplit(test_fold=split_index)\n", + "svm_clf = GridSearchCV(SVC(), param_grid, cv=ps, refit=True)\n", + "svm_clf.fit(x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = svm_clf.predict(x_test)\n", + "y_true = y_test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " False 0.60 0.59 0.59 311\n", + " True 0.56 0.57 0.57 289\n", + "\n", + " accuracy 0.58 600\n", + " macro avg 0.58 0.58 0.58 600\n", + "weighted avg 0.58 0.58 0.58 600\n", + "\n", + "[[183 128]\n", + " [124 165]]\n" + ] + } + ], + "source": [ + "print(metrics.classification_report(y_true, y_pred))\n", + "print(metrics.confusion_matrix(y_true, y_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multilayer Perceptron" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "K.clear_session()\n", + "y_v = label\n", + "y_v = to_categorical(y_v)\n", + "x_train, x_test, y_train, y_test = train_test_split(\n", + " data, y_v, test_size=0.2, random_state=1)\n", + "x_train, x_val, y_train, y_val = train_test_split(\n", + " x_train, y_train, test_size=0.25, random_state=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def model_builder(hp):\n", + " model = models.Sequential()\n", + " model.add(Input(shape=(x_train.shape[1],)))\n", + "\n", + " for i in range(hp.Int('layers', 2, 6)):\n", + " model.add(Dense(units=hp.Int('units_' + str(i), 32, 1024, step=32),\n", + " activation=hp.Choice('act_' + str(i), ['relu', 'sigmoid'])))\n", + "\n", + " model.add(Dense(v.N_CLASSES, activation='softmax', name='out'))\n", + "\n", + " hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n", + "\n", + " model.compile(optimizer=opt.adam_v2.Adam(learning_rate=hp_learning_rate),\n", + " loss=\"binary_crossentropy\",\n", + " metrics=['accuracy'])\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tuner = RandomSearch(\n", + " model_builder,\n", + " objective='val_accuracy',\n", + " max_trials=15,\n", + " executions_per_trial=2,\n", + " overwrite=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trial 15 Complete [00h 01m 08s]\n", + "val_accuracy: 0.5450000166893005\n", + "\n", + "Best val_accuracy So Far: 0.5541666746139526\n", + "Total elapsed time: 00h 14m 40s\n", + "INFO:tensorflow:Oracle triggered exit\n" + ] + } + ], + "source": [ + "tuner.search(x_train, y_train, epochs=50, validation_data=[x_val, y_val])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = tuner.get_best_models(num_models=1)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-12-11 13:24:08.707790: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n" + ] + } + ], + "source": [ + "y_pred = model.predict(x_test)\n", + "y_true = y_test\n", + "y_pred = np.argmax(y_pred, axis=1)\n", + "y_true = np.argmax(y_true, axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.52 0.90 0.66 311\n", + " 1 0.52 0.11 0.19 289\n", + "\n", + " accuracy 0.52 600\n", + " macro avg 0.52 0.51 0.43 600\n", + "weighted avg 0.52 0.52 0.43 600\n", + "\n", + "[[281 30]\n", + " [256 33]]\n" + ] + } + ], + "source": [ + "print(metrics.classification_report(y_true, y_pred))\n", + "print(metrics.confusion_matrix(y_true, y_pred))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.7 ('init')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "6a0b16b431f91af56543167d2335ade6a4f69621936ac10d0388e1e58aabcd37" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}