[367703]: / .ipynb_checkpoints / EEGLearn-checkpoint.ipynb

Download this file

3238 lines (3237 with data), 276.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import scipy.io as sio\n",
    "import matplotlib.pyplot as plt \n",
    "import seaborn as sn\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import os \n",
    "\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "from torch.utils.data import DataLoader,random_split\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "torch.manual_seed(1234)\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instancing on the GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2670, 3, 32, 32)\n",
      "(2670, 7, 3, 32, 32)\n",
      "(2670,)\n",
      "(2670,)\n"
     ]
    }
   ],
   "source": [
    "Mean_Images = sio.loadmat(\"images.mat\")[\"img\"] #corresponding to the images mean for all the seven windows\n",
    "print(np.shape(Mean_Images)) \n",
    "\n",
    "Images = sio.loadmat(\"images_time.mat\")[\"img\"] #corresponding to the images mean for all the seven windows\n",
    "print(np.shape(Images)) \n",
    "\n",
    "\n",
    "Label = (sio.loadmat(\"FeatureMat_timeWin\")[\"features\"][:,-1]-1).astype(int)\n",
    "print(np.shape(Label)) \n",
    "\n",
    "Patient_id = sio.loadmat(\"trials_subNums.mat\")['subjectNum'][0]\n",
    "print(np.shape(Patient_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EEGImagesDataset(Dataset):\n",
    "    \"\"\"EEGLearn Images Dataset from EEG.\"\"\"\n",
    "    \n",
    "    def __init__(self, label, image):\n",
    "        self.label = Label\n",
    "        self.Images = image\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.label)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        if torch.is_tensor(idx):\n",
    "            idx = idx.tolist()\n",
    "        image = self.Images[idx]\n",
    "        label = self.label[idx]\n",
    "        sample = (image, label)\n",
    "        \n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### K-Fold Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kfold(length, n_fold):\n",
    "    tot_id = np.arange(length)\n",
    "    np.random.shuffle(tot_id)\n",
    "    len_fold = int(length/n_fold)\n",
    "    train_id = []\n",
    "    test_id = []\n",
    "    for i in range(n_fold):\n",
    "        test_id.append(tot_id[i*len_fold:(i+1)*len_fold])\n",
    "        train_id.append(np.hstack([tot_id[0:i*len_fold],tot_id[(i+1)*len_fold:-1]]))\n",
    "    return train_id, test_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mean Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicCNN(nn.Module):\n",
    "    '''\n",
    "    Build the  Mean Basic model performing a classification with CNN \n",
    "\n",
    "    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]\n",
    "    param kernel: kernel size used for the convolutional layers\n",
    "    param stride: stride apply during the convolutions\n",
    "    param padding: padding used during the convolutions\n",
    "    param max_kernel: kernel used for the maxpooling steps\n",
    "    param n_classes: number of classes\n",
    "    return x: output of the last layers after the log softmax\n",
    "    '''\n",
    "    def __init__(self, input_image=torch.zeros(1, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4):\n",
    "        super(BasicCNN, self).__init__()\n",
    "\n",
    "        n_window = input_image.shape[1]\n",
    "        n_channel = input_image.shape[2]\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.pool1 = nn.MaxPool2d(max_kernel)\n",
    "        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)\n",
    "        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)\n",
    "        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)\n",
    "\n",
    "        self.pool = nn.MaxPool2d((1,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "\n",
    "        self.fc1 = nn.Linear(2048,512)\n",
    "        self.fc2 = nn.Linear(512,n_classes)\n",
    "        self.max = nn.LogSoftmax()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.relu(self.conv4(x))\n",
    "        x = self.pool1(x)\n",
    "        x = F.relu(self.conv5(x))\n",
    "        x = F.relu(self.conv6(x))\n",
    "        x = self.pool1(x)\n",
    "        x = F.relu(self.conv7(x))\n",
    "        x = self.pool1(x)\n",
    "        x = x.reshape(x.shape[0],x.shape[1], -1)\n",
    "        x = self.pool(x)\n",
    "        x = x.reshape(x.shape[0],-1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "EEG = EEGImagesDataset(label=Label, image=Mean_Images)\n",
    "\n",
    "lengths = [int(2670*0.8), int(2670*0.2)]\n",
    "Train, Test = random_split(EEG, lengths)\n",
    "\n",
    "Trainloader = DataLoader(Train,batch_size=32)\n",
    "Testloader = DataLoader(Test, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vdelv/anaconda3/envs/Pytorch_EEG/lib/python3.7/site-packages/ipykernel_launcher.py:52: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Training\n"
     ]
    }
   ],
   "source": [
    "net = BasicCNN().cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "for epoch in range(30):  # loop over the dataset multiple times\n",
    "    running_loss = 0.0\n",
    "    evaluation = []\n",
    "    for i, data in enumerate(Trainloader, 0):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs.to(torch.float32).cuda())\n",
    "        _, predicted = torch.max(outputs.cpu().data, 1)\n",
    "        evaluation.append((predicted==labels).tolist())\n",
    "        loss = criterion(outputs, labels.cuda())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "    running_loss = running_loss/(i+1)\n",
    "    evaluation = [item for sublist in evaluation for item in sublist]\n",
    "    running_acc = sum(evaluation)/len(evaluation)\n",
    "    validation_loss, validation_acc = Test_Model(net, Testloader,True)\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vdelv/anaconda3/envs/Pytorch_EEG/lib/python3.7/site-packages/ipykernel_launcher.py:52: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5,  10]\tloss: 1.384\tAccuracy : 0.277\t\tval-loss: 1.383\tval-Accuracy : 0.311\n",
      "[10,  10]\tloss: 1.383\tAccuracy : 0.277\t\tval-loss: 1.382\tval-Accuracy : 0.311\n",
      "Finished Training \n",
      " loss: 1.383\tAccuracy : 0.277\t\tval-loss: 1.382\tval-Accuracy : 0.311\n"
     ]
    }
   ],
   "source": [
    "res = TrainTest_Model(BasicCNN, Trainloader, Testloader, n_epoch=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Test_Model(net, Testloader, is_cuda=True):\n",
    "    running_loss = 0.0 \n",
    "    evaluation = []\n",
    "    for i, data in enumerate(Testloader, 0):\n",
    "        input_img, labels = data\n",
    "        optimizer.zero_grad()\n",
    "        input_img = input_img.to(torch.float32)\n",
    "        if is_cuda:\n",
    "            input_img = input_img.cuda()\n",
    "        outputs = net(input_img)\n",
    "        _, predicted = torch.max(outputs.cpu().data, 1)\n",
    "        evaluation.append((predicted==labels).tolist())\n",
    "        loss = criterion(outputs, labels.cuda())\n",
    "        running_loss += loss.item()\n",
    "    running_loss = running_loss/(i+1)\n",
    "    evaluation = [item for sublist in evaluation for item in sublist]\n",
    "    running_acc = sum(evaluation)/len(evaluation)\n",
    "    return running_loss, running_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "def TrainTest_Model(model, trainloader, testloader, n_epoch=30, opti='SGD', learning_rate=0.0001, is_cuda=True, print_epoch =5):\n",
    "    if is_cuda:\n",
    "        net = model().cuda()\n",
    "    else :\n",
    "        net = model()\n",
    "        \n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    if opti=='SGD':\n",
    "        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    elif opti =='Adam':\n",
    "        optimizer = optim.Adam(CNN.parameters(), lr=learning_rate)\n",
    "    else: \n",
    "        print(\"Optimizer: \"+optim+\" not implemented.\")\n",
    "    \n",
    "    for epoch in range(n_epoch):\n",
    "        running_loss = 0.0\n",
    "        evaluation = []\n",
    "        for i, data in enumerate(Trainloader, 0):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "            inputs, labels = data\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            outputs = net(inputs.to(torch.float32).cuda())\n",
    "            _, predicted = torch.max(outputs.cpu().data, 1)\n",
    "            evaluation.append((predicted==labels).tolist())\n",
    "            loss = criterion(outputs, labels.cuda())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        running_loss = running_loss/(i+1)\n",
    "        evaluation = [item for sublist in evaluation for item in sublist]\n",
    "        running_acc = sum(evaluation)/len(evaluation)\n",
    "        validation_loss, validation_acc = Test_Model(net, Testloader,True)\n",
    "        \n",
    "        if epoch%print_epoch==(print_epoch-1):\n",
    "            print('[%d, %3d]\\tloss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epoch, running_loss, running_acc, validation_loss, validation_acc))\n",
    "    \n",
    "    print('Finished Training \\n loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (running_loss, running_acc, validation_loss,validation_acc))\n",
    "    \n",
    "    return (running_loss, running_acc, validation_loss,validation_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 477,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training Fold 1/5\t of Patient 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:19: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finish Training Fold 1/5\t of Patient 1\n",
      "Begin Training Fold 2/5\t of Patient 1\n",
      "Finish Training Fold 2/5\t of Patient 1\n",
      "Begin Training Fold 3/5\t of Patient 1\n",
      "Finish Training Fold 3/5\t of Patient 1\n",
      "Begin Training Fold 4/5\t of Patient 1\n",
      "Finish Training Fold 4/5\t of Patient 1\n",
      "Begin Training Fold 5/5\t of Patient 1\n",
      "Finish Training Fold 5/5\t of Patient 1\n",
      "loss: 0.854\tAccuracy : 0.855\t\tval-loss: 0.956\tval-Accuracy : 0.805\n",
      "Begin Training Fold 1/5\t of Patient 2\n",
      "Finish Training Fold 1/5\t of Patient 2\n",
      "Begin Training Fold 2/5\t of Patient 2\n",
      "Finish Training Fold 2/5\t of Patient 2\n",
      "Begin Training Fold 3/5\t of Patient 2\n",
      "Finish Training Fold 3/5\t of Patient 2\n",
      "Begin Training Fold 4/5\t of Patient 2\n",
      "Finish Training Fold 4/5\t of Patient 2\n",
      "Begin Training Fold 5/5\t of Patient 2\n",
      "Finish Training Fold 5/5\t of Patient 2\n",
      "loss: 0.824\tAccuracy : 0.872\t\tval-loss: 0.897\tval-Accuracy : 0.852\n",
      "Begin Training Fold 1/5\t of Patient 3\n",
      "Finish Training Fold 1/5\t of Patient 3\n",
      "Begin Training Fold 2/5\t of Patient 3\n",
      "Finish Training Fold 2/5\t of Patient 3\n",
      "Begin Training Fold 3/5\t of Patient 3\n",
      "Finish Training Fold 3/5\t of Patient 3\n",
      "Begin Training Fold 4/5\t of Patient 3\n",
      "Finish Training Fold 4/5\t of Patient 3\n",
      "Begin Training Fold 5/5\t of Patient 3\n",
      "Finish Training Fold 5/5\t of Patient 3\n",
      "loss: 0.778\tAccuracy : 0.883\t\tval-loss: 0.900\tval-Accuracy : 0.841\n",
      "Begin Training Fold 1/5\t of Patient 4\n",
      "Finish Training Fold 1/5\t of Patient 4\n",
      "Begin Training Fold 2/5\t of Patient 4\n",
      "Finish Training Fold 2/5\t of Patient 4\n",
      "Begin Training Fold 3/5\t of Patient 4\n",
      "Finish Training Fold 3/5\t of Patient 4\n",
      "Begin Training Fold 4/5\t of Patient 4\n",
      "Finish Training Fold 4/5\t of Patient 4\n",
      "Begin Training Fold 5/5\t of Patient 4\n",
      "Finish Training Fold 5/5\t of Patient 4\n",
      "loss: 0.770\tAccuracy : 0.960\t\tval-loss: 0.797\tval-Accuracy : 0.950\n",
      "Begin Training Fold 1/5\t of Patient 6\n",
      "Finish Training Fold 1/5\t of Patient 6\n",
      "Begin Training Fold 2/5\t of Patient 6\n",
      "Finish Training Fold 2/5\t of Patient 6\n",
      "Begin Training Fold 3/5\t of Patient 6\n",
      "Finish Training Fold 3/5\t of Patient 6\n",
      "Begin Training Fold 4/5\t of Patient 6\n",
      "Finish Training Fold 4/5\t of Patient 6\n",
      "Begin Training Fold 5/5\t of Patient 6\n",
      "Finish Training Fold 5/5\t of Patient 6\n",
      "loss: 0.814\tAccuracy : 0.888\t\tval-loss: 0.872\tval-Accuracy : 0.872\n",
      "Begin Training Fold 1/5\t of Patient 7\n",
      "Finish Training Fold 1/5\t of Patient 7\n",
      "Begin Training Fold 2/5\t of Patient 7\n",
      "Finish Training Fold 2/5\t of Patient 7\n",
      "Begin Training Fold 3/5\t of Patient 7\n",
      "Finish Training Fold 3/5\t of Patient 7\n",
      "Begin Training Fold 4/5\t of Patient 7\n",
      "Finish Training Fold 4/5\t of Patient 7\n",
      "Begin Training Fold 5/5\t of Patient 7\n",
      "Finish Training Fold 5/5\t of Patient 7\n",
      "loss: 0.785\tAccuracy : 0.893\t\tval-loss: 0.862\tval-Accuracy : 0.885\n",
      "Begin Training Fold 1/5\t of Patient 8\n",
      "Finish Training Fold 1/5\t of Patient 8\n",
      "Begin Training Fold 2/5\t of Patient 8\n",
      "Finish Training Fold 2/5\t of Patient 8\n",
      "Begin Training Fold 3/5\t of Patient 8\n",
      "Finish Training Fold 3/5\t of Patient 8\n",
      "Begin Training Fold 4/5\t of Patient 8\n",
      "Finish Training Fold 4/5\t of Patient 8\n",
      "Begin Training Fold 5/5\t of Patient 8\n",
      "Finish Training Fold 5/5\t of Patient 8\n",
      "loss: 0.780\tAccuracy : 0.951\t\tval-loss: 0.795\tval-Accuracy : 0.953\n",
      "Begin Training Fold 1/5\t of Patient 9\n",
      "Finish Training Fold 1/5\t of Patient 9\n",
      "Begin Training Fold 2/5\t of Patient 9\n",
      "Finish Training Fold 2/5\t of Patient 9\n",
      "Begin Training Fold 3/5\t of Patient 9\n",
      "Finish Training Fold 3/5\t of Patient 9\n",
      "Begin Training Fold 4/5\t of Patient 9\n",
      "Finish Training Fold 4/5\t of Patient 9\n",
      "Begin Training Fold 5/5\t of Patient 9\n",
      "Finish Training Fold 5/5\t of Patient 9\n",
      "loss: 0.769\tAccuracy : 0.965\t\tval-loss: 0.808\tval-Accuracy : 0.930\n",
      "Begin Training Fold 1/5\t of Patient 10\n",
      "Finish Training Fold 1/5\t of Patient 10\n",
      "Begin Training Fold 2/5\t of Patient 10\n",
      "Finish Training Fold 2/5\t of Patient 10\n",
      "Begin Training Fold 3/5\t of Patient 10\n",
      "Finish Training Fold 3/5\t of Patient 10\n",
      "Begin Training Fold 4/5\t of Patient 10\n",
      "Finish Training Fold 4/5\t of Patient 10\n",
      "Begin Training Fold 5/5\t of Patient 10\n",
      "Finish Training Fold 5/5\t of Patient 10\n",
      "loss: 0.774\tAccuracy : 0.939\t\tval-loss: 0.803\tval-Accuracy : 0.943\n",
      "Begin Training Fold 1/5\t of Patient 11\n",
      "Finish Training Fold 1/5\t of Patient 11\n",
      "Begin Training Fold 2/5\t of Patient 11\n",
      "Finish Training Fold 2/5\t of Patient 11\n",
      "Begin Training Fold 3/5\t of Patient 11\n",
      "Finish Training Fold 3/5\t of Patient 11\n",
      "Begin Training Fold 4/5\t of Patient 11\n",
      "Finish Training Fold 4/5\t of Patient 11\n",
      "Begin Training Fold 5/5\t of Patient 11\n",
      "Finish Training Fold 5/5\t of Patient 11\n",
      "loss: 0.768\tAccuracy : 0.944\t\tval-loss: 0.832\tval-Accuracy : 0.920\n",
      "Begin Training Fold 1/5\t of Patient 12\n",
      "Finish Training Fold 1/5\t of Patient 12\n",
      "Begin Training Fold 2/5\t of Patient 12\n",
      "Finish Training Fold 2/5\t of Patient 12\n",
      "Begin Training Fold 3/5\t of Patient 12\n",
      "Finish Training Fold 3/5\t of Patient 12\n",
      "Begin Training Fold 4/5\t of Patient 12\n",
      "Finish Training Fold 4/5\t of Patient 12\n",
      "Begin Training Fold 5/5\t of Patient 12\n",
      "Finish Training Fold 5/5\t of Patient 12\n",
      "loss: 0.767\tAccuracy : 0.924\t\tval-loss: 0.810\tval-Accuracy : 0.944\n",
      "Begin Training Fold 1/5\t of Patient 14\n",
      "Finish Training Fold 1/5\t of Patient 14\n",
      "Begin Training Fold 2/5\t of Patient 14\n",
      "Finish Training Fold 2/5\t of Patient 14\n",
      "Begin Training Fold 3/5\t of Patient 14\n",
      "Finish Training Fold 3/5\t of Patient 14\n",
      "Begin Training Fold 4/5\t of Patient 14\n",
      "Finish Training Fold 4/5\t of Patient 14\n",
      "Begin Training Fold 5/5\t of Patient 14\n",
      "Finish Training Fold 5/5\t of Patient 14\n",
      "loss: 0.795\tAccuracy : 0.904\t\tval-loss: 0.895\tval-Accuracy : 0.854\n",
      "Begin Training Fold 1/5\t of Patient 15\n",
      "Finish Training Fold 1/5\t of Patient 15\n",
      "Begin Training Fold 2/5\t of Patient 15\n",
      "Finish Training Fold 2/5\t of Patient 15\n",
      "Begin Training Fold 3/5\t of Patient 15\n",
      "Finish Training Fold 3/5\t of Patient 15\n",
      "Begin Training Fold 4/5\t of Patient 15\n",
      "Finish Training Fold 4/5\t of Patient 15\n",
      "Begin Training Fold 5/5\t of Patient 15\n",
      "Finish Training Fold 5/5\t of Patient 15\n",
      "loss: 0.772\tAccuracy : 0.967\t\tval-loss: 0.768\tval-Accuracy : 0.986\n"
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "fold_vloss = np.zeros((n_fold,n_patient))\n",
    "fold_loss = np.zeros((n_fold,n_patient))\n",
    "fold_vacc = np.zeros((n_fold,n_patient))\n",
    "fold_acc = np.zeros((n_fold,n_patient))\n",
    "for patient in np.unique(Patient):\n",
    "    id_patient = np.arange(len(Mean_Images))[Patient==patient]\n",
    "    n_fold = 5\n",
    "    length = len(id_patient)\n",
    "    \n",
    "    n_patient = len(np.unique(Patient))\n",
    "    \n",
    "    train_id, test_id = kfold(length,n_fold)\n",
    "    \n",
    "    for fold in range(n_fold):\n",
    "        X_train = Mean_Images[id_patient[train_id[fold]]]\n",
    "        X_test = Mean_Images[id_patient[test_id[fold]]]\n",
    "        y_train = Label[id_patient[train_id[fold]]]\n",
    "        y_test = Label[id_patient[test_id[fold]]] \n",
    "\n",
    "        print(\"Begin Training Fold %d/%d\\t of Patient %d\" % \n",
    "             (fold+1,n_fold, patient))\n",
    "\n",
    "        CNN = BasicCNN().cuda(0)\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "        n_epochs = 50\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 5\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i:i+batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i:i+batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "\n",
    "            #acc\n",
    "            _, idx = torch.max(CNN(torch.from_numpy(X_train[:]).to(torch.float32).cuda()).data,1)\n",
    "            acc = (idx == torch.from_numpy(y_train).cuda()).sum().item()/len(y_train)\n",
    "\n",
    "            #val Loss\n",
    "            val_outputs = CNN(torch.from_numpy(X_test[:]).to(torch.float32).cuda())\n",
    "            val_loss = criterion(val_outputs, torch.from_numpy(y_test[:]).to(torch.long).cuda())\n",
    "            _, idx = torch.max(val_outputs.data,1)\n",
    "            val_acc = (idx == torch.from_numpy(y_test).cuda()).sum().item()/len(y_test)\n",
    "\n",
    "            #if epoch%10==0:\n",
    "            #    print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "            #     (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "        fold_vloss[fold, p ] = val_loss.item()\n",
    "        fold_loss[fold, p] = running_loss/i\n",
    "        fold_vacc[fold, p] = val_acc\n",
    "        fold_acc[fold, p] = acc\n",
    "        print('Finish Training Fold %d/%d\\t of Patient %d' % \n",
    "             (fold+1,n_fold, patient))\n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Peresented Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 493,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x720 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n basic CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('fig.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Time Window"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNN Parallel Model (A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParallelCNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ParallelCNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv2 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv3 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv4 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv5 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv6 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv7 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv8 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv9 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv10 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv11 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv12 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.conv13 = nn.Conv2d(3,3,3, padding=1)\n",
    "        self.conv14 = nn.Conv2d(3,3,5, padding=2)\n",
    "        self.pool = nn.MaxPool2d(2,2)\n",
    "        self.fc1 = nn.Linear(3549,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.Softmax()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x[:,0] = F.relu(self.conv1(x[:,0]))\n",
    "        x[:,1] = F.relu(self.conv3(x[:,1]))\n",
    "        x[:,2] = F.relu(self.conv5(x[:,2]))\n",
    "        x[:,3] = F.relu(self.conv7(x[:,3]))\n",
    "        x[:,4] = F.relu(self.conv9(x[:,4]))\n",
    "        x[:,5] = F.relu(self.conv11(x[:,5]))\n",
    "        x[:,6] = F.relu(self.conv13(x[:,6]))\n",
    "        x[:,0] = F.relu(self.conv2(x[:,0]))\n",
    "        x[:,1] = F.relu(self.conv4(x[:,1]))\n",
    "        x[:,2] = F.relu(self.conv6(x[:,2]))\n",
    "        x[:,3] = F.relu(self.conv8(x[:,3]))\n",
    "        x[:,4] = F.relu(self.conv10(x[:,4]))\n",
    "        x[:,5] = F.relu(self.conv12(x[:,5]))\n",
    "        x[:,6] = F.relu(self.conv14(x[:,6]))\n",
    "        x = x[:,:,:,3:29,3:29]\n",
    "        x = x.reshape(batch_size, x.shape[2], x.shape[1]*x.shape[3],-1) # img reshape\n",
    "        x = self.pool(x)\n",
    "        x = x.view(batch_size,-1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:45: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training Fold 1/5\t of Patient 1\n",
      "[1,  30] loss: 1.414\tAccuracy : 0.517\t\tval-loss: 1.359\tval-Accuracy : 0.459\n",
      "[6,  30] loss: 1.117\tAccuracy : 0.748\t\tval-loss: 1.088\tval-Accuracy : 0.757\n",
      "[11,  30] loss: 0.913\tAccuracy : 0.830\t\tval-loss: 0.968\tval-Accuracy : 0.784\n",
      "[16,  30] loss: 0.862\tAccuracy : 0.830\t\tval-loss: 0.925\tval-Accuracy : 0.811\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-a0847679b8ae>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     36\u001b[0m                 \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCNN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     37\u001b[0m                 \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 38\u001b[1;33m                 \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     39\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     40\u001b[0m                 \u001b[0mrunning_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m    164\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    165\u001b[0m         \"\"\"\n\u001b[1;32m--> 166\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    167\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    168\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 99\u001b[1;33m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_fold = 5    \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_fold,n_patient))\n",
    "fold_loss = np.zeros((n_fold,n_patient))\n",
    "fold_vacc = np.zeros((n_fold,n_patient))\n",
    "fold_acc = np.zeros((n_fold,n_patient))\n",
    "for patient in np.unique(Patient):\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "\n",
    "    length = len(id_patient)\n",
    "    \n",
    "    train_id, test_id = kfold(length,n_fold)\n",
    "    \n",
    "    for fold in range(n_fold):\n",
    "        X_train = tmp[id_patient[train_id[fold]]]\n",
    "        X_test = tmp[id_patient[test_id[fold]]]\n",
    "        y_train = Label[id_patient[train_id[fold]]]\n",
    "        y_test = Label[id_patient[test_id[fold]]] \n",
    "\n",
    "        print(\"Begin Training Fold %d/%d\\t of Patient %d\" % \n",
    "             (fold+1,n_fold, patient))\n",
    "\n",
    "        CNN = ParallelCNN().cuda(0)\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "        n_epochs = 30\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 4\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i:i+batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i:i+batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "\n",
    "            #acc\n",
    "            _, idx = torch.max(CNN(torch.from_numpy(X_train[:]).to(torch.float32).cuda()).data,1)\n",
    "            acc = (idx == torch.from_numpy(y_train).cuda()).sum().item()/len(y_train)\n",
    "\n",
    "            #val Loss\n",
    "            val_outputs = CNN(torch.from_numpy(X_test[:]).to(torch.float32).cuda())\n",
    "            val_loss = criterion(val_outputs, torch.from_numpy(y_test[:]).to(torch.long).cuda())\n",
    "            _, idx = torch.max(val_outputs.data,1)\n",
    "            val_acc = (idx == torch.from_numpy(y_test).cuda()).sum().item()/len(y_test)\n",
    "\n",
    "            if epoch%5==0:\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "        fold_vloss[fold, p ] = val_loss.item()\n",
    "        fold_loss[fold, p] = running_loss/i\n",
    "        fold_vacc[fold, p] = val_acc\n",
    "        fold_acc[fold, p] = acc\n",
    "        print('Finish Training Fold %d/%d\\t of Patient %d' % \n",
    "             (fold+1,n_fold, patient))\n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Peresented Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x720 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sio.savemat('Result/Res_ParallelCNN.mat',{\"loss\":fold_loss,\"acc\":fold_acc,\"val loss\":fold_vloss,\"val acc\":fold_vacc})\n",
    "\n",
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n Parallel CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('Result/ParallelCNN.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNN Temporal Model (B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TemporalCNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TemporalCNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,3,3)\n",
    "        self.conv2 = nn.Conv2d(3,3,5)\n",
    "        self.conv3 = nn.Conv2d(3,3,3)\n",
    "        self.conv4 = nn.Conv2d(3,3,5)\n",
    "        self.conv5 = nn.Conv2d(3,3,3)\n",
    "        self.conv6 = nn.Conv2d(3,3,5)\n",
    "        self.conv7 = nn.Conv2d(3,3,3)\n",
    "        self.conv8 = nn.Conv2d(3,3,5)\n",
    "        self.conv9 = nn.Conv2d(3,3,3)\n",
    "        self.conv10 = nn.Conv2d(3,3,5)\n",
    "        self.conv11 = nn.Conv2d(3,3,3)\n",
    "        self.conv12 = nn.Conv2d(3,3,5)\n",
    "        self.conv13 = nn.Conv2d(3,3,3)\n",
    "        self.conv14 = nn.Conv2d(3,3,5)     \n",
    "        self.pool1 = nn.MaxPool2d(2)\n",
    "        self.pool2 = nn.MaxPool2d(2)   \n",
    "        self.conv15 = nn.Conv2d(3,3,5)\n",
    "        self.conv16 = nn.Conv2d(3,3,7)\n",
    "        self.fc1 = nn.Linear(120,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.Softmax()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        tmp = torch.zeros(batch_size, x.shape[1], x.shape[2],26,26).cuda()\n",
    "        tmp[:,0] = F.relu(self.conv2(F.relu(self.conv1(x[:,0]))))\n",
    "        tmp[:,1] = F.relu(self.conv4(F.relu(self.conv3(x[:,1]))))\n",
    "        tmp[:,2] = F.relu(self.conv6(F.relu(self.conv5(x[:,2]))))\n",
    "        tmp[:,3] = F.relu(self.conv8(F.relu(self.conv7(x[:,3]))))\n",
    "        tmp[:,4] = F.relu(self.conv10(F.relu(self.conv9(x[:,4]))))\n",
    "        tmp[:,5] = F.relu(self.conv12(F.relu(self.conv11(x[:,5]))))\n",
    "        tmp[:,6] = F.relu(self.conv14(F.relu(self.conv13(x[:,6]))))\n",
    "        x = torch.zeros(batch_size, x.shape[1], x.shape[2],26,26).cuda()\n",
    "        for i in range(7):\n",
    "            x[:,i] = tmp[:,i]\n",
    "        #x[:,0] = F.relu(self.conv1(x[:,0]))\n",
    "        #x[:,1] = F.relu(self.conv3(x[:,1]))\n",
    "        #x[:,2] = F.relu(self.conv5(x[:,2]))\n",
    "        #x[:,3] = F.relu(self.conv7(x[:,3]))\n",
    "        #x[:,4] = F.relu(self.conv9(x[:,4]))\n",
    "        #x[:,5] = F.relu(self.conv11(x[:,5]))\n",
    "        #x[:,6] = F.relu(self.conv13(x[:,6]))\n",
    "        #x[:,0] = F.relu(self.conv2(x[:,0]))\n",
    "        #x[:,1] = F.relu(self.conv4(x[:,1]))\n",
    "        #x[:,2] = F.relu(self.conv6(x[:,2]))\n",
    "        #x[:,3] = F.relu(self.conv8(x[:,3]))\n",
    "        #x[:,4] = F.relu(self.conv10(x[:,4]))\n",
    "        #x[:,5] = F.relu(self.conv12(x[:,5]))\n",
    "        #x[:,6] = F.relu(self.conv14(x[:,6]))\n",
    "        #x = x[:,:,:,3:29,3:29]\n",
    "        #tmp = torch.zeros(batch_size, x.shape[2], x.shape[1]*x.shape[3],x.shape[4]).cuda()\n",
    "        #for i in range(x.shape[1]):\n",
    "        #    tmp[:,:,i*x.shape[3]:(i+1)*x.shape[3], :] = x[:,i]\n",
    "        x = x.reshape(batch_size, x.shape[2], x.shape[1]*x.shape[3],-1) # img reshape\n",
    "        x = self.pool1(x)\n",
    "        x = F.relu(self.conv15(x))\n",
    "        x = F.relu(self.conv16(x))\n",
    "        x = self.pool2(x)\n",
    "        x = x.view(batch_size,-1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:65: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0.2572, 0.2522, 0.2505, 0.2401],\n",
       "        [0.2591, 0.2499, 0.2525, 0.2385]], device='cuda:0',\n",
       "       grad_fn=<SoftmaxBackward>)"
      ]
     },
     "execution_count": 277,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = TemporalCNN().cuda()\n",
    "net(torch.from_numpy(tmp[0:2]).to(torch.float32).cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training Fold 1/5\t of Patient 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:65: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 100] loss: 1.424\tAccuracy : 0.327\t\tval-loss: 1.384\tval-Accuracy : 0.243\n",
      "[11, 100] loss: 1.422\tAccuracy : 0.320\t\tval-loss: 1.384\tval-Accuracy : 0.270\n",
      "[21, 100] loss: 1.420\tAccuracy : 0.320\t\tval-loss: 1.384\tval-Accuracy : 0.270\n",
      "[31, 100] loss: 1.418\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[41, 100] loss: 1.417\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[51, 100] loss: 1.415\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[61, 100] loss: 1.413\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[71, 100] loss: 1.411\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[81, 100] loss: 1.409\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "[91, 100] loss: 1.408\tAccuracy : 0.320\t\tval-loss: 1.383\tval-Accuracy : 0.270\n",
      "Finish Training Fold 1/5\t of Patient 1\n",
      "Begin Training Fold 2/5\t of Patient 1\n",
      "[1, 100] loss: 1.427\tAccuracy : 0.265\t\tval-loss: 1.390\tval-Accuracy : 0.216\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-280-7232a97298b5>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     37\u001b[0m                 \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCNN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     38\u001b[0m                 \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 39\u001b[1;33m                 \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     40\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     41\u001b[0m                 \u001b[0mrunning_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m    164\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    165\u001b[0m         \"\"\"\n\u001b[1;32m--> 166\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    167\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    168\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 99\u001b[1;33m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "    if opti=='SGD':\n",
    "        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    elif opti =='Adam':\n",
    "        optimizer = optip = 0\n",
    "n_fold = 5    \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_fold,n_patient))\n",
    "fold_loss = np.zeros((n_fold,n_patient))\n",
    "fold_vacc = np.zeros((n_fold,n_patient))\n",
    "fold_acc = np.zeros((n_fold,n_patient))\n",
    "for patient in np.unique(Patient):\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "\n",
    "    length = len(id_patient)\n",
    "    \n",
    "    train_id, test_id = kfold(length,n_fold)\n",
    "    \n",
    "    for fold in range(n_fold):\n",
    "        X_train = tmp[id_patient[train_id[fold]]]\n",
    "        X_test = tmp[id_patient[test_id[fold]]]\n",
    "        y_train = Label[id_patient[train_id[fold]]]\n",
    "        y_test = Label[id_patient[test_id[fold]]] \n",
    "\n",
    "        print(\"Begin Training Fold %d/%d\\t of Patient %d\" % \n",
    "             (fold+1,n_fold, patient))\n",
    "\n",
    "        CNN = TemporalCNN().cuda(0)\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.001)\n",
    "#        optimizer = optim.SGD(CNN.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "        n_epochs = 100\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 4\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i:i+batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i:i+batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "\n",
    "            #acc\n",
    "            _, idx = torch.max(CNN(torch.from_numpy(X_train[:]).to(torch.float32).cuda()).data,1)\n",
    "            acc = (idx == torch.from_numpy(y_train).cuda()).sum().item()/len(y_train)\n",
    "\n",
    "            #val Loss\n",
    "            val_outputs = CNN(torch.from_numpy(X_test[:]).to(torch.float32).cuda())\n",
    "            val_loss = criterion(val_outputs, torch.from_numpy(y_test[:]).to(torch.long).cuda())\n",
    "            _, idx = torch.max(val_outputs.data,1)\n",
    "            val_acc = (idx == torch.from_numpy(y_test).cuda()).sum().item()/len(y_test)\n",
    "\n",
    "            if epoch%10==0:\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "        fold_vloss[fold, p ] = val_loss.item()\n",
    "        fold_loss[fold, p] = running_loss/i\n",
    "        fold_vacc[fold, p] = val_acc\n",
    "        fold_acc[fold, p] = acc\n",
    "        print('Finish Training Fold %d/%d\\t of Patient %d' % \n",
    "             (fold+1,n_fold, patient))\n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Peresented Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x720 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sio.savemat('Result/Res_TemporalCNN.mat',{\"loss\":fold_loss,\"acc\":fold_acc,\"val loss\":fold_vloss,\"val acc\":fold_vacc})\n",
    "\n",
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n Parallel CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('Result/ParallelCNN.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sio.savemat('Result/Res_2dTemporalCNN.mat',{\"loss\":fold_loss,\"acc\":fold_acc,\"val loss\":fold_vloss,\"val acc\":fold_vacc})\n",
    "\n",
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n Parallel CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('Result/2dParallelCNN.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM Model (C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,3,3)\n",
    "        self.conv2 = nn.Conv2d(3,3,5)\n",
    "        self.conv3 = nn.Conv2d(3,3,3)\n",
    "        self.conv4 = nn.Conv2d(3,3,5)\n",
    "        self.conv5 = nn.Conv2d(3,3,3)\n",
    "        self.conv6 = nn.Conv2d(3,3,5)\n",
    "        self.conv7 = nn.Conv2d(3,3,3)\n",
    "        self.conv8 = nn.Conv2d(3,3,5)\n",
    "        self.conv9 = nn.Conv2d(3,3,3)\n",
    "        self.conv10 = nn.Conv2d(3,3,5)\n",
    "        self.conv11 = nn.Conv2d(3,3,3)\n",
    "        self.conv12 = nn.Conv2d(3,3,5)\n",
    "        self.conv13 = nn.Conv2d(3,3,3)\n",
    "        self.conv14 = nn.Conv2d(3,3,5)     \n",
    "        self.pool1 = nn.MaxPool2d(2)\n",
    "        self.pool2 = nn.MaxPool2d(2)   \n",
    "        self.rnn1 = nn.LSTMCell(507,2)\n",
    "        self.fc1 = nn.Linear(120,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.Softmax()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        tmp = torch.zeros(batch_size, x.shape[1], x.shape[2],26,26).cuda()\n",
    "        tmp[:,0] = F.relu(self.conv2(F.relu(self.conv1(x[:,0]))))\n",
    "        tmp[:,1] = F.relu(self.conv4(F.relu(self.conv3(x[:,1]))))\n",
    "        tmp[:,2] = F.relu(self.conv6(F.relu(self.conv5(x[:,2]))))\n",
    "        tmp[:,3] = F.relu(self.conv8(F.relu(self.conv7(x[:,3]))))\n",
    "        tmp[:,4] = F.relu(self.conv10(F.relu(self.conv9(x[:,4]))))\n",
    "        tmp[:,5] = F.relu(self.conv12(F.relu(self.conv11(x[:,5]))))\n",
    "        tmp[:,6] = F.relu(self.conv14(F.relu(self.conv13(x[:,6]))))\n",
    "        x = torch.zeros(batch_size, x.shape[1], x.shape[2],13,13).cuda()\n",
    "        for i in range(7):\n",
    "            x[:,i] = self.pool1(tmp[:,i])\n",
    "        x = x.view(batch_size,x.shape[1],-1)\n",
    "        hx = torch.randn(1,batch_size)\n",
    "        print(hx.size())\n",
    "        print(x[:,0].size())\n",
    "        self.rnn1(x[:,0],hx)\n",
    "        x = x.reshape(batch_size, x.shape[2], x.shape[1]*x.shape[3],-1) # img reshape\n",
    "        x = F.relu(self.conv15(x))\n",
    "        x = F.relu(self.conv16(x))\n",
    "        x = self.pool2(x)\n",
    "        x = x.view(batch_size,-1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2])\n",
      "torch.Size([2, 507])\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-239-37f3230988bf>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mLSTM\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtmp\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-238-e8a90c20b4cc>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     40\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     41\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 42\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrnn1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mhx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     43\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# img reshape\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     44\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv15\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input, hx)\u001b[0m\n\u001b[0;32m    938\u001b[0m             \u001b[0mzeros\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhidden_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    939\u001b[0m             \u001b[0mhx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mzeros\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 940\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcheck_forward_hidden\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'[0]'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    941\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcheck_forward_hidden\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'[1]'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    942\u001b[0m         return _VF.lstm_cell(\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mcheck_forward_hidden\u001b[1;34m(self, input, hx, hidden_label)\u001b[0m\n\u001b[0;32m    769\u001b[0m                     input.size(0), hidden_label, hx.size(0)))\n\u001b[0;32m    770\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 771\u001b[1;33m         \u001b[1;32mif\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhidden_size\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    772\u001b[0m             raise RuntimeError(\n\u001b[0;32m    773\u001b[0m                 \"hidden{} has inconsistent hidden_size: got {}, expected {}\".format(\n",
      "\u001b[1;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
     ]
    }
   ],
   "source": [
    "net = LSTM().cuda()\n",
    "net(torch.from_numpy(tmp[0:2]).to(torch.float32).cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 499,
   "metadata": {},
   "outputs": [],
   "source": [
    "tot_img = Images[:,0,:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 393,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training Fold 1/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:19: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  50] loss: 1.385\tAccuracy : 0.314\t\tval-loss: 1.352\tval-Accuracy : 0.309\n",
      "[11,  50] loss: 0.841\tAccuracy : 0.834\t\tval-loss: 0.899\tval-Accuracy : 0.845\n",
      "[21,  50] loss: 0.826\tAccuracy : 0.865\t\tval-loss: 0.877\tval-Accuracy : 0.867\n",
      "[31,  50] loss: 0.804\tAccuracy : 0.875\t\tval-loss: 0.876\tval-Accuracy : 0.869\n",
      "[41,  50] loss: 0.797\tAccuracy : 0.883\t\tval-loss: 0.876\tval-Accuracy : 0.867\n",
      "Finish Training Fold 1/5\n",
      "Begin Training Fold 2/5\n",
      "[1,  50] loss: 1.386\tAccuracy : 0.317\t\tval-loss: 1.370\tval-Accuracy : 0.277\n",
      "[11,  50] loss: 0.869\tAccuracy : 0.831\t\tval-loss: 0.918\tval-Accuracy : 0.828\n",
      "[21,  50] loss: 0.842\tAccuracy : 0.827\t\tval-loss: 0.918\tval-Accuracy : 0.824\n",
      "[31,  50] loss: 0.825\tAccuracy : 0.831\t\tval-loss: 0.911\tval-Accuracy : 0.826\n",
      "[41,  50] loss: 0.824\tAccuracy : 0.832\t\tval-loss: 0.907\tval-Accuracy : 0.833\n",
      "Finish Training Fold 2/5\n",
      "Begin Training Fold 3/5\n",
      "[1,  50] loss: 1.384\tAccuracy : 0.482\t\tval-loss: 1.330\tval-Accuracy : 0.485\n",
      "[11,  50] loss: 0.871\tAccuracy : 0.837\t\tval-loss: 0.913\tval-Accuracy : 0.824\n",
      "[21,  50] loss: 0.830\tAccuracy : 0.862\t\tval-loss: 0.906\tval-Accuracy : 0.841\n",
      "[31,  50] loss: 0.815\tAccuracy : 0.881\t\tval-loss: 0.880\tval-Accuracy : 0.861\n",
      "[41,  50] loss: 0.813\tAccuracy : 0.878\t\tval-loss: 0.882\tval-Accuracy : 0.861\n",
      "Finish Training Fold 3/5\n",
      "Begin Training Fold 4/5\n",
      "[1,  50] loss: 1.392\tAccuracy : 0.318\t\tval-loss: 1.384\tval-Accuracy : 0.348\n",
      "[11,  50] loss: 0.927\tAccuracy : 0.791\t\tval-loss: 0.941\tval-Accuracy : 0.800\n",
      "[21,  50] loss: 0.840\tAccuracy : 0.839\t\tval-loss: 0.917\tval-Accuracy : 0.828\n",
      "[31,  50] loss: 0.823\tAccuracy : 0.848\t\tval-loss: 0.899\tval-Accuracy : 0.845\n",
      "[41,  50] loss: 0.813\tAccuracy : 0.845\t\tval-loss: 0.897\tval-Accuracy : 0.846\n",
      "Finish Training Fold 4/5\n",
      "Begin Training Fold 5/5\n",
      "[1,  50] loss: 1.335\tAccuracy : 0.418\t\tval-loss: 1.265\tval-Accuracy : 0.436\n",
      "[11,  50] loss: 0.880\tAccuracy : 0.821\t\tval-loss: 0.935\tval-Accuracy : 0.807\n",
      "[21,  50] loss: 0.832\tAccuracy : 0.850\t\tval-loss: 0.907\tval-Accuracy : 0.835\n",
      "[31,  50] loss: 0.823\tAccuracy : 0.837\t\tval-loss: 0.905\tval-Accuracy : 0.833\n",
      "[41,  50] loss: 0.810\tAccuracy : 0.837\t\tval-loss: 0.906\tval-Accuracy : 0.833\n",
      "Finish Training Fold 5/5\n"
     ]
    }
   ],
   "source": [
    "id_patient = \n",
    "\n",
    "n_fold = 5\n",
    "length = len(Mean_Images)\n",
    "\n",
    "fold_vloss = np.zeros((n_fold,n_patient))\n",
    "fold_loss = np.zeros((n_fold,n_patient))\n",
    "fold_vacc = np.zeros((n_fold,n_patient))\n",
    "fold_acc = np.zeros((n_fold,n_patient))\n",
    "\n",
    "train_id, test_id = kfold(length,n_fold)\n",
    "for fold in range(n_fold):\n",
    "    X_train = Mean_Images[train_id[fold]]\n",
    "    X_test = Mean_Images[test_id[fold]]    \n",
    "    y_train = Label[train_id[fold]]\n",
    "    y_test = Label[test_id[fold]]    \n",
    "    \n",
    "    print(\"Begin Training Fold %d/%d\" % \n",
    "         (fold+1,n_fold))\n",
    "    \n",
    "    CNN = BasicCNN().cuda(0)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.SGD(CNN.parameters(), lr=0.001, momentum=0.9)\n",
    "    \n",
    "    n_epochs = 50\n",
    "    for epoch in range(n_epochs):\n",
    "        running_loss = 0.0\n",
    "        batchsize = 10\n",
    "        for i in range(int(len(y_train)/batchsize)):\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            outputs = CNN(torch.from_numpy(X_train[i:i+batchsize]).to(torch.float32).cuda())\n",
    "            loss = criterion(outputs, torch.from_numpy(y_train[i:i+batchsize]).to(torch.long).cuda())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        #acc\n",
    "        _, idx = torch.max(CNN(torch.from_numpy(X_train[:]).to(torch.float32).cuda()).data,1)\n",
    "        acc = (idx == torch.from_numpy(y_train).cuda()).sum().item()/len(y_train)\n",
    "\n",
    "        #val Loss\n",
    "        val_outputs = CNN(torch.from_numpy(X_test[:]).to(torch.float32).cuda())\n",
    "        val_loss = criterion(val_outputs, torch.from_numpy(y_test[:]).to(torch.long).cuda())\n",
    "        _, idx = torch.max(val_outputs.data,1)\n",
    "        val_acc = (idx == torch.from_numpy(y_test).cuda()).sum().item()/len(y_test)\n",
    "\n",
    "        if epoch%10==0:\n",
    "            print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "    fold_vloss.append(val_loss)\n",
    "    fold_loss.append(running_loss/i)\n",
    "    fold_vacc.append(val_acc)\n",
    "    fold_acc.append(acc)\n",
    "    print('Finish Training Fold %d/%d' % \n",
    "         (fold+1,n_fold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicCNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BasicCNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        self.fc1 = nn.Linear(507,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.Softmax()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.relu(self.conv4(x))\n",
    "        x = self.pool(x)\n",
    "        x = F.relu(self.conv5(x))\n",
    "        x = F.relu(self.conv6(x))\n",
    "        x = self.pool(x)\n",
    "        x = F.relu(self.conv7(x))\n",
    "        x = self.pool(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "192"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "64*3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaxCNN(nn.Module):\n",
    "    def __init__(self, input_image, kernel=(3,3), stride=1, padding=1,max_kernel=(2,2)):\n",
    "        super(MaxCNN, self).__init__()\n",
    "        \n",
    "        \n",
    "        n_window = input_image.shape[1]\n",
    "        n_channel = input_image.shape[2]\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)\n",
    "        self.pool1 = nn.MaxPool2d(max_kernel)\n",
    "        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)\n",
    "        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)\n",
    "        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)\n",
    "        \n",
    "        self.pool = nn.MaxPool2d((n_window,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        \n",
    "        self.fc = nn.Linear(n_window*int(4*4*128/n_window),512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "\n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            tmp[:,i] = self.pool1( F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4,1)\n",
    "        x = self.pool(x)\n",
    "        x = x.view(x.shape[0],-1)\n",
    "        x = self.fc2(self.fc(x))\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaxCNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MaxCNN, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        \n",
    "        self.pool = nn.MaxPool2d((7,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        self.fc = nn.Linear(2044,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            tmp[:,i] = self.pool1( F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4,1)\n",
    "        x = self.pool(x)\n",
    "        x = x.view(x.shape[0],-1)\n",
    "        #x = self.drop(x)\n",
    "        x = self.fc2(self.fc(x))\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/5\t of Patient 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-43-18f565b68b5f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     38\u001b[0m                 \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCNN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     39\u001b[0m                 \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 40\u001b[1;33m                 \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     41\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     42\u001b[0m                 \u001b[0mrunning_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m    164\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    165\u001b[0m         \"\"\"\n\u001b[1;32m--> 166\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    167\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    168\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 99\u001b[1;33m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_rep = 5    \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "for patient in np.unique(Patient):\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = MaxCNN().cuda()\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.01)\n",
    "        \n",
    "        n_epochs = 45\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                \n",
    "                CNN.to(torch.device(\"cuda\"))\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            if epoch==50:\n",
    "                cnn_cpu = CNN.to(torch.device(\"cuda\"))\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(cnn_cpu(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32)).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = cnn_cpu(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32))\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long)).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "\n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/5\t of Patient 11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  45] loss: 1.402\tAccuracy : 0.286\t\tval-loss: 1.379\tval-Accuracy : 0.258\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.286\t\tval-loss: 1.368\tval-Accuracy : 0.258\n",
      "[11,  45] loss: 1.207\tAccuracy : 0.482\t\tval-loss: 0.958\tval-Accuracy : 0.516\n",
      "[16,  45] loss: 1.054\tAccuracy : 0.578\t\tval-loss: 0.800\tval-Accuracy : 0.613\n",
      "[21,  45] loss: 0.941\tAccuracy : 0.656\t\tval-loss: 0.727\tval-Accuracy : 0.689\n",
      "[26,  45] loss: 0.793\tAccuracy : 0.741\t\tval-loss: 0.627\tval-Accuracy : 0.787\n",
      "[31,  45] loss: 0.565\tAccuracy : 0.845\t\tval-loss: 0.306\tval-Accuracy : 0.902\n",
      "[36,  45] loss: 0.332\tAccuracy : 0.887\t\tval-loss: 0.214\tval-Accuracy : 0.933\n",
      "[41,  45] loss: 0.232\tAccuracy : 0.912\t\tval-loss: 0.218\tval-Accuracy : 0.920\n",
      "Begin Training rep 2/5\t of Patient 11\n",
      "[1,  45] loss: 1.401\tAccuracy : 0.358\t\tval-loss: 1.379\tval-Accuracy : 0.387\n",
      "[6,  45] loss: 1.389\tAccuracy : 0.286\t\tval-loss: 1.365\tval-Accuracy : 0.258\n",
      "[11,  45] loss: 1.136\tAccuracy : 0.548\t\tval-loss: 1.012\tval-Accuracy : 0.578\n",
      "[16,  45] loss: 1.019\tAccuracy : 0.606\t\tval-loss: 0.904\tval-Accuracy : 0.582\n",
      "[21,  45] loss: 0.908\tAccuracy : 0.665\t\tval-loss: 0.820\tval-Accuracy : 0.640\n",
      "[26,  45] loss: 0.771\tAccuracy : 0.598\t\tval-loss: 0.890\tval-Accuracy : 0.622\n",
      "[31,  45] loss: 0.457\tAccuracy : 0.858\t\tval-loss: 0.644\tval-Accuracy : 0.840\n",
      "[36,  45] loss: 0.326\tAccuracy : 0.889\t\tval-loss: 0.718\tval-Accuracy : 0.871\n",
      "[41,  45] loss: 0.234\tAccuracy : 0.923\t\tval-loss: 0.713\tval-Accuracy : 0.907\n",
      "Begin Training rep 3/5\t of Patient 11\n",
      "[1,  45] loss: 1.403\tAccuracy : 0.286\t\tval-loss: 1.385\tval-Accuracy : 0.258\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.286\t\tval-loss: 1.392\tval-Accuracy : 0.258\n",
      "[11,  45] loss: 1.378\tAccuracy : 0.286\t\tval-loss: 1.351\tval-Accuracy : 0.258\n",
      "[16,  45] loss: 1.115\tAccuracy : 0.550\t\tval-loss: 0.995\tval-Accuracy : 0.587\n",
      "[21,  45] loss: 0.998\tAccuracy : 0.599\t\tval-loss: 0.896\tval-Accuracy : 0.604\n",
      "[26,  45] loss: 0.867\tAccuracy : 0.667\t\tval-loss: 0.762\tval-Accuracy : 0.680\n",
      "[31,  45] loss: 0.635\tAccuracy : 0.816\t\tval-loss: 0.418\tval-Accuracy : 0.880\n",
      "[36,  45] loss: 0.361\tAccuracy : 0.889\t\tval-loss: 0.248\tval-Accuracy : 0.916\n",
      "[41,  45] loss: 0.245\tAccuracy : 0.907\t\tval-loss: 0.197\tval-Accuracy : 0.933\n",
      "Begin Training rep 4/5\t of Patient 11\n",
      "[1,  45] loss: 1.404\tAccuracy : 0.268\t\tval-loss: 1.392\tval-Accuracy : 0.262\n",
      "[6,  45] loss: 1.396\tAccuracy : 0.286\t\tval-loss: 1.410\tval-Accuracy : 0.258\n",
      "[11,  45] loss: 1.218\tAccuracy : 0.483\t\tval-loss: 1.115\tval-Accuracy : 0.520\n",
      "[16,  45] loss: 1.053\tAccuracy : 0.578\t\tval-loss: 0.800\tval-Accuracy : 0.600\n",
      "[21,  45] loss: 0.927\tAccuracy : 0.655\t\tval-loss: 0.742\tval-Accuracy : 0.662\n",
      "[26,  45] loss: 0.795\tAccuracy : 0.767\t\tval-loss: 0.661\tval-Accuracy : 0.787\n",
      "[31,  45] loss: 0.509\tAccuracy : 0.793\t\tval-loss: 0.645\tval-Accuracy : 0.804\n",
      "[36,  45] loss: 0.381\tAccuracy : 0.884\t\tval-loss: 0.299\tval-Accuracy : 0.920\n",
      "[41,  45] loss: 0.299\tAccuracy : 0.897\t\tval-loss: 0.265\tval-Accuracy : 0.907\n",
      "Begin Training rep 5/5\t of Patient 11\n",
      "[1,  45] loss: 1.404\tAccuracy : 0.285\t\tval-loss: 1.393\tval-Accuracy : 0.258\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.286\t\tval-loss: 1.407\tval-Accuracy : 0.258\n",
      "[11,  45] loss: 1.396\tAccuracy : 0.286\t\tval-loss: 1.408\tval-Accuracy : 0.258\n",
      "[16,  45] loss: 1.238\tAccuracy : 0.454\t\tval-loss: 1.161\tval-Accuracy : 0.498\n",
      "[21,  45] loss: 1.053\tAccuracy : 0.570\t\tval-loss: 1.031\tval-Accuracy : 0.578\n",
      "[26,  45] loss: 0.924\tAccuracy : 0.627\t\tval-loss: 0.993\tval-Accuracy : 0.604\n",
      "[31,  45] loss: 0.763\tAccuracy : 0.742\t\tval-loss: 0.782\tval-Accuracy : 0.751\n",
      "[36,  45] loss: 0.576\tAccuracy : 0.750\t\tval-loss: 0.740\tval-Accuracy : 0.769\n",
      "[41,  45] loss: 0.390\tAccuracy : 0.881\t\tval-loss: 0.282\tval-Accuracy : 0.929\n",
      "loss: 0.203\tAccuracy : 0.904\t\tval-loss: 0.335\tval-Accuracy : 0.919\n",
      "Begin Training rep 1/5\t of Patient 12\n",
      "[1,  45] loss: 1.402\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.262\tAccuracy : 0.450\t\tval-loss: 1.258\tval-Accuracy : 0.493\n",
      "[11,  45] loss: 1.069\tAccuracy : 0.572\t\tval-loss: 1.008\tval-Accuracy : 0.622\n",
      "[16,  45] loss: 0.983\tAccuracy : 0.614\t\tval-loss: 1.032\tval-Accuracy : 0.622\n",
      "[21,  45] loss: 0.852\tAccuracy : 0.650\t\tval-loss: 1.417\tval-Accuracy : 0.631\n",
      "[26,  45] loss: 0.659\tAccuracy : 0.766\t\tval-loss: 1.111\tval-Accuracy : 0.700\n",
      "[31,  45] loss: 0.426\tAccuracy : 0.883\t\tval-loss: 0.603\tval-Accuracy : 0.843\n",
      "[36,  45] loss: 0.291\tAccuracy : 0.898\t\tval-loss: 0.524\tval-Accuracy : 0.829\n",
      "[41,  45] loss: 0.191\tAccuracy : 0.915\t\tval-loss: 0.421\tval-Accuracy : 0.894\n",
      "Begin Training rep 2/5\t of Patient 12\n",
      "[1,  45] loss: 1.402\tAccuracy : 0.285\t\tval-loss: 1.385\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.285\t\tval-loss: 1.385\tval-Accuracy : 0.263\n",
      "[11,  45] loss: 1.238\tAccuracy : 0.468\t\tval-loss: 1.232\tval-Accuracy : 0.488\n",
      "[16,  45] loss: 1.052\tAccuracy : 0.534\t\tval-loss: 1.095\tval-Accuracy : 0.562\n",
      "[21,  45] loss: 0.943\tAccuracy : 0.623\t\tval-loss: 0.969\tval-Accuracy : 0.631\n",
      "[26,  45] loss: 0.790\tAccuracy : 0.736\t\tval-loss: 0.996\tval-Accuracy : 0.705\n",
      "[31,  45] loss: 0.522\tAccuracy : 0.852\t\tval-loss: 0.756\tval-Accuracy : 0.742\n",
      "[36,  45] loss: 0.351\tAccuracy : 0.882\t\tval-loss: 0.699\tval-Accuracy : 0.779\n",
      "[41,  45] loss: 0.274\tAccuracy : 0.899\t\tval-loss: 0.653\tval-Accuracy : 0.788\n",
      "Begin Training rep 3/5\t of Patient 12\n",
      "[1,  45] loss: 1.403\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.396\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[11,  45] loss: 1.215\tAccuracy : 0.490\t\tval-loss: 1.180\tval-Accuracy : 0.544\n",
      "[16,  45] loss: 1.044\tAccuracy : 0.551\t\tval-loss: 0.995\tval-Accuracy : 0.576\n",
      "[21,  45] loss: 0.947\tAccuracy : 0.605\t\tval-loss: 0.976\tval-Accuracy : 0.581\n",
      "[26,  45] loss: 0.819\tAccuracy : 0.723\t\tval-loss: 0.926\tval-Accuracy : 0.691\n",
      "[31,  45] loss: 0.541\tAccuracy : 0.843\t\tval-loss: 0.697\tval-Accuracy : 0.760\n",
      "[36,  45] loss: 0.323\tAccuracy : 0.851\t\tval-loss: 1.152\tval-Accuracy : 0.673\n",
      "[41,  45] loss: 0.197\tAccuracy : 0.887\t\tval-loss: 1.182\tval-Accuracy : 0.742\n",
      "Begin Training rep 4/5\t of Patient 12\n",
      "[1,  45] loss: 1.404\tAccuracy : 0.287\t\tval-loss: 1.385\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.392\tAccuracy : 0.285\t\tval-loss: 1.380\tval-Accuracy : 0.263\n",
      "[11,  45] loss: 1.132\tAccuracy : 0.528\t\tval-loss: 1.179\tval-Accuracy : 0.581\n",
      "[16,  45] loss: 0.999\tAccuracy : 0.591\t\tval-loss: 1.066\tval-Accuracy : 0.618\n",
      "[21,  45] loss: 0.879\tAccuracy : 0.673\t\tval-loss: 1.006\tval-Accuracy : 0.599\n",
      "[26,  45] loss: 0.689\tAccuracy : 0.733\t\tval-loss: 1.066\tval-Accuracy : 0.641\n",
      "[31,  45] loss: 0.460\tAccuracy : 0.790\t\tval-loss: 0.968\tval-Accuracy : 0.691\n",
      "[36,  45] loss: 0.288\tAccuracy : 0.889\t\tval-loss: 0.541\tval-Accuracy : 0.843\n",
      "[41,  45] loss: 0.218\tAccuracy : 0.917\t\tval-loss: 0.455\tval-Accuracy : 0.871\n",
      "Begin Training rep 5/5\t of Patient 12\n",
      "[1,  45] loss: 1.403\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.252\tAccuracy : 0.456\t\tval-loss: 1.223\tval-Accuracy : 0.530\n",
      "[11,  45] loss: 1.046\tAccuracy : 0.583\t\tval-loss: 1.013\tval-Accuracy : 0.618\n",
      "[16,  45] loss: 0.926\tAccuracy : 0.628\t\tval-loss: 0.999\tval-Accuracy : 0.622\n",
      "[21,  45] loss: 0.741\tAccuracy : 0.743\t\tval-loss: 0.754\tval-Accuracy : 0.719\n",
      "[26,  45] loss: 0.724\tAccuracy : 0.669\t\tval-loss: 1.009\tval-Accuracy : 0.664\n",
      "[31,  45] loss: 0.281\tAccuracy : 0.892\t\tval-loss: 0.760\tval-Accuracy : 0.816\n",
      "[36,  45] loss: 0.236\tAccuracy : 0.905\t\tval-loss: 0.652\tval-Accuracy : 0.839\n",
      "[41,  45] loss: 0.227\tAccuracy : 0.910\t\tval-loss: 0.666\tval-Accuracy : 0.848\n",
      "loss: 0.229\tAccuracy : 0.906\t\tval-loss: 0.675\tval-Accuracy : 0.829\n",
      "Begin Training rep 1/5\t of Patient 13\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "non-empty 3D or 4D input tensor expected but got ndim: 4",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-57c097f46e05>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     60\u001b[0m                 \u001b[0mval_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     61\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mval_acc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 62\u001b[1;33m                     \u001b[0mval_outputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcnn_cpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     63\u001b[0m                     \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mval_outputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m                     \u001b[0mval_loss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mval_outputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-7-a97bbd16b095>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     26\u001b[0m             \u001b[0mtmp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     27\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 28\u001b[1;33m             \u001b[0mtmp\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv7\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv6\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv5\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv4\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv3\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     29\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m*\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m*\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     30\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\pooling.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    139\u001b[0m         return F.max_pool2d(input, self.kernel_size, self.stride,\n\u001b[0;32m    140\u001b[0m                             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpadding\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdilation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mceil_mode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m                             self.return_indices)\n\u001b[0m\u001b[0;32m    142\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    143\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\_jit_internal.py\u001b[0m in \u001b[0;36mfn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    136\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mif_true\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    137\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mif_false\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    140\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mif_true\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mif_false\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36m_max_pool2d\u001b[1;34m(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)\u001b[0m\n\u001b[0;32m    486\u001b[0m         \u001b[0mstride\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mList\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    487\u001b[0m     return torch.max_pool2d(\n\u001b[1;32m--> 488\u001b[1;33m         input, kernel_size, stride, padding, dilation, ceil_mode)\n\u001b[0m\u001b[0;32m    489\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    490\u001b[0m max_pool2d = boolean_dispatch(\n",
      "\u001b[1;31mRuntimeError\u001b[0m: non-empty 3D or 4D input tensor expected but got ndim: 4"
     ]
    }
   ],
   "source": [
    "p = 12\n",
    "n_rep = 5    \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "for patient in np.unique(Patient):\n",
    "    patient = patient + 13\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = MaxCNN().cuda()\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.01)\n",
    "        \n",
    "        n_epochs = 45\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                \n",
    "                CNN.to(torch.device(\"cuda\"))\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            if epoch%5==0:\n",
    "                cnn_cpu = CNN.to(torch.device(\"cpu\"))\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(cnn_cpu(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32)).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = cnn_cpu(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32))\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long)).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "\n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/5\t of Patient 14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  45] loss: 1.403\tAccuracy : 0.286\t\tval-loss: 1.383\tval-Accuracy : 0.273\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.284\t\tval-loss: 1.380\tval-Accuracy : 0.278\n",
      "[11,  45] loss: 1.243\tAccuracy : 0.464\t\tval-loss: 1.235\tval-Accuracy : 0.464\n",
      "[16,  45] loss: 1.071\tAccuracy : 0.552\t\tval-loss: 1.099\tval-Accuracy : 0.565\n",
      "[21,  45] loss: 0.970\tAccuracy : 0.621\t\tval-loss: 0.996\tval-Accuracy : 0.612\n",
      "[26,  45] loss: 0.838\tAccuracy : 0.715\t\tval-loss: 0.963\tval-Accuracy : 0.641\n",
      "[31,  45] loss: 0.609\tAccuracy : 0.831\t\tval-loss: 0.720\tval-Accuracy : 0.785\n",
      "[36,  45] loss: 0.392\tAccuracy : 0.906\t\tval-loss: 0.601\tval-Accuracy : 0.656\n",
      "[41,  45] loss: 0.243\tAccuracy : 0.912\t\tval-loss: 0.485\tval-Accuracy : 0.675\n",
      "Begin Training rep 2/5\t of Patient 14\n",
      "[1,  45] loss: 1.402\tAccuracy : 0.284\t\tval-loss: 1.383\tval-Accuracy : 0.278\n",
      "[6,  45] loss: 1.398\tAccuracy : 0.284\t\tval-loss: 1.381\tval-Accuracy : 0.278\n",
      "[11,  45] loss: 1.338\tAccuracy : 0.398\t\tval-loss: 1.318\tval-Accuracy : 0.469\n",
      "[16,  45] loss: 1.081\tAccuracy : 0.574\t\tval-loss: 1.017\tval-Accuracy : 0.608\n",
      "[21,  45] loss: 0.981\tAccuracy : 0.605\t\tval-loss: 0.967\tval-Accuracy : 0.646\n",
      "[26,  45] loss: 0.912\tAccuracy : 0.650\t\tval-loss: 0.921\tval-Accuracy : 0.656\n",
      "[31,  45] loss: 0.715\tAccuracy : 0.749\t\tval-loss: 0.909\tval-Accuracy : 0.742\n",
      "[36,  45] loss: 0.417\tAccuracy : 0.836\t\tval-loss: 0.597\tval-Accuracy : 0.828\n",
      "[41,  45] loss: 0.338\tAccuracy : 0.889\t\tval-loss: 0.527\tval-Accuracy : 0.833\n",
      "Begin Training rep 3/5\t of Patient 14\n",
      "[1,  45] loss: 1.405\tAccuracy : 0.284\t\tval-loss: 1.384\tval-Accuracy : 0.278\n",
      "[6,  45] loss: 1.262\tAccuracy : 0.446\t\tval-loss: 1.263\tval-Accuracy : 0.474\n",
      "[11,  45] loss: 1.065\tAccuracy : 0.576\t\tval-loss: 1.061\tval-Accuracy : 0.589\n",
      "[16,  45] loss: 0.951\tAccuracy : 0.636\t\tval-loss: 1.003\tval-Accuracy : 0.636\n",
      "[21,  45] loss: 0.815\tAccuracy : 0.670\t\tval-loss: 0.894\tval-Accuracy : 0.651\n",
      "[26,  45] loss: 0.668\tAccuracy : 0.762\t\tval-loss: 0.686\tval-Accuracy : 0.756\n",
      "[31,  45] loss: 0.453\tAccuracy : 0.873\t\tval-loss: 0.530\tval-Accuracy : 0.837\n",
      "[36,  45] loss: 0.275\tAccuracy : 0.909\t\tval-loss: 0.680\tval-Accuracy : 0.665\n",
      "[41,  45] loss: 0.197\tAccuracy : 0.936\t\tval-loss: 0.713\tval-Accuracy : 0.689\n",
      "Begin Training rep 4/5\t of Patient 14\n",
      "[1,  45] loss: 1.401\tAccuracy : 0.283\t\tval-loss: 1.383\tval-Accuracy : 0.282\n",
      "[6,  45] loss: 1.398\tAccuracy : 0.284\t\tval-loss: 1.382\tval-Accuracy : 0.278\n",
      "[11,  45] loss: 1.397\tAccuracy : 0.284\t\tval-loss: 1.382\tval-Accuracy : 0.278\n",
      "[16,  45] loss: 1.395\tAccuracy : 0.284\t\tval-loss: 1.380\tval-Accuracy : 0.278\n",
      "[21,  45] loss: 1.186\tAccuracy : 0.495\t\tval-loss: 1.156\tval-Accuracy : 0.502\n",
      "[26,  45] loss: 1.032\tAccuracy : 0.592\t\tval-loss: 1.023\tval-Accuracy : 0.627\n",
      "[31,  45] loss: 0.933\tAccuracy : 0.656\t\tval-loss: 0.981\tval-Accuracy : 0.665\n",
      "[36,  45] loss: 0.709\tAccuracy : 0.796\t\tval-loss: 0.710\tval-Accuracy : 0.799\n",
      "[41,  45] loss: 0.409\tAccuracy : 0.861\t\tval-loss: 0.602\tval-Accuracy : 0.665\n",
      "Begin Training rep 5/5\t of Patient 14\n",
      "[1,  45] loss: 1.403\tAccuracy : 0.284\t\tval-loss: 1.383\tval-Accuracy : 0.278\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.284\t\tval-loss: 1.380\tval-Accuracy : 0.278\n",
      "[11,  45] loss: 1.194\tAccuracy : 0.482\t\tval-loss: 1.203\tval-Accuracy : 0.478\n",
      "[16,  45] loss: 1.020\tAccuracy : 0.572\t\tval-loss: 1.169\tval-Accuracy : 0.565\n",
      "[21,  45] loss: 0.914\tAccuracy : 0.645\t\tval-loss: 1.014\tval-Accuracy : 0.612\n",
      "[26,  45] loss: 0.781\tAccuracy : 0.718\t\tval-loss: 0.916\tval-Accuracy : 0.708\n",
      "[31,  45] loss: 0.610\tAccuracy : 0.809\t\tval-loss: 0.580\tval-Accuracy : 0.842\n",
      "[36,  45] loss: 0.403\tAccuracy : 0.870\t\tval-loss: 0.667\tval-Accuracy : 0.646\n",
      "[41,  45] loss: 0.295\tAccuracy : 0.898\t\tval-loss: 0.687\tval-Accuracy : 0.675\n",
      "loss: 0.411\tAccuracy : 0.899\t\tval-loss: 0.603\tval-Accuracy : 0.707\n",
      "Begin Training rep 1/5\t of Patient 15\n",
      "[1,  45] loss: 1.400\tAccuracy : 0.284\t\tval-loss: 1.382\tval-Accuracy : 0.273\n",
      "[6,  45] loss: 1.397\tAccuracy : 0.284\t\tval-loss: 1.381\tval-Accuracy : 0.273\n",
      "[11,  45] loss: 1.222\tAccuracy : 0.477\t\tval-loss: 1.399\tval-Accuracy : 0.309\n",
      "[16,  45] loss: 1.019\tAccuracy : 0.573\t\tval-loss: 1.598\tval-Accuracy : 0.364\n",
      "[21,  45] loss: 0.895\tAccuracy : 0.623\t\tval-loss: 1.710\tval-Accuracy : 0.418\n",
      "[26,  45] loss: 0.726\tAccuracy : 0.690\t\tval-loss: 1.776\tval-Accuracy : 0.414\n",
      "[31,  45] loss: 0.423\tAccuracy : 0.859\t\tval-loss: 2.021\tval-Accuracy : 0.295\n",
      "[36,  45] loss: 0.241\tAccuracy : 0.894\t\tval-loss: 2.203\tval-Accuracy : 0.377\n",
      "[41,  45] loss: 0.184\tAccuracy : 0.940\t\tval-loss: 2.258\tval-Accuracy : 0.395\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "index 13 is out of bounds for axis 1 with size 13",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-14-7a4d43dcb058>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     70\u001b[0m              (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n\u001b[0;32m     71\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 72\u001b[1;33m         \u001b[0mfold_vloss\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrep\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mp\u001b[0m \u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mval_loss\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     73\u001b[0m         \u001b[0mfold_loss\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrep\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mp\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrunning_loss\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     74\u001b[0m         \u001b[0mfold_vacc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrep\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mp\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mval_acc\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mIndexError\u001b[0m: index 13 is out of bounds for axis 1 with size 13"
     ]
    }
   ],
   "source": [
    "p = 12\n",
    "n_rep = 5    \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "for patient in np.unique(Patient):\n",
    "    patient = patient + 13\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = MaxCNN().cuda()\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.01)\n",
    "        \n",
    "        n_epochs = 45\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                \n",
    "                CNN.to(torch.device(\"cuda\"))\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            if epoch%5==0:\n",
    "                cnn_cpu = CNN.to(torch.device(\"cpu\"))\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(cnn_cpu(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32)).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = cnn_cpu(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32))\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long)).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "\n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sio.savemat('Result/Res_MaxPoolCNN.mat',{\"loss\":fold_loss,\"acc\":fold_acc,\"val loss\":fold_vloss,\"val acc\":fold_vacc})\n",
    "\n",
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n MaxPool CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('Result/MaxPoolCNN.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TempCNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TempCNN, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        self.conv8 = nn.Conv1d(7,64,(4*4*128,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        self.pool = nn.MaxPool2d((7,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        self.fc = nn.Linear(192,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            tmp[:,i] = self.pool1( F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4,1)\n",
    "        x = F.relu(self.conv8(x))\n",
    "        x = x.view(x.shape[0],-1)\n",
    "        x = self.fc(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:33: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-1.3924, -1.3307, -1.3637, -1.4631],\n",
       "        [-1.3927, -1.3307, -1.3638, -1.4627]], device='cuda:0',\n",
       "       grad_fn=<LogSoftmaxBackward>)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = TempCNN().cuda()\n",
    "net(torch.from_numpy(tmp[0:2]).to(torch.float32).cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/1\t of Patient 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\ipykernel_launcher.py:33: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  45] loss: 1.404\tAccuracy : 0.283\t\tval-loss: 1.380\tval-Accuracy : 0.287\n",
      "[2,  45] loss: 1.401\tAccuracy : 0.283\t\tval-loss: 1.376\tval-Accuracy : 0.287\n",
      "[3,  45] loss: 1.399\tAccuracy : 0.283\t\tval-loss: 1.373\tval-Accuracy : 0.287\n",
      "[4,  45] loss: 1.398\tAccuracy : 0.283\t\tval-loss: 1.371\tval-Accuracy : 0.287\n",
      "[5,  45] loss: 1.397\tAccuracy : 0.283\t\tval-loss: 1.370\tval-Accuracy : 0.287\n",
      "[6,  45] loss: 1.396\tAccuracy : 0.283\t\tval-loss: 1.366\tval-Accuracy : 0.287\n",
      "[7,  45] loss: 1.390\tAccuracy : 0.283\t\tval-loss: 1.347\tval-Accuracy : 0.287\n",
      "[8,  45] loss: 1.328\tAccuracy : 0.427\t\tval-loss: 1.156\tval-Accuracy : 0.307\n",
      "[9,  45] loss: 1.217\tAccuracy : 0.519\t\tval-loss: 1.032\tval-Accuracy : 0.485\n",
      "[10,  45] loss: 1.113\tAccuracy : 0.565\t\tval-loss: 0.905\tval-Accuracy : 0.604\n",
      "[11,  45] loss: 1.030\tAccuracy : 0.598\t\tval-loss: 0.829\tval-Accuracy : 0.634\n",
      "[12,  45] loss: 0.970\tAccuracy : 0.628\t\tval-loss: 0.782\tval-Accuracy : 0.678\n",
      "[13,  45] loss: 0.921\tAccuracy : 0.655\t\tval-loss: 0.774\tval-Accuracy : 0.688\n",
      "[14,  45] loss: 0.840\tAccuracy : 0.702\t\tval-loss: 0.624\tval-Accuracy : 0.787\n",
      "[15,  45] loss: 0.756\tAccuracy : 0.751\t\tval-loss: 0.581\tval-Accuracy : 0.807\n",
      "[16,  45] loss: 0.757\tAccuracy : 0.783\t\tval-loss: 0.499\tval-Accuracy : 0.851\n",
      "[17,  45] loss: 0.627\tAccuracy : 0.810\t\tval-loss: 0.366\tval-Accuracy : 0.906\n",
      "[18,  45] loss: 0.556\tAccuracy : 0.840\t\tval-loss: 0.308\tval-Accuracy : 0.936\n",
      "[19,  45] loss: 0.448\tAccuracy : 0.868\t\tval-loss: 0.163\tval-Accuracy : 0.960\n",
      "[20,  45] loss: 0.399\tAccuracy : 0.868\t\tval-loss: 0.170\tval-Accuracy : 0.960\n",
      "[21,  45] loss: 0.293\tAccuracy : 0.905\t\tval-loss: 0.132\tval-Accuracy : 0.970\n",
      "[22,  45] loss: 0.242\tAccuracy : 0.925\t\tval-loss: 0.093\tval-Accuracy : 0.980\n",
      "[23,  45] loss: 0.205\tAccuracy : 0.931\t\tval-loss: 0.086\tval-Accuracy : 0.980\n",
      "[24,  45] loss: 0.175\tAccuracy : 0.942\t\tval-loss: 0.078\tval-Accuracy : 0.980\n",
      "[25,  45] loss: 0.159\tAccuracy : 0.941\t\tval-loss: 0.100\tval-Accuracy : 0.975\n",
      "[26,  45] loss: 0.169\tAccuracy : 0.943\t\tval-loss: 0.065\tval-Accuracy : 0.980\n",
      "[27,  45] loss: 0.982\tAccuracy : 0.836\t\tval-loss: 0.336\tval-Accuracy : 0.916\n",
      "[28,  45] loss: 0.310\tAccuracy : 0.915\t\tval-loss: 0.066\tval-Accuracy : 0.990\n",
      "[29,  45] loss: 0.195\tAccuracy : 0.931\t\tval-loss: 0.091\tval-Accuracy : 0.980\n",
      "[30,  45] loss: 0.159\tAccuracy : 0.940\t\tval-loss: 0.096\tval-Accuracy : 0.980\n",
      "[31,  45] loss: 0.155\tAccuracy : 0.933\t\tval-loss: 0.089\tval-Accuracy : 0.975\n",
      "[32,  45] loss: 0.160\tAccuracy : 0.949\t\tval-loss: 0.077\tval-Accuracy : 0.970\n",
      "[33,  45] loss: 0.117\tAccuracy : 0.953\t\tval-loss: 0.090\tval-Accuracy : 0.965\n",
      "[34,  45] loss: 0.109\tAccuracy : 0.950\t\tval-loss: 0.069\tval-Accuracy : 0.980\n",
      "[35,  45] loss: 0.101\tAccuracy : 0.955\t\tval-loss: 0.078\tval-Accuracy : 0.975\n",
      "[36,  45] loss: 0.094\tAccuracy : 0.957\t\tval-loss: 0.080\tval-Accuracy : 0.970\n",
      "[37,  45] loss: 0.090\tAccuracy : 0.958\t\tval-loss: 0.078\tval-Accuracy : 0.975\n",
      "[38,  45] loss: 0.086\tAccuracy : 0.959\t\tval-loss: 0.079\tval-Accuracy : 0.970\n",
      "[39,  45] loss: 0.083\tAccuracy : 0.958\t\tval-loss: 0.080\tval-Accuracy : 0.970\n",
      "[40,  45] loss: 0.080\tAccuracy : 0.960\t\tval-loss: 0.082\tval-Accuracy : 0.965\n",
      "[41,  45] loss: 0.078\tAccuracy : 0.962\t\tval-loss: 0.085\tval-Accuracy : 0.970\n",
      "[42,  45] loss: 0.074\tAccuracy : 0.962\t\tval-loss: 0.086\tval-Accuracy : 0.975\n",
      "[43,  45] loss: 0.072\tAccuracy : 0.963\t\tval-loss: 0.090\tval-Accuracy : 0.970\n",
      "[44,  45] loss: 0.071\tAccuracy : 0.964\t\tval-loss: 0.094\tval-Accuracy : 0.970\n",
      "[45,  45] loss: 0.069\tAccuracy : 0.963\t\tval-loss: 0.099\tval-Accuracy : 0.970\n",
      "loss: 0.069\tAccuracy : 0.963\t\tval-loss: 0.099\tval-Accuracy : 0.970\n",
      "Begin Training rep 1/1\t of Patient 12\n",
      "[1,  45] loss: 1.404\tAccuracy : 0.240\t\tval-loss: 1.385\tval-Accuracy : 0.240\n",
      "[2,  45] loss: 1.401\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[3,  45] loss: 1.399\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[4,  45] loss: 1.398\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[5,  45] loss: 1.397\tAccuracy : 0.285\t\tval-loss: 1.384\tval-Accuracy : 0.263\n",
      "[6,  45] loss: 1.396\tAccuracy : 0.285\t\tval-loss: 1.383\tval-Accuracy : 0.263\n",
      "[7,  45] loss: 1.393\tAccuracy : 0.285\t\tval-loss: 1.378\tval-Accuracy : 0.263\n",
      "[8,  45] loss: 1.358\tAccuracy : 0.344\t\tval-loss: 1.328\tval-Accuracy : 0.318\n",
      "[9,  45] loss: 1.249\tAccuracy : 0.497\t\tval-loss: 1.229\tval-Accuracy : 0.530\n",
      "[10,  45] loss: 1.139\tAccuracy : 0.607\t\tval-loss: 1.209\tval-Accuracy : 0.571\n",
      "[11,  45] loss: 1.048\tAccuracy : 0.633\t\tval-loss: 1.135\tval-Accuracy : 0.553\n",
      "[12,  45] loss: 0.962\tAccuracy : 0.654\t\tval-loss: 1.078\tval-Accuracy : 0.562\n",
      "[13,  45] loss: 0.865\tAccuracy : 0.687\t\tval-loss: 1.236\tval-Accuracy : 0.581\n",
      "[14,  45] loss: 0.781\tAccuracy : 0.735\t\tval-loss: 1.049\tval-Accuracy : 0.636\n",
      "[15,  45] loss: 0.681\tAccuracy : 0.764\t\tval-loss: 1.046\tval-Accuracy : 0.608\n",
      "[16,  45] loss: 0.588\tAccuracy : 0.788\t\tval-loss: 1.112\tval-Accuracy : 0.641\n",
      "[17,  45] loss: 0.558\tAccuracy : 0.821\t\tval-loss: 0.808\tval-Accuracy : 0.760\n",
      "[18,  45] loss: 0.467\tAccuracy : 0.837\t\tval-loss: 1.016\tval-Accuracy : 0.747\n",
      "[19,  45] loss: 0.453\tAccuracy : 0.851\t\tval-loss: 0.793\tval-Accuracy : 0.816\n",
      "[20,  45] loss: 0.519\tAccuracy : 0.819\t\tval-loss: 0.884\tval-Accuracy : 0.770\n",
      "[21,  45] loss: 0.449\tAccuracy : 0.867\t\tval-loss: 0.907\tval-Accuracy : 0.751\n",
      "[22,  45] loss: 0.298\tAccuracy : 0.885\t\tval-loss: 0.991\tval-Accuracy : 0.779\n",
      "[23,  45] loss: 0.231\tAccuracy : 0.888\t\tval-loss: 1.303\tval-Accuracy : 0.760\n",
      "[24,  45] loss: 0.354\tAccuracy : 0.890\t\tval-loss: 0.884\tval-Accuracy : 0.760\n",
      "[25,  45] loss: 0.228\tAccuracy : 0.915\t\tval-loss: 1.024\tval-Accuracy : 0.802\n",
      "[26,  45] loss: 0.277\tAccuracy : 0.888\t\tval-loss: 0.807\tval-Accuracy : 0.765\n",
      "[27,  45] loss: 0.247\tAccuracy : 0.859\t\tval-loss: 0.909\tval-Accuracy : 0.751\n",
      "[28,  45] loss: 0.177\tAccuracy : 0.930\t\tval-loss: 1.099\tval-Accuracy : 0.829\n",
      "[29,  45] loss: 0.139\tAccuracy : 0.936\t\tval-loss: 1.224\tval-Accuracy : 0.829\n",
      "[30,  45] loss: 0.125\tAccuracy : 0.937\t\tval-loss: 1.362\tval-Accuracy : 0.820\n",
      "[31,  45] loss: 0.113\tAccuracy : 0.937\t\tval-loss: 1.460\tval-Accuracy : 0.825\n",
      "[32,  45] loss: 0.105\tAccuracy : 0.939\t\tval-loss: 1.542\tval-Accuracy : 0.829\n",
      "[33,  45] loss: 0.100\tAccuracy : 0.939\t\tval-loss: 1.617\tval-Accuracy : 0.839\n",
      "[34,  45] loss: 0.095\tAccuracy : 0.944\t\tval-loss: 1.673\tval-Accuracy : 0.839\n",
      "[35,  45] loss: 0.091\tAccuracy : 0.946\t\tval-loss: 1.712\tval-Accuracy : 0.848\n",
      "[36,  45] loss: 0.087\tAccuracy : 0.947\t\tval-loss: 1.763\tval-Accuracy : 0.848\n",
      "[37,  45] loss: 0.084\tAccuracy : 0.948\t\tval-loss: 1.797\tval-Accuracy : 0.857\n",
      "[38,  45] loss: 0.082\tAccuracy : 0.949\t\tval-loss: 1.817\tval-Accuracy : 0.857\n",
      "[39,  45] loss: 0.079\tAccuracy : 0.952\t\tval-loss: 1.853\tval-Accuracy : 0.857\n",
      "[40,  45] loss: 0.077\tAccuracy : 0.955\t\tval-loss: 1.889\tval-Accuracy : 0.853\n",
      "[41,  45] loss: 0.075\tAccuracy : 0.956\t\tval-loss: 1.911\tval-Accuracy : 0.853\n",
      "[42,  45] loss: 0.074\tAccuracy : 0.957\t\tval-loss: 1.904\tval-Accuracy : 0.857\n",
      "[43,  45] loss: 0.072\tAccuracy : 0.961\t\tval-loss: 2.002\tval-Accuracy : 0.862\n",
      "[44,  45] loss: 0.071\tAccuracy : 0.965\t\tval-loss: 1.996\tval-Accuracy : 0.862\n",
      "[45,  45] loss: 0.070\tAccuracy : 0.962\t\tval-loss: 2.043\tval-Accuracy : 0.871\n",
      "loss: 0.070\tAccuracy : 0.962\t\tval-loss: 2.043\tval-Accuracy : 0.871\n",
      "Begin Training rep 1/1\t of Patient 10\n",
      "[1,  45] loss: 1.402\tAccuracy : 0.285\t\tval-loss: 1.382\tval-Accuracy : 0.271\n",
      "[2,  45] loss: 1.400\tAccuracy : 0.285\t\tval-loss: 1.380\tval-Accuracy : 0.271\n",
      "[3,  45] loss: 1.398\tAccuracy : 0.285\t\tval-loss: 1.379\tval-Accuracy : 0.271\n",
      "[4,  45] loss: 1.396\tAccuracy : 0.285\t\tval-loss: 1.374\tval-Accuracy : 0.271\n",
      "[5,  45] loss: 1.378\tAccuracy : 0.285\t\tval-loss: 1.302\tval-Accuracy : 0.271\n",
      "[6,  45] loss: 1.270\tAccuracy : 0.467\t\tval-loss: 1.120\tval-Accuracy : 0.467\n",
      "[7,  45] loss: 1.172\tAccuracy : 0.559\t\tval-loss: 1.016\tval-Accuracy : 0.562\n",
      "[8,  45] loss: 1.097\tAccuracy : 0.586\t\tval-loss: 0.936\tval-Accuracy : 0.600\n",
      "[9,  45] loss: 1.027\tAccuracy : 0.620\t\tval-loss: 0.890\tval-Accuracy : 0.648\n",
      "[10,  45] loss: 0.963\tAccuracy : 0.657\t\tval-loss: 0.813\tval-Accuracy : 0.671\n",
      "[11,  45] loss: 0.887\tAccuracy : 0.700\t\tval-loss: 0.712\tval-Accuracy : 0.733\n",
      "[12,  45] loss: 0.842\tAccuracy : 0.714\t\tval-loss: 0.636\tval-Accuracy : 0.786\n",
      "[13,  45] loss: 0.764\tAccuracy : 0.754\t\tval-loss: 0.595\tval-Accuracy : 0.814\n",
      "[14,  45] loss: 0.690\tAccuracy : 0.784\t\tval-loss: 0.482\tval-Accuracy : 0.838\n",
      "[15,  45] loss: 0.628\tAccuracy : 0.815\t\tval-loss: 0.438\tval-Accuracy : 0.857\n",
      "[16,  45] loss: 0.604\tAccuracy : 0.826\t\tval-loss: 0.375\tval-Accuracy : 0.895\n",
      "[17,  45] loss: 0.553\tAccuracy : 0.840\t\tval-loss: 0.348\tval-Accuracy : 0.900\n",
      "[18,  45] loss: 0.425\tAccuracy : 0.867\t\tval-loss: 0.279\tval-Accuracy : 0.938\n",
      "[19,  45] loss: 0.396\tAccuracy : 0.889\t\tval-loss: 0.228\tval-Accuracy : 0.957\n",
      "[20,  45] loss: 0.338\tAccuracy : 0.893\t\tval-loss: 0.194\tval-Accuracy : 0.957\n",
      "[21,  45] loss: 0.262\tAccuracy : 0.914\t\tval-loss: 0.182\tval-Accuracy : 0.952\n",
      "[22,  45] loss: 0.219\tAccuracy : 0.925\t\tval-loss: 0.217\tval-Accuracy : 0.943\n",
      "[23,  45] loss: 0.224\tAccuracy : 0.934\t\tval-loss: 0.160\tval-Accuracy : 0.957\n",
      "[24,  45] loss: 0.256\tAccuracy : 0.901\t\tval-loss: 0.193\tval-Accuracy : 0.952\n",
      "[25,  45] loss: 0.266\tAccuracy : 0.921\t\tval-loss: 0.134\tval-Accuracy : 0.952\n",
      "[26,  45] loss: 0.172\tAccuracy : 0.945\t\tval-loss: 0.153\tval-Accuracy : 0.962\n",
      "[27,  45] loss: 0.139\tAccuracy : 0.950\t\tval-loss: 0.152\tval-Accuracy : 0.962\n",
      "[28,  45] loss: 0.121\tAccuracy : 0.953\t\tval-loss: 0.153\tval-Accuracy : 0.957\n",
      "[29,  45] loss: 0.111\tAccuracy : 0.954\t\tval-loss: 0.155\tval-Accuracy : 0.962\n",
      "[30,  45] loss: 0.103\tAccuracy : 0.955\t\tval-loss: 0.158\tval-Accuracy : 0.957\n",
      "[31,  45] loss: 0.098\tAccuracy : 0.955\t\tval-loss: 0.162\tval-Accuracy : 0.957\n",
      "[32,  45] loss: 0.093\tAccuracy : 0.959\t\tval-loss: 0.162\tval-Accuracy : 0.952\n",
      "[33,  45] loss: 0.088\tAccuracy : 0.959\t\tval-loss: 0.169\tval-Accuracy : 0.952\n",
      "[34,  45] loss: 0.085\tAccuracy : 0.959\t\tval-loss: 0.181\tval-Accuracy : 0.943\n",
      "[35,  45] loss: 0.082\tAccuracy : 0.959\t\tval-loss: 0.190\tval-Accuracy : 0.938\n",
      "[36,  45] loss: 0.080\tAccuracy : 0.960\t\tval-loss: 0.200\tval-Accuracy : 0.938\n",
      "[37,  45] loss: 0.077\tAccuracy : 0.960\t\tval-loss: 0.209\tval-Accuracy : 0.938\n",
      "[38,  45] loss: 0.075\tAccuracy : 0.961\t\tval-loss: 0.214\tval-Accuracy : 0.938\n",
      "[39,  45] loss: 0.073\tAccuracy : 0.961\t\tval-loss: 0.218\tval-Accuracy : 0.938\n",
      "[40,  45] loss: 0.071\tAccuracy : 0.961\t\tval-loss: 0.228\tval-Accuracy : 0.938\n",
      "[41,  45] loss: 0.069\tAccuracy : 0.962\t\tval-loss: 0.236\tval-Accuracy : 0.938\n",
      "[42,  45] loss: 0.068\tAccuracy : 0.962\t\tval-loss: 0.248\tval-Accuracy : 0.943\n",
      "[43,  45] loss: 0.067\tAccuracy : 0.963\t\tval-loss: 0.256\tval-Accuracy : 0.943\n",
      "[44,  45] loss: 0.066\tAccuracy : 0.963\t\tval-loss: 0.261\tval-Accuracy : 0.938\n",
      "[45,  45] loss: 0.065\tAccuracy : 0.963\t\tval-loss: 0.271\tval-Accuracy : 0.938\n",
      "loss: 0.065\tAccuracy : 0.963\t\tval-loss: 0.271\tval-Accuracy : 0.938\n",
      "Begin Training rep 1/1\t of Patient 8\n",
      "[1,  45] loss: 1.400\tAccuracy : 0.283\t\tval-loss: 1.373\tval-Accuracy : 0.290\n",
      "[2,  45] loss: 1.399\tAccuracy : 0.283\t\tval-loss: 1.370\tval-Accuracy : 0.290\n",
      "[3,  45] loss: 1.398\tAccuracy : 0.283\t\tval-loss: 1.367\tval-Accuracy : 0.290\n",
      "[4,  45] loss: 1.397\tAccuracy : 0.283\t\tval-loss: 1.364\tval-Accuracy : 0.290\n",
      "[5,  45] loss: 1.395\tAccuracy : 0.283\t\tval-loss: 1.359\tval-Accuracy : 0.290\n",
      "[6,  45] loss: 1.385\tAccuracy : 0.283\t\tval-loss: 1.320\tval-Accuracy : 0.290\n",
      "[7,  45] loss: 1.287\tAccuracy : 0.499\t\tval-loss: 1.221\tval-Accuracy : 0.373\n",
      "[8,  45] loss: 1.172\tAccuracy : 0.560\t\tval-loss: 1.076\tval-Accuracy : 0.534\n",
      "[9,  45] loss: 1.078\tAccuracy : 0.619\t\tval-loss: 0.941\tval-Accuracy : 0.627\n",
      "[10,  45] loss: 0.986\tAccuracy : 0.677\t\tval-loss: 0.871\tval-Accuracy : 0.705\n",
      "[11,  45] loss: 0.886\tAccuracy : 0.705\t\tval-loss: 0.763\tval-Accuracy : 0.705\n",
      "[12,  45] loss: 0.829\tAccuracy : 0.727\t\tval-loss: 0.767\tval-Accuracy : 0.746\n",
      "[13,  45] loss: 0.714\tAccuracy : 0.778\t\tval-loss: 0.555\tval-Accuracy : 0.855\n",
      "[14,  45] loss: 0.628\tAccuracy : 0.802\t\tval-loss: 0.532\tval-Accuracy : 0.834\n",
      "[15,  45] loss: 0.625\tAccuracy : 0.825\t\tval-loss: 0.472\tval-Accuracy : 0.886\n",
      "[16,  45] loss: 0.508\tAccuracy : 0.859\t\tval-loss: 0.493\tval-Accuracy : 0.902\n",
      "[17,  45] loss: 0.379\tAccuracy : 0.887\t\tval-loss: 0.454\tval-Accuracy : 0.902\n",
      "[18,  45] loss: 0.389\tAccuracy : 0.869\t\tval-loss: 0.445\tval-Accuracy : 0.907\n",
      "[19,  45] loss: 0.326\tAccuracy : 0.897\t\tval-loss: 0.337\tval-Accuracy : 0.933\n",
      "[20,  45] loss: 0.279\tAccuracy : 0.907\t\tval-loss: 0.660\tval-Accuracy : 0.948\n",
      "[21,  45] loss: 0.210\tAccuracy : 0.926\t\tval-loss: 0.715\tval-Accuracy : 0.938\n",
      "[22,  45] loss: 0.170\tAccuracy : 0.931\t\tval-loss: 0.685\tval-Accuracy : 0.943\n",
      "[23,  45] loss: 0.146\tAccuracy : 0.931\t\tval-loss: 0.567\tval-Accuracy : 0.943\n",
      "[24,  45] loss: 0.131\tAccuracy : 0.942\t\tval-loss: 0.690\tval-Accuracy : 0.943\n",
      "[25,  45] loss: 0.118\tAccuracy : 0.944\t\tval-loss: 0.750\tval-Accuracy : 0.943\n",
      "[26,  45] loss: 0.110\tAccuracy : 0.947\t\tval-loss: 0.761\tval-Accuracy : 0.938\n",
      "[27,  45] loss: 0.104\tAccuracy : 0.948\t\tval-loss: 0.730\tval-Accuracy : 0.938\n",
      "[28,  45] loss: 0.099\tAccuracy : 0.948\t\tval-loss: 0.796\tval-Accuracy : 0.938\n",
      "[29,  45] loss: 0.094\tAccuracy : 0.950\t\tval-loss: 0.847\tval-Accuracy : 0.938\n",
      "[30,  45] loss: 0.090\tAccuracy : 0.950\t\tval-loss: 0.866\tval-Accuracy : 0.943\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-41-e7fae06453d9>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     52\u001b[0m                 \u001b[0macc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     53\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0macc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 54\u001b[1;33m                     \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mCNN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     55\u001b[0m                     \u001b[0macc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"cpu\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     56\u001b[0m                 \u001b[0macc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0macc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-16-e790eb8c12b5>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     26\u001b[0m             \u001b[0mtmp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     27\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 28\u001b[1;33m             \u001b[0mtmp\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv7\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv6\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv5\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool1\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv4\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv3\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     29\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtmp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m*\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m*\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     30\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv8\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    540\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 541\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    542\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\modules\\pooling.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    139\u001b[0m         return F.max_pool2d(input, self.kernel_size, self.stride,\n\u001b[0;32m    140\u001b[0m                             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpadding\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdilation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mceil_mode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m                             self.return_indices)\n\u001b[0m\u001b[0;32m    142\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    143\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\_jit_internal.py\u001b[0m in \u001b[0;36mfn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    136\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mif_true\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    137\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mif_false\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    140\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mif_true\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mif_false\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\victo\\.conda\\envs\\pytorch_eeg\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36m_max_pool2d\u001b[1;34m(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)\u001b[0m\n\u001b[0;32m    486\u001b[0m         \u001b[0mstride\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mList\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mint\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    487\u001b[0m     return torch.max_pool2d(\n\u001b[1;32m--> 488\u001b[1;33m         input, kernel_size, stride, padding, dilation, ceil_mode)\n\u001b[0m\u001b[0;32m    489\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    490\u001b[0m max_pool2d = boolean_dispatch(\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_rep = 1  \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "un = np.unique(Patient)\n",
    "np.random.shuffle(un)\n",
    "for patient in un:\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = TempCNN().cuda()\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.01)\n",
    "        \n",
    "        n_epochs = 45\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                \n",
    "                CNN.to(torch.device(\"cuda\"))\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            if epoch%1==0:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(CNN(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda()).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx.to(torch.device(\"cpu\")) == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = CNN(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda())\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long).cuda()).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "                \n",
    "            if epoch==900:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(cnn_cpu(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32)).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = cnn_cpu(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32))\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long)).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "\n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sio' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-e97a8aad55e1>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0msio\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msavemat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Result/Res_TemporalCNN.mat'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"loss\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mfold_loss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"acc\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mfold_acc\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"val loss\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mfold_vloss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"val acc\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mfold_vacc\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mfig\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfigsize\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m12\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgrid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mboxplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfold_vacc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'sio' is not defined"
     ]
    }
   ],
   "source": [
    "sio.savemat('Result/Res_TemporalCNN.mat',{\"loss\":fold_loss,\"acc\":fold_acc,\"val loss\":fold_vloss,\"val acc\":fold_vacc})\n",
    "\n",
    "fig = plt.figure(figsize=(12,10))\n",
    "plt.grid()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.suptitle('Cross-Validation Accuracy\\n Temporal CNN')\n",
    "ax = plt.gca()\n",
    "plt.xlabel('Patient id')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.savefig('Result/TemporalCNN.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/1\t of Patient 12\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'TempCNN' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-4-e7fae06453d9>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     25\u001b[0m              (rep+1,n_rep, patient))\n\u001b[0;32m     26\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m         \u001b[0mCNN\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTempCNN\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     28\u001b[0m         \u001b[0mcriterion\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNLLLoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     29\u001b[0m         \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mCNN\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.01\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'TempCNN' is not defined"
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_rep = 1  \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "un = np.unique(Patient)\n",
    "np.random.shuffle(un)\n",
    "for patient in un:\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = TempCNN().cuda()\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = optim.SGD(CNN.parameters(), lr=0.01)\n",
    "        \n",
    "        n_epochs = 45\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                \n",
    "                CNN.to(torch.device(\"cuda\"))\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            if epoch%1==0:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(CNN(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda()).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx.to(torch.device(\"cpu\")) == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = CNN(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda())\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long).cuda()).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "                \n",
    "            if epoch==900:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(cnn_cpu(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32)).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = cnn_cpu(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32))\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long)).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "\n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        self.lstm = nn.LSTM(4*4*128,128,7)\n",
    "        \n",
    "        self.pool = nn.MaxPool2d((7,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        self.fc = nn.Linear(896,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "        \n",
    "        self.hidden = self.init_hidden()\n",
    "\n",
    "    def init_hidden(self):\n",
    "        return (torch.randn(7, 7, 128).cuda(),\n",
    "                torch.randn(7, 7, 128).cuda())    \n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            tmp[:,i] = self.pool1(F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4)\n",
    "        \n",
    "        lstm_out, self.hidden = self.lstm(x,self.hidden)\n",
    "        x = lstm_out.view(x.shape[0],-1)\n",
    "        \n",
    "        x = self.fc(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        self.lstm = nn.RNN(4*4*128,128,7)\n",
    "        \n",
    "        self.pool = nn.MaxPool2d((7,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        self.fc = nn.Linear(896,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "        \n",
    "        self.lstm_out = torch.zeros(2,7,128)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            img = x[:,i]\n",
    "            img = F.relu(self.conv1(img))\n",
    "            img = F.relu(self.conv2(img))\n",
    "            img = F.relu(self.conv3(img))\n",
    "            img = F.relu(self.conv4(img))\n",
    "            img = self.pool1(img)\n",
    "            img = F.relu(self.conv5(img))\n",
    "            img = F.relu(self.conv6(img))\n",
    "            img = self.pool1(img)\n",
    "            img = F.relu(self.conv7(img))\n",
    "            #x[:,i,]\n",
    "            tmp[:,i] = self.pool1(img)\n",
    "            del img\n",
    "            #tmp[:,i] = self.pool1(F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4)\n",
    "        del tmp\n",
    "        #self.lstm_out, self.hidden = self.lstm(x,self.hidden)\n",
    "        self.lstm_out, _ = self.lstm(x)\n",
    "        \n",
    "        x = self.lstm_out.view(x.shape[0],-1)\n",
    "        x = self.fc(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vdelv/anaconda3/envs/Pytorch_EEG/lib/python3.7/site-packages/ipykernel_launcher.py:51: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-1.3417, -1.3869, -1.4236, -1.3948],\n",
       "        [-1.3531, -1.3826, -1.4254, -1.3854]], device='cuda:0',\n",
       "       grad_fn=<LogSoftmaxBackward>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = LSTM().cuda()\n",
    "net(torch.from_numpy(tmp[0:2]).to(torch.float32).cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/1\t of Patient 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vdelv/anaconda3/envs/Pytorch_EEG/lib/python3.7/site-packages/ipykernel_launcher.py:51: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  50] loss: 1.401\tAccuracy : 0.284\t\tval-loss: 1.375\tval-Accuracy : 0.291\n",
      "[6,  50] loss: 0.759\tAccuracy : 0.744\t\tval-loss: 0.395\tval-Accuracy : 0.898\n",
      "[11,  50] loss: 0.180\tAccuracy : 0.951\t\tval-loss: 0.045\tval-Accuracy : 0.995\n",
      "[16,  50] loss: 0.116\tAccuracy : 0.970\t\tval-loss: 0.029\tval-Accuracy : 0.995\n",
      "[21,  50] loss: 0.067\tAccuracy : 0.979\t\tval-loss: 0.028\tval-Accuracy : 0.995\n",
      "[26,  50] loss: 0.033\tAccuracy : 0.987\t\tval-loss: 0.017\tval-Accuracy : 0.990\n"
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_rep = 1  \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "un = np.unique(Patient)\n",
    "np.random.shuffle(un)\n",
    "for patient in un:\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = LSTM().cuda()\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.Adam(CNN.parameters(), lr=0.0001)\n",
    "        \n",
    "        \n",
    "        n_epochs = 50\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                #print(i)\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda().detach())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward(retain_graph=True)\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()         \n",
    "                \n",
    "                \n",
    "            if epoch%5==0:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(CNN(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda()).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx.to(torch.device(\"cpu\")) == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = CNN(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda())\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long).cuda()).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "                \n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mix Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mix(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Mix, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,32,(3,3),stride=(1,1), padding=1)\n",
    "        self.conv2 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32,32,(3,3),stride=1, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d((2,2))\n",
    "        self.conv5 = nn.Conv2d(32,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv6 = nn.Conv2d(64,64,(3,3),stride=(1,1),padding=1)\n",
    "        self.conv7 = nn.Conv2d(64,128,(3,3),stride=(1,1),padding=1)\n",
    "        \n",
    "        #\n",
    "        self.conv8 = nn.Conv2d(7,64,(4*4*128,3),stride=(1,1),padding=1)\n",
    "        self.lstm = nn.RNN(4*4*128,128,7)\n",
    "        \n",
    "        self.pool = nn.MaxPool2d((7,1))\n",
    "        self.drop = nn.Dropout(p=0.5)\n",
    "        self.fc1 = nn.Linear(1088,512)\n",
    "        self.fc2 = nn.Linear(512,4)\n",
    "        self.max = nn.LogSoftmax()\n",
    "        \n",
    "        self.lstm_out = torch.zeros(2,7,128)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if x.get_device() == 0:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()\n",
    "        else:\n",
    "            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()\n",
    "        for i in range(7):\n",
    "            img = x[:,i]\n",
    "            img = F.relu(self.conv1(img))\n",
    "            img = F.relu(self.conv2(img))\n",
    "            img = F.relu(self.conv3(img))\n",
    "            img = F.relu(self.conv4(img))\n",
    "            img = self.pool1(img)\n",
    "            img = F.relu(self.conv5(img))\n",
    "            img = F.relu(self.conv6(img))\n",
    "            img = self.pool1(img)\n",
    "            img = F.relu(self.conv7(img))\n",
    "            #x[:,i,]\n",
    "            tmp[:,i] = self.pool1(img)\n",
    "            del img\n",
    "            #tmp[:,i] = self.pool1(F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))\n",
    "        \n",
    "        temp_conv = F.relu(self.conv8(tmp.reshape(x.shape[0], x.shape[1], 4*128*4,1)))\n",
    "        temp_conv = temp_conv.reshape(temp_conv.shape[0],-1)\n",
    "        \n",
    "        self.lstm_out, _ = self.lstm(tmp.reshape(x.shape[0], x.shape[1], 4*128*4))\n",
    "        lstm = self.lstm_out.view(x.shape[0],-1)\n",
    "        \n",
    "        x = torch.cat((temp_conv, lstm), 1)\n",
    "        del tmp\n",
    "        #self.lstm_out, self.hidden = self.lstm(x,self.hidden)\n",
    "        \n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.max(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training rep 1/1\t of Patient 11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vdelv/anaconda3/envs/Pytorch_EEG/lib/python3.7/site-packages/ipykernel_launcher.py:58: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  50] loss: 1.317\tAccuracy : 0.492\t\tval-loss: 0.935\tval-Accuracy : 0.498\n",
      "[6,  50] loss: 0.345\tAccuracy : 0.888\t\tval-loss: 0.296\tval-Accuracy : 0.867\n",
      "[11,  50] loss: 0.123\tAccuracy : 0.948\t\tval-loss: 0.260\tval-Accuracy : 0.924\n",
      "[16,  50] loss: 0.076\tAccuracy : 0.965\t\tval-loss: 0.259\tval-Accuracy : 0.911\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-38-270da0643345>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     53\u001b[0m                 \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m                     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCNN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m                     \u001b[0macc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatchsize\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m                 \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "p = 0\n",
    "n_rep = 1  \n",
    "n_patient = len(np.unique(Patient))\n",
    "fold_vloss = np.zeros((n_rep,n_patient))\n",
    "fold_loss = np.zeros((n_rep,n_patient))\n",
    "fold_vacc = np.zeros((n_rep,n_patient))\n",
    "fold_acc = np.zeros((n_rep,n_patient))\n",
    "\n",
    "un = np.unique(Patient)\n",
    "np.random.shuffle(un)\n",
    "for patient in un:\n",
    "    id_patient = np.arange(len(tmp))[Patient==patient]\n",
    "    id_train = np.arange(len(tmp))[Patient!=patient]\n",
    "    \n",
    "    for rep in range(n_rep):\n",
    "        np.random.shuffle(id_patient)\n",
    "        np.random.shuffle(id_train)\n",
    "        \n",
    "        X_train = tmp[id_train]\n",
    "        X_test = tmp[id_patient]\n",
    "        y_train = Label[id_train]\n",
    "        y_test = Label[id_patient]\n",
    "        \n",
    "        print(\"Begin Training rep %d/%d\\t of Patient %d\" % \n",
    "             (rep+1,n_rep, patient))\n",
    "        \n",
    "        CNN = Mix().cuda()\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.Adam(CNN.parameters(), lr=0.0001)\n",
    "        \n",
    "        \n",
    "        n_epochs = 50\n",
    "        for epoch in range(n_epochs):\n",
    "            running_loss = 0.0\n",
    "            batchsize = 32\n",
    "            for i in range(int(len(y_train)/batchsize)):\n",
    "                #print(i)\n",
    "                optimizer.zero_grad()\n",
    "                # forward + backward + optimize\n",
    "                outputs = CNN(torch.from_numpy(X_train[i*batchsize:(i+1)*batchsize]).to(torch.float32).cuda().detach())\n",
    "                loss = criterion(outputs, torch.from_numpy(y_train[i*batchsize:(i+1)*batchsize]).to(torch.long).cuda())\n",
    "                loss.backward(retain_graph=True)\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()         \n",
    "                \n",
    "                \n",
    "            if epoch%5==0:\n",
    "\n",
    "                check_id = np.arange(2000)\n",
    "                np.random.shuffle(check_id)\n",
    "                \n",
    "                #acc\n",
    "                acc = np.zeros(len(y_train))\n",
    "                for j in range(int(len(acc)/batchsize)+1):\n",
    "                    _, idx = torch.max(CNN(torch.from_numpy(X_train[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda()).data,1)\n",
    "                    acc[j*batchsize:(j+1)*batchsize] = (idx.to(torch.device(\"cpu\")) == torch.from_numpy(y_train[j*batchsize:(j+1)*batchsize])).numpy() + 0 \n",
    "                acc = np.mean(acc)\n",
    "                \n",
    "                #validation\n",
    "                val_acc = np.zeros(len(y_test))\n",
    "                val_loss = []\n",
    "                for j in range(int(len(val_acc)/batchsize)+1):\n",
    "                    val_outputs = CNN(torch.from_numpy(X_test[j*batchsize:(j+1)*batchsize]).to(torch.float32).cuda())\n",
    "                    _, idx = torch.max(val_outputs.data,1)\n",
    "                    val_loss.append(criterion(val_outputs, torch.from_numpy(y_test[j*batchsize:(j+1)*batchsize]).to(torch.long).cuda()).item())\n",
    "                    val_acc[j*batchsize:(j+1)*batchsize] = (idx.cpu().numpy() == y_test[j*batchsize:(j+1)*batchsize])+0\n",
    "                val_acc = np.mean(val_acc)\n",
    "                val_loss = np.mean(val_loss)\n",
    "\n",
    "                print('[%d, %3d] loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "             (epoch+1, n_epochs, running_loss/i, float(acc), val_loss, val_acc))\n",
    "                \n",
    "        fold_vloss[rep, p ] = val_loss\n",
    "        fold_loss[rep, p] = running_loss/i\n",
    "        fold_vacc[rep, p] = val_acc\n",
    "        fold_acc[rep, p] = acc\n",
    "      \n",
    "    \n",
    "    print('loss: %.3f\\tAccuracy : %.3f\\t\\tval-loss: %.3f\\tval-Accuracy : %.3f' %\n",
    "                 (np.mean(fold_loss[:,p]), np.mean(fold_acc[:,p]), np.mean(fold_vloss[:,p]),np.mean(fold_vacc[:,p])))\n",
    "    \n",
    "    p = p + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fold_vloss = sio.loadmat(\"result_LSTM.mat\")['vloss']\n",
    "fold_loss = sio.loadmat(\"result_LSTM.mat\")['loss']\n",
    "fold_vacc = sio.loadmat(\"result_LSTM.mat\")['vacc']\n",
    "fold_acc = sio.loadmat(\"result_LSTM.mat\")['acc']  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = plt.figure()\n",
    "plt.boxplot(fold_vacc)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}