--- a
+++ b/classification.ipynb
@@ -0,0 +1,449 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 61,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "\n",
+    "from sklearn.model_selection import train_test_split, GridSearchCV, PredefinedSplit\n",
+    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
+    "from sklearn import metrics\n",
+    "from sklearn.neighbors import KNeighborsClassifier\n",
+    "from sklearn.svm import SVC\n",
+    "\n",
+    "from keras import models, Input\n",
+    "from keras import optimizers as opt\n",
+    "from keras import backend as K\n",
+    "from keras.layers import Dense\n",
+    "from keras_tuner.tuners import RandomSearch\n",
+    "from tensorflow.keras.utils import to_categorical"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 62,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from dataset import load_dataset, load_labels, split_data, format_labels\n",
+    "from features import time_series_features, fractal_features, entropy_features, hjorth_features, freq_band_features\n",
+    "import variables as v"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Variables"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 63,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_type = \"ica_filtered\"\n",
+    "test_type = \"Arithmetic\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Load Dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset_ = load_dataset(data_type=data_type, test_type=test_type)\n",
+    "dataset = split_data(dataset_, v.SFREQ)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "label_ = load_labels()\n",
+    "label = format_labels(label_, test_type=test_type, epochs=dataset.shape[1])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Compute Features"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# features = time_series_features(dataset)\n",
+    "# freq_bands = np.array([1, 4, 8, 12, 30, 50])\n",
+    "# features = freq_band_features(dataset, freq_bands)\n",
+    "# features = hjorth_features(dataset)\n",
+    "# features = entropy_features(dataset)\n",
+    "features = fractal_features(dataset)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 67,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data = features"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# k-NN Classifier"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x, x_test, y, y_test = train_test_split(\n",
+    "    data, label, test_size=0.2, random_state=1)\n",
+    "x_train, x_val, y_train, y_val = train_test_split(\n",
+    "    x, y, test_size=0.25, random_state=1)\n",
+    "scaler = MinMaxScaler()\n",
+    "scaler.fit(x_train)\n",
+    "x = scaler.transform(x)\n",
+    "x_train = scaler.transform(x_train)\n",
+    "x_val = scaler.transform(x_val)\n",
+    "x_test = scaler.transform(x_test)\n",
+    "\n",
+    "param_grid = {\n",
+    "    'leaf_size': range(50),\n",
+    "    'n_neighbors': range(1, 10),\n",
+    "    'p': [1, 2]\n",
+    "}\n",
+    "split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n",
+    "ps = PredefinedSplit(test_fold=split_index)\n",
+    "knn_clf = GridSearchCV(KNeighborsClassifier(), param_grid, cv=ps, refit=True)\n",
+    "knn_clf.fit(x, y)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "y_pred = knn_clf.predict(x_test)\n",
+    "y_true = y_test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "              precision    recall  f1-score   support\n",
+      "\n",
+      "       False       0.57      0.55      0.56       311\n",
+      "        True       0.54      0.56      0.55       289\n",
+      "\n",
+      "    accuracy                           0.56       600\n",
+      "   macro avg       0.56      0.56      0.55       600\n",
+      "weighted avg       0.56      0.56      0.56       600\n",
+      "\n",
+      "[[170 141]\n",
+      " [126 163]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(metrics.classification_report(y_true, y_pred))\n",
+    "print(metrics.confusion_matrix(y_true, y_pred))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# SVM Classifier"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),\n",
+       "             estimator=SVC(),\n",
+       "             param_grid={&#x27;C&#x27;: [0.1, 1, 10, 100, 1000], &#x27;kernel&#x27;: [&#x27;rbf&#x27;]})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-16\" type=\"checkbox\" ><label for=\"sk-estimator-id-16\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),\n",
+       "             estimator=SVC(),\n",
+       "             param_grid={&#x27;C&#x27;: [0.1, 1, 10, 100, 1000], &#x27;kernel&#x27;: [&#x27;rbf&#x27;]})</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-17\" type=\"checkbox\" ><label for=\"sk-estimator-id-17\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-18\" type=\"checkbox\" ><label for=\"sk-estimator-id-18\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SVC</label><div class=\"sk-toggleable__content\"><pre>SVC()</pre></div></div></div></div></div></div></div></div></div></div>"
+      ],
+      "text/plain": [
+       "GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),\n",
+       "             estimator=SVC(),\n",
+       "             param_grid={'C': [0.1, 1, 10, 100, 1000], 'kernel': ['rbf']})"
+      ]
+     },
+     "execution_count": 51,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x, x_test, y, y_test = train_test_split(\n",
+    "    data, label, test_size=0.2, random_state=1)\n",
+    "x_train, x_val, y_train, y_val = train_test_split(\n",
+    "    x, y, test_size=0.25, random_state=1)\n",
+    "\n",
+    "param_grid = {\n",
+    "    'C': [0.1, 1, 10, 100, 1000],\n",
+    "    'kernel': ['rbf']\n",
+    "}\n",
+    "split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]\n",
+    "ps = PredefinedSplit(test_fold=split_index)\n",
+    "svm_clf = GridSearchCV(SVC(), param_grid, cv=ps, refit=True)\n",
+    "svm_clf.fit(x, y)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "y_pred = svm_clf.predict(x_test)\n",
+    "y_true = y_test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "              precision    recall  f1-score   support\n",
+      "\n",
+      "       False       0.60      0.59      0.59       311\n",
+      "        True       0.56      0.57      0.57       289\n",
+      "\n",
+      "    accuracy                           0.58       600\n",
+      "   macro avg       0.58      0.58      0.58       600\n",
+      "weighted avg       0.58      0.58      0.58       600\n",
+      "\n",
+      "[[183 128]\n",
+      " [124 165]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(metrics.classification_report(y_true, y_pred))\n",
+    "print(metrics.confusion_matrix(y_true, y_pred))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Multilayer Perceptron"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "K.clear_session()\n",
+    "y_v = label\n",
+    "y_v = to_categorical(y_v)\n",
+    "x_train, x_test, y_train, y_test = train_test_split(\n",
+    "    data, y_v, test_size=0.2, random_state=1)\n",
+    "x_train, x_val, y_train, y_val = train_test_split(\n",
+    "    x_train, y_train, test_size=0.25, random_state=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def model_builder(hp):\n",
+    "    model = models.Sequential()\n",
+    "    model.add(Input(shape=(x_train.shape[1],)))\n",
+    "\n",
+    "    for i in range(hp.Int('layers', 2, 6)):\n",
+    "        model.add(Dense(units=hp.Int('units_' + str(i), 32, 1024, step=32),\n",
+    "                        activation=hp.Choice('act_' + str(i), ['relu', 'sigmoid'])))\n",
+    "\n",
+    "    model.add(Dense(v.N_CLASSES, activation='softmax', name='out'))\n",
+    "\n",
+    "    hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
+    "\n",
+    "    model.compile(optimizer=opt.adam_v2.Adam(learning_rate=hp_learning_rate),\n",
+    "                  loss=\"binary_crossentropy\",\n",
+    "                  metrics=['accuracy'])\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tuner = RandomSearch(\n",
+    "    model_builder,\n",
+    "    objective='val_accuracy',\n",
+    "    max_trials=15,\n",
+    "    executions_per_trial=2,\n",
+    "    overwrite=True\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Trial 15 Complete [00h 01m 08s]\n",
+      "val_accuracy: 0.5450000166893005\n",
+      "\n",
+      "Best val_accuracy So Far: 0.5541666746139526\n",
+      "Total elapsed time: 00h 14m 40s\n",
+      "INFO:tensorflow:Oracle triggered exit\n"
+     ]
+    }
+   ],
+   "source": [
+    "tuner.search(x_train, y_train, epochs=50, validation_data=[x_val, y_val])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = tuner.get_best_models(num_models=1)[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-12-11 13:24:08.707790: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
+     ]
+    }
+   ],
+   "source": [
+    "y_pred = model.predict(x_test)\n",
+    "y_true = y_test\n",
+    "y_pred = np.argmax(y_pred, axis=1)\n",
+    "y_true = np.argmax(y_true, axis=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "              precision    recall  f1-score   support\n",
+      "\n",
+      "           0       0.52      0.90      0.66       311\n",
+      "           1       0.52      0.11      0.19       289\n",
+      "\n",
+      "    accuracy                           0.52       600\n",
+      "   macro avg       0.52      0.51      0.43       600\n",
+      "weighted avg       0.52      0.52      0.43       600\n",
+      "\n",
+      "[[281  30]\n",
+      " [256  33]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(metrics.classification_report(y_true, y_pred))\n",
+    "print(metrics.confusion_matrix(y_true, y_pred))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.9.7 ('init')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.7"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "6a0b16b431f91af56543167d2335ade6a4f69621936ac10d0388e1e58aabcd37"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}