--- a
+++ b/LSTM - Experiments/LSTM_Experiments.ipynb
@@ -0,0 +1 @@
+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"7zKdtJlIFhOn"},"outputs":[],"source":["import matplotlib.pyplot as plt\n","import pandas as pd\n","import torch\n","import torch.nn as nn\n","\n","import numpy as np\n","import torch.optim as optim\n","import torch.utils.data as data\n","from sklearn.model_selection import train_test_split\n","from sklearn.preprocessing import MinMaxScaler, LabelEncoder\n","import ast\n","from torch.utils.data import DataLoader, TensorDataset\n","from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n","import matplotlib.pyplot as plt"]},{"cell_type":"code","source":["dataset = pd.read_csv('labeled_dataset.csv')"],"metadata":{"id":"MeNmCEtPFzFo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.head(15)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":519},"id":"P9w7w8WWF0Nd","executionInfo":{"status":"ok","timestamp":1692311813192,"user_tz":300,"elapsed":388,"user":{"displayName":"César Mosqueira","userId":"11705195256143475621"}},"outputId":"fde1231c-6d0f-44f8-b547-99975dec11e2"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["           video  group  frame  \\\n","0   video_15.mp4      1      5   \n","1   video_15.mp4      1     11   \n","2   video_15.mp4      1     17   \n","3   video_15.mp4      1     23   \n","4   video_15.mp4      1     29   \n","5   video_15.mp4      1     35   \n","6   video_15.mp4      1     41   \n","7   video_15.mp4      1     47   \n","8   video_15.mp4      1     53   \n","9   video_15.mp4      1     59   \n","10  video_15.mp4      2     71   \n","11  video_15.mp4      2     77   \n","12  video_15.mp4      2     83   \n","13  video_15.mp4      2     89   \n","14  video_15.mp4      2     95   \n","\n","                                            landmarks Label  \n","0   [     334.75      178.55     0.98386      339....   bad  \n","1   [     329.95      181.47     0.99063      334....   bad  \n","2   [      329.7      182.92     0.99079       334...   bad  \n","3   [     329.32      187.55     0.98055      334....   bad  \n","4   [     331.31      194.96       0.985      335....   bad  \n","5   [     326.15      199.31     0.98249      330....   bad  \n","6   [     317.77      207.25     0.97823       321...   bad  \n","7   [     311.16      194.05     0.97995      315....   bad  \n","8   [     308.41      194.28     0.98219      312....   bad  \n","9   [     318.15      184.29     0.98273       323...   bad  \n","10  [     321.08         126     0.99838      332....   bad  \n","11  [     311.02      128.68     0.99853      321....   bad  \n","12  [      301.6       139.2     0.99845      311....   bad  \n","13  [     294.57      158.13      0.9981      303....   bad  \n","14  [     294.98      166.88     0.99792      303....   bad  "],"text/html":["\n","  <div id=\"df-ef560272-0547-4bdb-ade2-64bf1f701d7d\" class=\"colab-df-container\">\n","    <div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>video</th>\n","      <th>group</th>\n","      <th>frame</th>\n","      <th>landmarks</th>\n","      <th>Label</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>5</td>\n","      <td>[     334.75      178.55     0.98386      339....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>11</td>\n","      <td>[     329.95      181.47     0.99063      334....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>17</td>\n","      <td>[      329.7      182.92     0.99079       334...</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>23</td>\n","      <td>[     329.32      187.55     0.98055      334....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>29</td>\n","      <td>[     331.31      194.96       0.985      335....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>5</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>35</td>\n","      <td>[     326.15      199.31     0.98249      330....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>6</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>41</td>\n","      <td>[     317.77      207.25     0.97823       321...</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>7</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>47</td>\n","      <td>[     311.16      194.05     0.97995      315....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>8</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>53</td>\n","      <td>[     308.41      194.28     0.98219      312....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>9</th>\n","      <td>video_15.mp4</td>\n","      <td>1</td>\n","      <td>59</td>\n","      <td>[     318.15      184.29     0.98273       323...</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>10</th>\n","      <td>video_15.mp4</td>\n","      <td>2</td>\n","      <td>71</td>\n","      <td>[     321.08         126     0.99838      332....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>11</th>\n","      <td>video_15.mp4</td>\n","      <td>2</td>\n","      <td>77</td>\n","      <td>[     311.02      128.68     0.99853      321....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>12</th>\n","      <td>video_15.mp4</td>\n","      <td>2</td>\n","      <td>83</td>\n","      <td>[      301.6       139.2     0.99845      311....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>13</th>\n","      <td>video_15.mp4</td>\n","      <td>2</td>\n","      <td>89</td>\n","      <td>[     294.57      158.13      0.9981      303....</td>\n","      <td>bad</td>\n","    </tr>\n","    <tr>\n","      <th>14</th>\n","      <td>video_15.mp4</td>\n","      <td>2</td>\n","      <td>95</td>\n","      <td>[     294.98      166.88     0.99792      303....</td>\n","      <td>bad</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>\n","    <div class=\"colab-df-buttons\">\n","\n","  <div class=\"colab-df-container\">\n","    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ef560272-0547-4bdb-ade2-64bf1f701d7d')\"\n","            title=\"Convert this dataframe to an interactive table.\"\n","            style=\"display:none;\">\n","\n","  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n","    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n","  </svg>\n","    </button>\n","\n","  <style>\n","    .colab-df-container {\n","      display:flex;\n","      gap: 12px;\n","    }\n","\n","    .colab-df-convert {\n","      background-color: #E8F0FE;\n","      border: none;\n","      border-radius: 50%;\n","      cursor: pointer;\n","      display: none;\n","      fill: #1967D2;\n","      height: 32px;\n","      padding: 0 0 0 0;\n","      width: 32px;\n","    }\n","\n","    .colab-df-convert:hover {\n","      background-color: #E2EBFA;\n","      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n","      fill: #174EA6;\n","    }\n","\n","    .colab-df-buttons div {\n","      margin-bottom: 4px;\n","    }\n","\n","    [theme=dark] .colab-df-convert {\n","      background-color: #3B4455;\n","      fill: #D2E3FC;\n","    }\n","\n","    [theme=dark] .colab-df-convert:hover {\n","      background-color: #434B5C;\n","      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n","      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n","      fill: #FFFFFF;\n","    }\n","  </style>\n","\n","    <script>\n","      const buttonEl =\n","        document.querySelector('#df-ef560272-0547-4bdb-ade2-64bf1f701d7d button.colab-df-convert');\n","      buttonEl.style.display =\n","        google.colab.kernel.accessAllowed ? 'block' : 'none';\n","\n","      async function convertToInteractive(key) {\n","        const element = document.querySelector('#df-ef560272-0547-4bdb-ade2-64bf1f701d7d');\n","        const dataTable =\n","          await google.colab.kernel.invokeFunction('convertToInteractive',\n","                                                    [key], {});\n","        if (!dataTable) return;\n","\n","        const docLinkHtml = 'Like what you see? Visit the ' +\n","          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n","          + ' to learn more about interactive tables.';\n","        element.innerHTML = '';\n","        dataTable['output_type'] = 'display_data';\n","        await google.colab.output.renderOutput(dataTable, element);\n","        const docLink = document.createElement('div');\n","        docLink.innerHTML = docLinkHtml;\n","        element.appendChild(docLink);\n","      }\n","    </script>\n","  </div>\n","\n","\n","<div id=\"df-8448df95-760d-4350-a6bb-82553d08a229\">\n","  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-8448df95-760d-4350-a6bb-82553d08a229')\"\n","            title=\"Suggest charts.\"\n","            style=\"display:none;\">\n","\n","<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n","     width=\"24px\">\n","    <g>\n","        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n","    </g>\n","</svg>\n","  </button>\n","\n","<style>\n","  .colab-df-quickchart {\n","    background-color: #E8F0FE;\n","    border: none;\n","    border-radius: 50%;\n","    cursor: pointer;\n","    display: none;\n","    fill: #1967D2;\n","    height: 32px;\n","    padding: 0 0 0 0;\n","    width: 32px;\n","  }\n","\n","  .colab-df-quickchart:hover {\n","    background-color: #E2EBFA;\n","    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n","    fill: #174EA6;\n","  }\n","\n","  [theme=dark] .colab-df-quickchart {\n","    background-color: #3B4455;\n","    fill: #D2E3FC;\n","  }\n","\n","  [theme=dark] .colab-df-quickchart:hover {\n","    background-color: #434B5C;\n","    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n","    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n","    fill: #FFFFFF;\n","  }\n","</style>\n","\n","  <script>\n","    async function quickchart(key) {\n","      const charts = await google.colab.kernel.invokeFunction(\n","          'suggestCharts', [key], {});\n","    }\n","    (() => {\n","      let quickchartButtonEl =\n","        document.querySelector('#df-8448df95-760d-4350-a6bb-82553d08a229 button');\n","      quickchartButtonEl.style.display =\n","        google.colab.kernel.accessAllowed ? 'block' : 'none';\n","    })();\n","  </script>\n","</div>\n","    </div>\n","  </div>\n"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["def is_float(num):\n","    try:\n","        float(num)\n","        return True\n","    except ValueError:\n","        return False"],"metadata":{"id":"qxOmR6ioF1cR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset['landmarks'] = dataset['landmarks'].apply(lambda arr: np.array([float(n) for n in arr.split() if is_float(n)]))"],"metadata":{"id":"P9smPfHIF2qM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset['Label'].value_counts()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"EUaaKCMlF4SV","executionInfo":{"status":"ok","timestamp":1692311817060,"user_tz":300,"elapsed":3,"user":{"displayName":"César Mosqueira","userId":"11705195256143475621"}},"outputId":"a3113cb5-4c84-4081-809c-0cca9c575d4c"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["bad     3444\n","good    1205\n","Name: Label, dtype: int64"]},"metadata":{},"execution_count":10}]},{"cell_type":"code","source":["# Group the data by 'video' and 'group'\n","grouped_data = dataset.groupby(['video', 'group'])\n","\n","# Define the sequence length\n","sequence_length = 10\n","\n","# Create lists to store the sequences and labels\n","sequences = []\n","labels = []\n","\n","# Iterate over each group\n","for group, data in grouped_data:\n","    landmarks = data['landmarks'].tolist()\n","    group_labels = data['Label'].tolist()\n","\n","    # Create sequences of landmarks\n","    for i in range(len(landmarks) - sequence_length + 1):\n","        sequence = landmarks[i:i+sequence_length]\n","        sequences.append(sequence)\n","        labels.append(group_labels[i+sequence_length-1])"],"metadata":{"id":"A0zDHdCnF5tN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["sequences = np.array(sequences)\n","\n","scaler = MinMaxScaler()\n","normalized_sequences = np.zeros_like(sequences)\n","\n","for i in range(sequences.shape[0]):\n","    for j in range(sequences.shape[1]):\n","        # Flatten the landmarks for each set within the sequence\n","        landmarks_flattened = np.reshape(sequences[i, j], (-1, 1))\n","        # Normalize the landmarks\n","        landmarks_normalized = scaler.fit_transform(landmarks_flattened)\n","        # Reshape the normalized landmarks back to the original shape\n","        normalized_landmarks = np.reshape(landmarks_normalized, sequences[i, j].shape)\n","        # Update the normalized landmarks in the sequences array\n","        normalized_sequences[i, j] = normalized_landmarks"],"metadata":{"id":"wGMZpHdEF8Z2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["label_encoder = LabelEncoder()\n","labels_encoded = label_encoder.fit_transform(labels)"],"metadata":{"id":"6SVR-ictF98k"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["label_encoder.transform(['good'])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"O9qvWDJ3F_SC","executionInfo":{"status":"ok","timestamp":1692311835736,"user_tz":300,"elapsed":2,"user":{"displayName":"César Mosqueira","userId":"11705195256143475621"}},"outputId":"7d541e84-4852-4e37-81a8-117a9cc9cd6e"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([1])"]},"metadata":{},"execution_count":14}]},{"cell_type":"code","source":["train_X, test_X, train_y, test_y = train_test_split(normalized_sequences, labels_encoded, test_size=0.2, shuffle=True)"],"metadata":{"id":"10be1q-rGAV1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(train_X.shape)\n","print(train_y.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Hox_iyMwGBq6","executionInfo":{"status":"ok","timestamp":1692311839481,"user_tz":300,"elapsed":2,"user":{"displayName":"César Mosqueira","userId":"11705195256143475621"}},"outputId":"d0a747f3-6383-4034-b00c-4533cc8f337c"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["(355, 10, 50)\n","(355,)\n"]}]},{"cell_type":"code","source":["train_X_tensor = torch.Tensor(train_X)\n","train_y_tensor = torch.Tensor(train_y)"],"metadata":{"id":"kan0J9GVGCrs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# First iteration of LSTM"],"metadata":{"id":"m6iKzio8JNSL"}},{"cell_type":"code","source":["train_dataset = TensorDataset(train_X_tensor, train_y_tensor)\n","train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)"],"metadata":{"id":"msOwbxKyGDvQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class LSTMModel(nn.Module):\n","    def __init__(self, input_size, hidden_size, num_classes):\n","        super(LSTMModel, self).__init__()\n","        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)\n","        self.dropout = nn.Dropout(0.2)\n","        self.fc1 = nn.Linear(hidden_size, 32)\n","        self.fc2 = nn.Linear(32, num_classes)\n","        self.relu = nn.ReLU()\n","\n","    def forward(self, x):\n","        _, (h_n, _) = self.lstm(x)\n","        x = self.dropout(h_n[-1])\n","        x = self.fc1(x)\n","        x = self.fc2(x)\n","        x = self.relu(x)\n","        return x"],"metadata":{"id":"whZxvNyLGFL5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["input_size = train_X.shape[2]\n","hidden_size = 256\n","num_classes = 1\n","num_epochs = 30\n","learning_rate = 0.00001\n","\n","# Instantiate the model\n","model = LSTMModel(input_size, hidden_size, num_classes)"],"metadata":{"id":"Fck9hSdRGKS7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Define the loss function and optimizer\n","criterion = nn.BCELoss()\n","optimizer = optim.SGD(model.parameters(), lr=learning_rate)"],"metadata":{"id":"QlpoZwz1GKyq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_X.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6HZW1eFwGL3D","executionInfo":{"status":"ok","timestamp":1692295580820,"user_tz":300,"elapsed":11,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"229b83bd-8e51-4560-eb55-2eacb4fb641d"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(355, 10, 50)"]},"metadata":{},"execution_count":69}]},{"cell_type":"code","source":["model"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"U0MaHSm9GM4W","executionInfo":{"status":"ok","timestamp":1692295581791,"user_tz":300,"elapsed":5,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"1c9b1d16-49d3-435c-a40f-79fa7c7f0eae"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LSTMModel(\n","  (lstm): LSTM(50, 256, batch_first=True)\n","  (dropout): Dropout(p=0.2, inplace=False)\n","  (fc1): Linear(in_features=256, out_features=32, bias=True)\n","  (fc2): Linear(in_features=32, out_features=1, bias=True)\n","  (relu): ReLU()\n",")"]},"metadata":{},"execution_count":70}]},{"cell_type":"code","source":["epoch_accuracy = []\n","epoch_f1_score = []\n","epoch_recall = []\n","epoch_loss = []\n","\n","for epoch in range(num_epochs):\n","    true_labels = []\n","    predicted_labels = []\n","    for inputs, labels in train_dataloader:\n","        # Zero the gradients\n","        optimizer.zero_grad()\n","\n","        # Forward pass\n","        outputs = model(inputs)\n","        predictions = torch.round(torch.sigmoid(outputs.squeeze()))\n","\n","        true_labels.extend(labels.numpy())\n","        predicted_labels.extend(predictions.detach().numpy())\n","\n","        loss = criterion(outputs.squeeze(), labels)\n","\n","        # Backward pass and optimization\n","        loss.backward()\n","        optimizer.step()\n","\n","    # Calculate metrics\n","    accuracy = accuracy_score(true_labels, predicted_labels)\n","    f1 = f1_score(true_labels, predicted_labels)\n","    recall = recall_score(true_labels, predicted_labels)\n","\n","    # Store metrics for each epoch\n","    epoch_accuracy.append(accuracy)\n","    epoch_f1_score.append(f1)\n","    epoch_recall.append(recall)\n","    epoch_loss.append(loss.item())\n","    # Print the metrics for every epoch\n","    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {accuracy}, F1 Score: {f1}, Recall: {recall}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vZfo2EV7GN8H","executionInfo":{"status":"ok","timestamp":1692295594209,"user_tz":300,"elapsed":7890,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"7a2199eb-6b69-443a-89d1-cf71269c8f49"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/30, Loss: 3.30371356010437, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 2/30, Loss: 0.0640532448887825, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 3/30, Loss: 0.9332453608512878, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 4/30, Loss: 0.9942008852958679, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 5/30, Loss: 2.844069480895996, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 6/30, Loss: 1.0282055139541626, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 7/30, Loss: 0.07186160236597061, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 8/30, Loss: 0.9316868185997009, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 9/30, Loss: 0.07316235452890396, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 10/30, Loss: 0.06532102078199387, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 11/30, Loss: 0.05990065634250641, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 12/30, Loss: 0.9599320292472839, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 13/30, Loss: 1.0268315076828003, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 14/30, Loss: 1.7993221282958984, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 15/30, Loss: 0.9321920275688171, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 16/30, Loss: 0.0565328486263752, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 17/30, Loss: 1.7288103103637695, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 18/30, Loss: 1.0498160123825073, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 19/30, Loss: 0.057037562131881714, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 20/30, Loss: 0.08153844624757767, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 21/30, Loss: 0.9557292461395264, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 22/30, Loss: 0.9206430315971375, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 23/30, Loss: 1.7758244276046753, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 24/30, Loss: 0.8312497138977051, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 25/30, Loss: 1.63068687915802, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 26/30, Loss: 0.8575438857078552, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 27/30, Loss: 0.09230203181505203, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 28/30, Loss: 0.8663346767425537, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 29/30, Loss: 0.07764268666505814, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 30/30, Loss: 1.0366336107254028, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n"]}]},{"cell_type":"markdown","source":["# Seconrd Iteration of LSTM"],"metadata":{"id":"6nEKstr5JTxk"}},{"cell_type":"code","source":["train_dataset = TensorDataset(train_X_tensor, train_y_tensor)\n","train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)"],"metadata":{"id":"tMcRmy1mKEB_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ImprovedLSTMModel(nn.Module):\n","    def __init__(self, input_size, hidden_size, num_classes, num_layers=1, bidirectional=True, dropout_rate=0.2):\n","        super(ImprovedLSTMModel, self).__init__()\n","        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)\n","        self.dropout = nn.Dropout(dropout_rate)\n","\n","        # Calculate the correct input size for the first linear layer based on bidirectional LSTM\n","        fc1_input_size = hidden_size * (2 if bidirectional else 1) * num_layers\n","        self.fc1 = nn.Linear(fc1_input_size, 128)  # Use 128 units for the first dense layer\n","        self.fc2 = nn.Linear(128, num_classes)\n","        self.relu = nn.ReLU()\n","\n","    def forward(self, x):\n","        _, (h_n, _) = self.lstm(x)\n","        h_n_concat = h_n.permute(1, 0, 2).contiguous().view(h_n.shape[1], -1)  # Reshape h_n for linear layer\n","        x = self.dropout(h_n_concat)\n","        x = self.fc1(x)\n","        x = self.fc2(x)\n","        x = self.relu(x)\n","        return x"],"metadata":{"id":"n1qUiIfSJXXp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["input_size = train_X.shape[2]\n","hidden_size = 256\n","num_classes = 1\n","num_layers = 2\n","bidirectional = True\n","dropout_rate = 0.2\n","learning_rate = 0.00001\n","\n","# Instantiate the improved model\n","model = ImprovedLSTMModel(input_size, hidden_size, num_classes, num_layers, bidirectional, dropout_rate)\n","\n","# Define the loss function and optimizer\n","criterion = nn.BCELoss()\n","optimizer = optim.SGD(model.parameters(), lr=learning_rate)"],"metadata":{"id":"ZpxWbsioJg0b"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"LEYM9BBfJpJx","executionInfo":{"status":"ok","timestamp":1692295679473,"user_tz":300,"elapsed":958,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"c9fbd8e6-44f3-44ed-f1cc-2355396a258b"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["ImprovedLSTMModel(\n","  (lstm): LSTM(50, 256, num_layers=2, batch_first=True, bidirectional=True)\n","  (dropout): Dropout(p=0.2, inplace=False)\n","  (fc1): Linear(in_features=1024, out_features=128, bias=True)\n","  (fc2): Linear(in_features=128, out_features=1, bias=True)\n","  (relu): ReLU()\n",")"]},"metadata":{},"execution_count":75}]},{"cell_type":"code","source":["epoch_accuracy = []\n","epoch_f1_score = []\n","epoch_recall = []\n","epoch_loss = []\n","\n","for epoch in range(num_epochs):\n","    true_labels = []\n","    predicted_labels = []\n","    for inputs, labels in train_dataloader:\n","        # Zero the gradients\n","        optimizer.zero_grad()\n","\n","        # Forward pass\n","        outputs = model(inputs)\n","        predictions = torch.round(torch.sigmoid(outputs.squeeze()))\n","\n","        true_labels.extend(labels.numpy())\n","        predicted_labels.extend(predictions.detach().numpy())\n","\n","        loss = criterion(outputs.squeeze(), labels)\n","\n","        # Backward pass and optimization\n","        loss.backward()\n","        optimizer.step()\n","\n","    # Calculate metrics\n","    accuracy = accuracy_score(true_labels, predicted_labels)\n","    f1 = f1_score(true_labels, predicted_labels)\n","    recall = recall_score(true_labels, predicted_labels)\n","\n","    # Store metrics for each epoch\n","    epoch_accuracy.append(accuracy)\n","    epoch_f1_score.append(f1)\n","    epoch_recall.append(recall)\n","    epoch_loss.append(loss.item())\n","    # Print the metrics for every epoch\n","    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {accuracy}, F1 Score: {f1}, Recall: {recall}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XEIU1OidKGTR","executionInfo":{"status":"ok","timestamp":1692295728092,"user_tz":300,"elapsed":47843,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"4deb738f-129d-4e5d-e38b-c83aaa3e611d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/30, Loss: 0.14196285605430603, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 2/30, Loss: 0.7603593468666077, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 3/30, Loss: 0.1483137458562851, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 4/30, Loss: 0.1442556381225586, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 5/30, Loss: 2.0081534385681152, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 6/30, Loss: 2.103062391281128, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 7/30, Loss: 0.7612888216972351, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 8/30, Loss: 0.7601293921470642, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 9/30, Loss: 0.15363752841949463, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 10/30, Loss: 0.1545637995004654, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 11/30, Loss: 0.14706026017665863, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 12/30, Loss: 0.7576022744178772, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 13/30, Loss: 0.7605314254760742, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 14/30, Loss: 1.419185996055603, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 15/30, Loss: 1.3716388940811157, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 16/30, Loss: 1.329178810119629, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 17/30, Loss: 0.743213951587677, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 18/30, Loss: 0.8101215362548828, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 19/30, Loss: 0.1562623828649521, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 20/30, Loss: 0.1567939668893814, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 21/30, Loss: 1.3488202095031738, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 22/30, Loss: 0.1470808982849121, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 23/30, Loss: 0.1572873443365097, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 24/30, Loss: 0.14903366565704346, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 25/30, Loss: 0.7512195706367493, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 26/30, Loss: 0.711514949798584, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 27/30, Loss: 0.7488934993743896, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 28/30, Loss: 0.15341293811798096, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 29/30, Loss: 0.7657694816589355, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 30/30, Loss: 0.15584520995616913, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n"]}]},{"cell_type":"markdown","source":["# Thrird iteration of LSTM"],"metadata":{"id":"qcdyA9WGMwUw"}},{"cell_type":"code","source":["train_dataset = TensorDataset(train_X_tensor, train_y_tensor)\n","train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)"],"metadata":{"id":"VqEEGw2oKL5a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class LSTMModel(nn.Module):\n","    def __init__(self, input_size, hidden_size, num_classes):\n","        super(LSTMModel, self).__init__()\n","        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)\n","        self.dropout = nn.Dropout(0.2)\n","        self.fc1 = nn.Linear(hidden_size, 32)\n","        self.relu = nn.ReLU6()\n","        self.fc2 = nn.Linear(32, num_classes)\n","        self.relu = nn.ReLU6()\n","\n","    def forward(self, x):\n","        _, (h_n, _) = self.lstm(x)\n","        x = self.dropout(h_n[-1])\n","        x = self.fc1(x)\n","        x = self.fc2(x)\n","        x = self.relu(x)\n","        return x"],"metadata":{"id":"1TAwRu-WVzJf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["input_size = train_X.shape[2]\n","hidden_size = 256\n","num_classes = 1\n","num_epochs = 20\n","learning_rate = 0.01\n","\n","# Instantiate the model\n","model = LSTMModel(input_size, hidden_size, num_classes)\n","\n","# Define the loss function and optimizer\n","criterion = nn.BCELoss()\n","optimizer = optim.SGD(model.parameters(), lr=learning_rate)"],"metadata":{"id":"hQjiUGqJM9VF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"469B6ZaHVu31","executionInfo":{"status":"ok","timestamp":1692295813086,"user_tz":300,"elapsed":7,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"f95e0662-4fc6-4a5f-d63c-55cde09f18cb"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LSTMModel(\n","  (lstm): LSTM(50, 256, batch_first=True)\n","  (dropout): Dropout(p=0.2, inplace=False)\n","  (fc1): Linear(in_features=256, out_features=32, bias=True)\n","  (relu): ReLU6()\n","  (fc2): Linear(in_features=32, out_features=1, bias=True)\n",")"]},"metadata":{},"execution_count":84}]},{"cell_type":"code","source":["epoch_accuracy = []\n","epoch_f1_score = []\n","epoch_recall = []\n","epoch_loss = []\n","\n","for epoch in range(num_epochs):\n","    true_labels = []\n","    predicted_labels = []\n","    for inputs, labels in train_dataloader:\n","        # Zero the gradients\n","        optimizer.zero_grad()\n","\n","        # Forward pass\n","        outputs = model(inputs)\n","        predictions = torch.round(torch.sigmoid(outputs.squeeze()))\n","\n","        true_labels.extend(labels.numpy())\n","        predicted_labels.extend(predictions.detach().numpy())\n","\n","        loss = criterion(outputs.squeeze(), labels)\n","\n","        # Backward pass and optimization\n","        loss.backward()\n","        optimizer.step()\n","\n","    # Calculate metrics\n","    accuracy = accuracy_score(true_labels, predicted_labels)\n","    f1 = f1_score(true_labels, predicted_labels)\n","    recall = recall_score(true_labels, predicted_labels)\n","\n","    # Store metrics for each epoch\n","    epoch_accuracy.append(accuracy)\n","    epoch_f1_score.append(f1)\n","    epoch_recall.append(recall)\n","    epoch_loss.append(loss.item())\n","    # Print the metrics for every epoch\n","    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {accuracy}, F1 Score: {f1}, Recall: {recall}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gBn4JcRcVsgq","executionInfo":{"status":"ok","timestamp":1692295819755,"user_tz":300,"elapsed":5947,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"ffb30456-1a91-4c7e-8c82-a75f1aa14851"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/20, Loss: 0.6219484806060791, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 2/20, Loss: 0.6616726517677307, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 3/20, Loss: 0.6715094447135925, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 4/20, Loss: 0.32363080978393555, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 5/20, Loss: 0.6457496285438538, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 6/20, Loss: 0.27271386981010437, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 7/20, Loss: 0.28883811831474304, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 8/20, Loss: 0.3118828535079956, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 9/20, Loss: 0.266092449426651, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 10/20, Loss: 0.29886820912361145, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 11/20, Loss: 0.2565235197544098, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 12/20, Loss: 0.2884318232536316, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 13/20, Loss: 0.22864384949207306, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 14/20, Loss: 1.2777477502822876, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 15/20, Loss: 0.3226184546947479, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 16/20, Loss: 0.944756805896759, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 17/20, Loss: 0.989539623260498, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 18/20, Loss: 0.28581663966178894, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 19/20, Loss: 0.6567023396492004, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n","Epoch 20/20, Loss: 0.5800800919532776, Accuracy: 0.2591549295774648, F1 Score: 0.41163310961968674, Recall: 1.0\n"]}]},{"cell_type":"markdown","source":["# KERAS"],"metadata":{"id":"zzbvFblpAW6l"}},{"cell_type":"markdown","source":["## 1 Iteration"],"metadata":{"id":"f2Jn64iP9f7W"}},{"cell_type":"code","source":["train_dataset = TensorDataset(train_X_tensor, train_y_tensor)\n","train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)"],"metadata":{"id":"G6pUQcBFAcCN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import numpy as np\n","from keras.models import Sequential\n","from keras.layers import LSTM, Dense, Dropout"],"metadata":{"id":"GcaXwWAiAbt5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["num_features = normalized_sequences.shape[2]\n","# Define the LSTM model\n","model = Sequential()\n","model.add(LSTM(units=64, input_shape=(sequence_length, num_features), return_sequences=True))\n","model.add(Dropout(0.2))\n","model.add(LSTM(units=64))\n","model.add(Dropout(0.2))\n","model.add(Dense(units=1, activation='sigmoid'))\n","\n","# Compile the model\n","model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n","\n","# Print the model summary\n","model.summary()\n","\n","# Train the model\n","batch_size = 32\n","epochs = 30\n","history = model.fit(train_X, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.2)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qQ0CwkQwV2b5","executionInfo":{"status":"ok","timestamp":1692301346068,"user_tz":300,"elapsed":14789,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"4d832068-bd15-49ee-9c04-ca9005602ee3"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Model: \"sequential_2\"\n","_________________________________________________________________\n"," Layer (type)                Output Shape              Param #   \n","=================================================================\n"," lstm_2 (LSTM)               (None, 10, 64)            29440     \n","                                                                 \n"," dropout_2 (Dropout)         (None, 10, 64)            0         \n","                                                                 \n"," lstm_3 (LSTM)               (None, 64)                33024     \n","                                                                 \n"," dropout_3 (Dropout)         (None, 64)                0         \n","                                                                 \n"," dense_1 (Dense)             (None, 1)                 65        \n","                                                                 \n","=================================================================\n","Total params: 62,529\n","Trainable params: 62,529\n","Non-trainable params: 0\n","_________________________________________________________________\n","Epoch 1/30\n","9/9 [==============================] - 6s 208ms/step - loss: 0.6018 - accuracy: 0.7430 - val_loss: 0.6794 - val_accuracy: 0.6761\n","Epoch 2/30\n","9/9 [==============================] - 0s 42ms/step - loss: 0.5474 - accuracy: 0.7570 - val_loss: 0.6236 - val_accuracy: 0.6761\n","Epoch 3/30\n","9/9 [==============================] - 0s 45ms/step - loss: 0.5487 - accuracy: 0.7570 - val_loss: 0.6342 - val_accuracy: 0.6761\n","Epoch 4/30\n","9/9 [==============================] - 0s 45ms/step - loss: 0.5318 - accuracy: 0.7570 - val_loss: 0.6125 - val_accuracy: 0.6761\n","Epoch 5/30\n","9/9 [==============================] - 0s 46ms/step - loss: 0.4928 - accuracy: 0.7570 - val_loss: 0.5275 - val_accuracy: 0.7183\n","Epoch 6/30\n","9/9 [==============================] - 0s 42ms/step - loss: 0.5334 - accuracy: 0.7606 - val_loss: 0.5935 - val_accuracy: 0.6761\n","Epoch 7/30\n","9/9 [==============================] - 0s 38ms/step - loss: 0.4558 - accuracy: 0.7852 - val_loss: 0.4841 - val_accuracy: 0.7183\n","Epoch 8/30\n","9/9 [==============================] - 0s 34ms/step - loss: 0.4072 - accuracy: 0.7852 - val_loss: 0.4071 - val_accuracy: 0.7746\n","Epoch 9/30\n","9/9 [==============================] - 0s 43ms/step - loss: 0.3700 - accuracy: 0.7993 - val_loss: 0.3723 - val_accuracy: 0.8028\n","Epoch 10/30\n","9/9 [==============================] - 0s 44ms/step - loss: 0.3615 - accuracy: 0.8063 - val_loss: 0.4775 - val_accuracy: 0.7887\n","Epoch 11/30\n","9/9 [==============================] - 0s 37ms/step - loss: 0.3962 - accuracy: 0.7993 - val_loss: 0.6988 - val_accuracy: 0.6761\n","Epoch 12/30\n","9/9 [==============================] - 0s 27ms/step - loss: 0.4335 - accuracy: 0.7852 - val_loss: 0.4021 - val_accuracy: 0.9014\n","Epoch 13/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3833 - accuracy: 0.8204 - val_loss: 0.4010 - val_accuracy: 0.7887\n","Epoch 14/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3861 - accuracy: 0.8063 - val_loss: 0.4407 - val_accuracy: 0.7746\n","Epoch 15/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3911 - accuracy: 0.8204 - val_loss: 0.3450 - val_accuracy: 0.9014\n","Epoch 16/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3624 - accuracy: 0.8486 - val_loss: 0.3330 - val_accuracy: 0.8169\n","Epoch 17/30\n","9/9 [==============================] - 0s 25ms/step - loss: 0.3641 - accuracy: 0.8063 - val_loss: 0.5208 - val_accuracy: 0.7746\n","Epoch 18/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3684 - accuracy: 0.8239 - val_loss: 0.3172 - val_accuracy: 0.8310\n","Epoch 19/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3354 - accuracy: 0.8345 - val_loss: 0.2897 - val_accuracy: 0.9014\n","Epoch 20/30\n","9/9 [==============================] - 0s 26ms/step - loss: 0.3454 - accuracy: 0.8134 - val_loss: 0.3917 - val_accuracy: 0.8028\n","Epoch 21/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3410 - accuracy: 0.8345 - val_loss: 0.3124 - val_accuracy: 0.8451\n","Epoch 22/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3278 - accuracy: 0.8345 - val_loss: 0.2882 - val_accuracy: 0.9155\n","Epoch 23/30\n","9/9 [==============================] - 0s 25ms/step - loss: 0.3276 - accuracy: 0.8345 - val_loss: 0.3020 - val_accuracy: 0.8592\n","Epoch 24/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3348 - accuracy: 0.8310 - val_loss: 0.2724 - val_accuracy: 0.9155\n","Epoch 25/30\n","9/9 [==============================] - 0s 29ms/step - loss: 0.3508 - accuracy: 0.8275 - val_loss: 0.2979 - val_accuracy: 0.9014\n","Epoch 26/30\n","9/9 [==============================] - 0s 26ms/step - loss: 0.3341 - accuracy: 0.8662 - val_loss: 0.4384 - val_accuracy: 0.8028\n","Epoch 27/30\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3112 - accuracy: 0.8521 - val_loss: 0.2568 - val_accuracy: 0.9296\n","Epoch 28/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3373 - accuracy: 0.8486 - val_loss: 0.3282 - val_accuracy: 0.8451\n","Epoch 29/30\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3946 - accuracy: 0.7993 - val_loss: 0.4196 - val_accuracy: 0.7465\n","Epoch 30/30\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3253 - accuracy: 0.8345 - val_loss: 0.3035 - val_accuracy: 0.8451\n"]}]},{"cell_type":"code","source":["# Evaluate the model\n","test_loss, test_accuracy = model.evaluate(test_X, test_y)\n","print(\"Test Loss:\", test_loss)\n","print(\"Test Accuracy:\", test_accuracy)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mUJzwnRiAjTR","executionInfo":{"status":"ok","timestamp":1692301355734,"user_tz":300,"elapsed":605,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"b2a14763-1112-4427-e14b-bfc11ac33426"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["3/3 [==============================] - 0s 9ms/step - loss: 0.3437 - accuracy: 0.8539\n","Test Loss: 0.3437446355819702\n","Test Accuracy: 0.8539325594902039\n"]}]},{"cell_type":"code","source":["# Plot training and validation metrics\n","plt.figure(figsize=(10, 4))\n","plt.subplot(1, 2, 1)\n","plt.plot(history.history['loss'], label='Training Loss')\n","plt.plot(history.history['val_loss'], label='Validation Loss')\n","plt.xlabel('Epoch')\n","plt.ylabel('Loss')\n","plt.legend()\n","\n","plt.subplot(1, 2, 2)\n","plt.plot(history.history['accuracy'], label='Training Accuracy')\n","plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n","plt.xlabel('Epoch')\n","plt.ylabel('Accuracy')\n","plt.legend()\n","\n","plt.tight_layout()\n","plt.show()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":406},"id":"Q2QfFCBcAwAv","executionInfo":{"status":"ok","timestamp":1692301359426,"user_tz":300,"elapsed":969,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"95875b70-9956-49f1-9900-5fecda96acc5"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":["<Figure size 1000x400 with 2 Axes>"],"image/png":"\n"},"metadata":{}}]},{"cell_type":"markdown","source":["## 2 iteration"],"metadata":{"id":"mTInRUCaDOlN"}},{"cell_type":"code","source":["from keras.optimizers import Adam"],"metadata":{"id":"jXw8MdvnD0jo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["num_features = normalized_sequences.shape[2]\n","# Define the LSTM model\n","model = Sequential()\n","model.add(LSTM(units=64, input_shape=(sequence_length, num_features), return_sequences=True))\n","model.add(Dropout(0.5))\n","model.add(LSTM(units=64))\n","model.add(Dropout(0.5))\n","model.add(Dense(units=1, activation='sigmoid'))\n","\n","custom_learning_rate = 0.001\n","optimizer = Adam(learning_rate=custom_learning_rate)\n","\n","# Compile the model\n","model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n","\n","# Print the model summary\n","model.summary()\n","\n","# Train the model\n","batch_size = 32\n","epochs = 60\n","history = model.fit(train_X, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.2, verbose=1)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NM-mOMJUBG9d","executionInfo":{"status":"ok","timestamp":1692302229233,"user_tz":300,"elapsed":24949,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"b17c4bb3-e63c-4d03-f517-5447fe886c73"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Model: \"sequential_9\"\n","_________________________________________________________________\n"," Layer (type)                Output Shape              Param #   \n","=================================================================\n"," lstm_16 (LSTM)              (None, 10, 64)            29440     \n","                                                                 \n"," dropout_16 (Dropout)        (None, 10, 64)            0         \n","                                                                 \n"," lstm_17 (LSTM)              (None, 64)                33024     \n","                                                                 \n"," dropout_17 (Dropout)        (None, 64)                0         \n","                                                                 \n"," dense_8 (Dense)             (None, 1)                 65        \n","                                                                 \n","=================================================================\n","Total params: 62,529\n","Trainable params: 62,529\n","Non-trainable params: 0\n","_________________________________________________________________\n","Epoch 1/60\n","9/9 [==============================] - 5s 137ms/step - loss: 0.5819 - accuracy: 0.7500 - val_loss: 0.6935 - val_accuracy: 0.6761\n","Epoch 2/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.5605 - accuracy: 0.7570 - val_loss: 0.6287 - val_accuracy: 0.6761\n","Epoch 3/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.5663 - accuracy: 0.7570 - val_loss: 0.6188 - val_accuracy: 0.6761\n","Epoch 4/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.5395 - accuracy: 0.7570 - val_loss: 0.6263 - val_accuracy: 0.6761\n","Epoch 5/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.5264 - accuracy: 0.7606 - val_loss: 0.5813 - val_accuracy: 0.6761\n","Epoch 6/60\n","9/9 [==============================] - 0s 21ms/step - loss: 0.4948 - accuracy: 0.7570 - val_loss: 0.5450 - val_accuracy: 0.7183\n","Epoch 7/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.4424 - accuracy: 0.7676 - val_loss: 0.5160 - val_accuracy: 0.7465\n","Epoch 8/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.4390 - accuracy: 0.7958 - val_loss: 0.5872 - val_accuracy: 0.7465\n","Epoch 9/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.4254 - accuracy: 0.7782 - val_loss: 0.5325 - val_accuracy: 0.7465\n","Epoch 10/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.4564 - accuracy: 0.7394 - val_loss: 0.5347 - val_accuracy: 0.7324\n","Epoch 11/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.4075 - accuracy: 0.7923 - val_loss: 0.3847 - val_accuracy: 0.7887\n","Epoch 12/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.4177 - accuracy: 0.7923 - val_loss: 0.4642 - val_accuracy: 0.7465\n","Epoch 13/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.4260 - accuracy: 0.7923 - val_loss: 0.5127 - val_accuracy: 0.7465\n","Epoch 14/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3976 - accuracy: 0.7852 - val_loss: 0.3617 - val_accuracy: 0.8169\n","Epoch 15/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3906 - accuracy: 0.8063 - val_loss: 0.3605 - val_accuracy: 0.8732\n","Epoch 16/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3821 - accuracy: 0.7958 - val_loss: 0.4222 - val_accuracy: 0.7887\n","Epoch 17/60\n","9/9 [==============================] - 0s 26ms/step - loss: 0.3660 - accuracy: 0.7887 - val_loss: 0.4211 - val_accuracy: 0.7887\n","Epoch 18/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3776 - accuracy: 0.7923 - val_loss: 0.3312 - val_accuracy: 0.8732\n","Epoch 19/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3680 - accuracy: 0.8134 - val_loss: 0.3497 - val_accuracy: 0.9296\n","Epoch 20/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3578 - accuracy: 0.8380 - val_loss: 0.5053 - val_accuracy: 0.7606\n","Epoch 21/60\n","9/9 [==============================] - 0s 26ms/step - loss: 0.4147 - accuracy: 0.8063 - val_loss: 0.4343 - val_accuracy: 0.7606\n","Epoch 22/60\n","9/9 [==============================] - 0s 39ms/step - loss: 0.3787 - accuracy: 0.7887 - val_loss: 0.3581 - val_accuracy: 0.8732\n","Epoch 23/60\n","9/9 [==============================] - 0s 41ms/step - loss: 0.3746 - accuracy: 0.8063 - val_loss: 0.4687 - val_accuracy: 0.7606\n","Epoch 24/60\n","9/9 [==============================] - 0s 40ms/step - loss: 0.3559 - accuracy: 0.8239 - val_loss: 0.3165 - val_accuracy: 0.8873\n","Epoch 25/60\n","9/9 [==============================] - 0s 33ms/step - loss: 0.3335 - accuracy: 0.8169 - val_loss: 0.3215 - val_accuracy: 0.8310\n","Epoch 26/60\n","9/9 [==============================] - 0s 34ms/step - loss: 0.3496 - accuracy: 0.8134 - val_loss: 0.4068 - val_accuracy: 0.8028\n","Epoch 27/60\n","9/9 [==============================] - 0s 37ms/step - loss: 0.3751 - accuracy: 0.7993 - val_loss: 0.5081 - val_accuracy: 0.7606\n","Epoch 28/60\n","9/9 [==============================] - 0s 37ms/step - loss: 0.3561 - accuracy: 0.8415 - val_loss: 0.3360 - val_accuracy: 0.9014\n","Epoch 29/60\n","9/9 [==============================] - 0s 35ms/step - loss: 0.3496 - accuracy: 0.8415 - val_loss: 0.3543 - val_accuracy: 0.8028\n","Epoch 30/60\n","9/9 [==============================] - 0s 40ms/step - loss: 0.3512 - accuracy: 0.8275 - val_loss: 0.4070 - val_accuracy: 0.7887\n","Epoch 31/60\n","9/9 [==============================] - 0s 28ms/step - loss: 0.3476 - accuracy: 0.8134 - val_loss: 0.2978 - val_accuracy: 0.9155\n","Epoch 32/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3246 - accuracy: 0.8275 - val_loss: 0.2865 - val_accuracy: 0.9296\n","Epoch 33/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3283 - accuracy: 0.8662 - val_loss: 0.3079 - val_accuracy: 0.8310\n","Epoch 34/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3164 - accuracy: 0.8556 - val_loss: 0.4392 - val_accuracy: 0.8028\n","Epoch 35/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3153 - accuracy: 0.8486 - val_loss: 0.3293 - val_accuracy: 0.8310\n","Epoch 36/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.3461 - accuracy: 0.8415 - val_loss: 0.3007 - val_accuracy: 0.9155\n","Epoch 37/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3288 - accuracy: 0.8451 - val_loss: 0.3310 - val_accuracy: 0.8310\n","Epoch 38/60\n","9/9 [==============================] - 0s 25ms/step - loss: 0.3282 - accuracy: 0.8310 - val_loss: 0.3204 - val_accuracy: 0.8592\n","Epoch 39/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3088 - accuracy: 0.8239 - val_loss: 0.4563 - val_accuracy: 0.7887\n","Epoch 40/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3047 - accuracy: 0.8486 - val_loss: 0.2776 - val_accuracy: 0.9014\n","Epoch 41/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.3029 - accuracy: 0.8380 - val_loss: 0.2660 - val_accuracy: 0.9155\n","Epoch 42/60\n","9/9 [==============================] - 0s 21ms/step - loss: 0.2952 - accuracy: 0.8380 - val_loss: 0.3441 - val_accuracy: 0.8592\n","Epoch 43/60\n","9/9 [==============================] - 0s 21ms/step - loss: 0.3133 - accuracy: 0.8451 - val_loss: 0.2432 - val_accuracy: 0.9296\n","Epoch 44/60\n","9/9 [==============================] - 0s 25ms/step - loss: 0.3076 - accuracy: 0.8345 - val_loss: 0.3246 - val_accuracy: 0.8451\n","Epoch 45/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2683 - accuracy: 0.8768 - val_loss: 0.3763 - val_accuracy: 0.8310\n","Epoch 46/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2869 - accuracy: 0.8592 - val_loss: 0.2761 - val_accuracy: 0.9014\n","Epoch 47/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2748 - accuracy: 0.8768 - val_loss: 0.2327 - val_accuracy: 0.9296\n","Epoch 48/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.2771 - accuracy: 0.8627 - val_loss: 0.3769 - val_accuracy: 0.8451\n","Epoch 49/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.2532 - accuracy: 0.8908 - val_loss: 0.3743 - val_accuracy: 0.8592\n","Epoch 50/60\n","9/9 [==============================] - 0s 25ms/step - loss: 0.2605 - accuracy: 0.8838 - val_loss: 0.3267 - val_accuracy: 0.8592\n","Epoch 51/60\n","9/9 [==============================] - 0s 22ms/step - loss: 0.2753 - accuracy: 0.8662 - val_loss: 0.3762 - val_accuracy: 0.8732\n","Epoch 52/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2641 - accuracy: 0.8627 - val_loss: 0.2680 - val_accuracy: 0.8873\n","Epoch 53/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2692 - accuracy: 0.8944 - val_loss: 0.2402 - val_accuracy: 0.9014\n","Epoch 54/60\n","9/9 [==============================] - 0s 25ms/step - loss: 0.2545 - accuracy: 0.8627 - val_loss: 0.3228 - val_accuracy: 0.8169\n","Epoch 55/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.3025 - accuracy: 0.8592 - val_loss: 0.2707 - val_accuracy: 0.9155\n","Epoch 56/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2981 - accuracy: 0.8451 - val_loss: 0.2414 - val_accuracy: 0.9155\n","Epoch 57/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.2904 - accuracy: 0.8592 - val_loss: 0.3366 - val_accuracy: 0.8873\n","Epoch 58/60\n","9/9 [==============================] - 0s 29ms/step - loss: 0.2521 - accuracy: 0.8627 - val_loss: 0.2486 - val_accuracy: 0.9014\n","Epoch 59/60\n","9/9 [==============================] - 0s 23ms/step - loss: 0.2530 - accuracy: 0.8732 - val_loss: 0.3621 - val_accuracy: 0.8732\n","Epoch 60/60\n","9/9 [==============================] - 0s 24ms/step - loss: 0.2533 - accuracy: 0.8838 - val_loss: 0.2578 - val_accuracy: 0.9014\n"]}]},{"cell_type":"code","source":["# Evaluate the model\n","test_loss, test_accuracy = model.evaluate(test_X, test_y)\n","print(\"Test Loss:\", test_loss)\n","print(\"Test Accuracy:\", test_accuracy)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9-zQxYTLDfbi","executionInfo":{"status":"ok","timestamp":1692302229234,"user_tz":300,"elapsed":10,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"47d2073c-2ee9-420b-8c8e-19d33d346c15"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["3/3 [==============================] - 0s 12ms/step - loss: 0.3524 - accuracy: 0.8764\n","Test Loss: 0.3523918390274048\n","Test Accuracy: 0.8764045238494873\n"]}]},{"cell_type":"code","source":["# Plot training and validation metrics\n","plt.figure(figsize=(10, 4))\n","plt.subplot(1, 2, 1)\n","plt.plot(history.history['loss'], label='Training Loss')\n","plt.plot(history.history['val_loss'], label='Validation Loss')\n","plt.xlabel('Epoch')\n","plt.ylabel('Loss')\n","plt.legend()\n","\n","plt.subplot(1, 2, 2)\n","plt.plot(history.history['accuracy'], label='Training Accuracy')\n","plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n","plt.xlabel('Epoch')\n","plt.ylabel('Accuracy')\n","plt.legend()\n","\n","plt.tight_layout()\n","plt.show()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":406},"id":"0MPVbGXzDf47","executionInfo":{"status":"ok","timestamp":1692302233675,"user_tz":300,"elapsed":946,"user":{"displayName":"Francesco Bassino","userId":"12214195385968794219"}},"outputId":"89493b24-7385-4ba2-90c5-295ff425b05f"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":["<Figure size 1000x400 with 2 Axes>"],"image/png":"\n"},"metadata":{}}]},{"cell_type":"code","source":["model.save('lstm_model_v01.h5')"],"metadata":{"id":"kJzUsHTqEbJZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":[],"metadata":{"id":"08qQAUfb9YvN"}}]}
\ No newline at end of file