[f6de96]: / Ensemble.ipynb

Download this file

1487 lines (1486 with data), 50.6 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from models.densenet import DenseNet\n",
    "from models.MHA_resnet import ResidualNetwork\n",
    "from models.resnet_bda import ResidualNetwork_classifier\n",
    "from models.resnet_rnn import ResidualNetwork_lstm\n",
    "from models.dense_rnn import  DenseNet_rnn\n",
    "from models.dense_bda import DenseNet_bda\n",
    "from models.dense_cam import DenseNet_cam\n",
    "from models.dense_mha import DenseNet_mha\n",
    "# from models.xresnet1d import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model1 = DenseNet(100,30).cuda()\n",
    "# model1.load_state_dict(torch.load('model_weights/densenet-32-aur-0.964-auc 0.615 .pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from easydict import EasyDict as edict\n",
    "config = edict({'hidden_size':320, 'num_attention_heads':8})  \n",
    "model2 = ResidualNetwork(12,config).cuda()\n",
    "model2.load_state_dict(torch.load('model_weights/Multi_head_Arrythmia detection-13-aur-0.975-auc 0.62  (1).pth'), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=[], unexpected_keys=['classifier.linear.weight', 'classifier.linear.bias', 'classifier.attention.fc1.weight', 'classifier.attention.fc1.bias', 'classifier.attention.fc2.weight', 'classifier.attention.fc2.bias'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model3 = ResidualNetwork_classifier(12).cuda()\n",
    "model3.load_state_dict(torch.load('model_weights/RESNET_two_way-41-aur-0.972-auc 0.608 .pth'),strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model4 = ResidualNetwork_lstm(12,bidirectional = False, rnn = 'gru').cuda()\n",
    "model4.load_state_dict(torch.load('model_weights/F_CONV_GRU-24-aur-0.968-auc 0.64 .pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model5 = DenseNet_rnn(100,30).cuda()\n",
    "model5.load_state_dict(torch.load('model_weights/dense-41-aur-0.97-auc 0.64 .pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model6 = DenseNet_bda(100,30).cuda()\n",
    "# model6.load_state_dict(torch.load('model_weights/dense_bda-1-aur-0.967-auc 0.594 .pth'),strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " model7 = DenseNet_cam(100,30).cuda()\n",
    "model7.load_state_dict(torch.load('model_weights/dense_cam_17-aur-0.972-auc 0.623 .pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from easydict import EasyDict as edict\n",
    "config = edict({'hidden_size':320, 'num_attention_heads':8})  \n",
    "model8 = DenseNet_mha(100,30,config).cuda()\n",
    "model8.load_state_dict(torch.load('model_weights/dense-24-aur-0.976-auc 0.634 .pth'), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import numpy as np, os, sys, joblib\n",
    "import matplotlib.pyplot as pl\n",
    "import pandas as pd\n",
    "import random, os\n",
    "import librosa\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import torch.nn.functional as F\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision.utils import make_grid\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    " #from torch.utils.data import random_split\n",
    "from torch.optim import lr_scheduler\n",
    "import time\n",
    "import tqdm\n",
    "from evaluate_model import *\n",
    "from my_helper_code import *\n",
    "from helper_code import *\n",
    "#from model import *\n",
    "from torch.nn import  Conv1d,Conv2d\n",
    "from dataset import My_Dataset_separate\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed: int):\n",
    "    \n",
    "    \n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    \n",
    "seed_everything(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('dx_mapping_scored.csv')\n",
    "labels = df['SNOMEDCTCode'].values\n",
    "labels = [str(i) for i in labels]\n",
    "classes = list(labels)\n",
    "\n",
    "test_data_directory = 'test_data'\n",
    "test_header_files, test_recording_files = find_challenge_files(test_data_directory)\n",
    "test_num_recordings = len(test_recording_files)\n",
    "training_classes = list(labels)\n",
    "test_classes = list(labels)\n",
    "num_classes = len(classes)\n",
    "num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_length = 4096\n",
    "twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')\n",
    "test_dataset = My_Dataset_separate(test_header_files, test_recording_files, twelve_leads,sample_length,test_classes)\n",
    "\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=1,\n",
    "                             shuffle=False, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from evaluate_model import *\n",
    "th = [.2,.3]\n",
    "\n",
    "def one_zero(data,th):\n",
    "    result = []\n",
    "    for i,j in enumerate(data):\n",
    "        if (data[i] > th) :\n",
    "            result.append(1)\n",
    "        else :\n",
    "            result.append(0)\n",
    "    return result\n",
    "\n",
    "def generating_output_files(models,weight, test_classes,test_loader, output_directory,th) :\n",
    "    for model in models :\n",
    "        model.eval()\n",
    "    for inputs, target, header_files in tqdm.tqdm(test_loader):\n",
    "        header_files = header_files[0]        \n",
    "        input_var = torch.autograd.Variable(inputs.cuda().float())\n",
    "        target_var = torch.autograd.Variable(target.cuda().float())\n",
    "        output = 0\n",
    "        if weight :\n",
    "            for i,model in enumerate(models):\n",
    "                \n",
    "                output = output+model(input_var)*(weight[i])\n",
    "            output = output/sum(weight)\n",
    "        else :\n",
    "            output = models[0](input_var)\n",
    "        \n",
    "        probabilities = output.detach().cpu().numpy().squeeze()\n",
    "        labels = one_zero(probabilities,th)\n",
    "        header = load_header(header_files)\n",
    "        recording_id = get_recording_id(header)\n",
    "        head, tail = os.path.split(header_files)\n",
    "        root, extension = os.path.splitext(tail)\n",
    "        output_file = os.path.join(output_directory, root + '.csv')\n",
    "        save_outputs(output_file, recording_id, test_classes, labels, probabilities)\n",
    "        \n",
    "        \n",
    "def test_model(model,weight, test_classes, test_loader, label_directory, output_directory,th) : \n",
    "    generating_output_files(model,weight, test_classes, test_loader, output_directory,th)\n",
    "    classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = evaluate_model('test_data','test_outputs')\n",
    "    print(f'Auroc : {auroc}')\n",
    "    print(f'Accuracy : {accuracy}')\n",
    "    print(f'f1 {f_measure}')\n",
    "    print(f'challenge_metric{challenge_metric}')\n",
    "    return\t classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix\n",
    "# tensor([[[0.9949],\n",
    "#          [0.1105],\n",
    "#          [0.9002],\n",
    "#          [0.9921]]], device='cuda:0', grad_fn=<SigmoidBackward0>)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                 | 2/4413 [00:00<04:40, 15.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for th 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [02:16<00:00, 32.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                 | 2/4413 [00:00<04:27, 16.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done.\n",
      "Auroc : 0.9760846512615529\n",
      "Accuracy : 0.5563108996147745\n",
      "f1 0.642042988659263\n",
      "challenge_metric0.7373136450157349\n",
      "for th 0.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [02:18<00:00, 31.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n",
      "Done.\n",
      "Auroc : 0.9760846512615529\n",
      "Accuracy : 0.6025379560389758\n",
      "f1 0.6498674132966807\n",
      "challenge_metric0.7153049900471004\n"
     ]
    }
   ],
   "source": [
    "\n",
    "weight = None\n",
    "label_directory = 'test_data'\n",
    "output_directory='test_outputs'\n",
    "models = [model8]\n",
    "cf_matrix_list = []\n",
    "for i in th :\n",
    "    print(f'for th {i}')\n",
    "    classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = test_model(models,weight, test_classes, test_loader, label_directory, output_directory,i) \n",
    "    cf_matrix_list.append(cf_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for th 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:26<00:00,  8.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done.\n",
      "Auroc : 0.9791910648836137\n",
      "Accuracy : 0.5859959211420802\n",
      "f1 0.6686548325577337\n",
      "challenge_metric0.7509528789248933\n",
      "for th 0.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:24<00:00,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done.\n",
      "Auroc : 0.9791910648836137\n",
      "Accuracy : 0.6317697711307501\n",
      "f1 0.6561949164910517\n",
      "challenge_metric0.7281301900704895\n",
      "for th 0.4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:24<00:00,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n",
      "Done.\n",
      "Auroc : 0.9791910648836137\n",
      "Accuracy : 0.6501246317697711\n",
      "f1 0.6374265369708496\n",
      "challenge_metric0.6994859830000091\n"
     ]
    }
   ],
   "source": [
    "model = [model2, model3, model4, model5,model7,model8]\n",
    "weight = [1,1,1,1,1,1]\n",
    "label_directory = 'test_data'\n",
    "output_directory='test_outputs'\n",
    "for i in th:\n",
    "    print(f'for th {i}')\n",
    "    cf_matrix = test_model(model,weight, test_classes, test_loader, label_directory, output_directory,i) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ensemble(nn.Module):\n",
    "    def __init__(self, n_models) :\n",
    "        super(ensemble, self).__init__()\n",
    "        self.n_models = n_models\n",
    "        self.conv1d = nn.Conv1d(self.n_models,1,1)\n",
    "        \n",
    "    def forward(self,x) :\n",
    "        return torch.sigmoid(self.conv1d(x)).squeeze()\n",
    "def compute_accuracy(labels, outputs):\n",
    "    num_recordings, num_classes = np.shape(labels)\n",
    "\n",
    "    num_correct_recordings = 0\n",
    "    for i in range(num_recordings):\n",
    "        if np.all(labels[i, :]==outputs[i, :]):\n",
    "            num_correct_recordings += 1\n",
    "\n",
    "    return float(num_correct_recordings) / float(num_recordings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_files, output_files = find_challenge_files_eval('test_data','test_outputs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/99\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3it [00:00,  5.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 0 totalloss: 0.001 running_acc: 0.000 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "503it [00:39, 12.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 500 totalloss: 0.424 running_acc: 0.574 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1003it [01:18, 12.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1000 totalloss: 0.250 running_acc: 0.486 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1503it [01:57, 13.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1500 totalloss: 0.143 running_acc: 0.646 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2003it [03:16, 12.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2000 totalloss: 0.100 running_acc: 0.658 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2503it [03:55, 13.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2500 totalloss: 0.097 running_acc: 0.608 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3003it [04:34, 12.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3000 totalloss: 0.117 running_acc: 0.396 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3503it [05:15, 12.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3500 totalloss: 0.037 running_acc: 0.820 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4003it [05:54, 12.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 4000 totalloss: 0.043 running_acc: 0.782 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4413it [06:41, 11.00it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoc acc: 0.628  \n",
      "Epoch 1/99\n",
      "----------\n",
      "step: 0 totalloss: 0.000 running_acc: 0.002 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "502it [00:39, 12.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 500 totalloss: 0.063 running_acc: 0.678 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1002it [01:18, 13.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1000 totalloss: 0.110 running_acc: 0.496 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1502it [01:56, 12.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1500 totalloss: 0.071 running_acc: 0.648 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2003it [03:15, 12.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2000 totalloss: 0.076 running_acc: 0.616 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2501it [03:53, 12.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2500 totalloss: 0.107 running_acc: 0.576 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3003it [04:33, 12.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3000 totalloss: 0.149 running_acc: 0.398 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3503it [05:12, 12.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3500 totalloss: 0.041 running_acc: 0.800 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4003it [05:50, 13.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 4000 totalloss: 0.049 running_acc: 0.798 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4413it [06:38, 11.07it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoc acc: 0.635  \n",
      "Epoch 2/99\n",
      "----------\n",
      "step: 0 totalloss: 0.000 running_acc: 0.002 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "502it [00:39, 12.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 500 totalloss: 0.089 running_acc: 0.678 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1002it [01:17, 12.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1000 totalloss: 0.166 running_acc: 0.484 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1502it [01:57, 12.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1500 totalloss: 0.103 running_acc: 0.662 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2003it [03:14, 12.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2000 totalloss: 0.101 running_acc: 0.646 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2503it [03:54, 12.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2500 totalloss: 0.154 running_acc: 0.574 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3001it [04:33, 12.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3000 totalloss: 0.226 running_acc: 0.320 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3503it [05:12, 13.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3500 totalloss: 0.064 running_acc: 0.772 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4003it [05:52, 12.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 4000 totalloss: 0.063 running_acc: 0.800 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4413it [06:39, 11.05it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoc acc: 0.625  \n",
      "Epoch 3/99\n",
      "----------\n",
      "step: 0 totalloss: 0.000 running_acc: 0.002 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "502it [00:38, 12.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 500 totalloss: 0.121 running_acc: 0.646 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1002it [01:17, 12.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1000 totalloss: 0.226 running_acc: 0.482 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1502it [01:55, 12.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1500 totalloss: 0.143 running_acc: 0.660 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2003it [03:12, 12.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2000 totalloss: 0.143 running_acc: 0.660 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2503it [03:50, 13.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2500 totalloss: 0.211 running_acc: 0.596 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3003it [04:27, 13.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3000 totalloss: 0.270 running_acc: 0.384 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3503it [05:06, 13.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3500 totalloss: 0.084 running_acc: 0.776 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4003it [05:44, 13.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 4000 totalloss: 0.082 running_acc: 0.798 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4413it [06:30, 11.31it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoc acc: 0.634  \n",
      "Epoch 4/99\n",
      "----------\n",
      "step: 0 totalloss: 0.000 running_acc: 0.002 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "502it [00:38, 13.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 500 totalloss: 0.157 running_acc: 0.632 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1002it [01:16, 13.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1000 totalloss: 0.288 running_acc: 0.460 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1502it [01:53, 13.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1500 totalloss: 0.191 running_acc: 0.664 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2003it [03:12, 12.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2000 totalloss: 0.192 running_acc: 0.656 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2503it [03:51, 12.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 2500 totalloss: 0.264 running_acc: 0.590 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3001it [04:35, 11.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3000 totalloss: 0.331 running_acc: 0.398 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3502it [05:22, 11.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 3500 totalloss: 0.100 running_acc: 0.784 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4002it [06:07, 12.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 4000 totalloss: 0.139 running_acc: 0.806 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4303it [06:37, 10.83it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-62-0c80d7842741>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     24\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m                         \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m                             \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_var\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     27\u001b[0m                             \u001b[1;32mif\u001b[0m \u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m==\u001b[0m\u001b[1;36m30\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     28\u001b[0m                                 \u001b[0moutput_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\MHA_resnet.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     41\u001b[0m         \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresidual_unit_4\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     42\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     44\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_state)\u001b[0m\n\u001b[0;32m     78\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     79\u001b[0m         \u001b[0mhidden_state\u001b[0m \u001b[1;33m=\u001b[0m  \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     81\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     82\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m  \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     78\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     79\u001b[0m         \u001b[0mhidden_state\u001b[0m \u001b[1;33m=\u001b[0m  \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     81\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     82\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m  \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[0mv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 63\u001b[1;33m         \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscaled_dot_product_attention\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     64\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     65\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mscaled_dot_product_attention\u001b[1;34m(self, q, k, v)\u001b[0m\n\u001b[0;32m     53\u001b[0m         \u001b[0mweights\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 55\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "models = [model2, model3, model4, model5,model7]\n",
    "num_epochs = 100\n",
    "learning_rate = 0.001\n",
    "for i in models:\n",
    "    i.eval()\n",
    "net = ensemble(5).cuda()\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)\n",
    "for epoch in range(0,num_epochs):\n",
    "        since = time.time()\n",
    "        print('Epoch {}/{}'.format(epoch , num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "    \n",
    "        net.train() \n",
    "        epoch_loss = 0.0\n",
    "        epoch_acc = 0.0\n",
    "        running_acc = 0.0\n",
    "        running_loss = 0.0\n",
    "        for i,(inputs, target, header_files) in tqdm.tqdm(enumerate(test_loader)):\n",
    "                header_files = header_files[0]        \n",
    "                input_var = torch.autograd.Variable(inputs.cuda().float())\n",
    "                target_var = torch.autograd.Variable(target.cuda().float())\n",
    "                output_list = []\n",
    "                for model in models:\n",
    "                        with torch.no_grad():\n",
    "                            a = model(input_var)\n",
    "                            if a.shape[0] ==30 :\n",
    "                                output_list.append(a.unsqueeze(dim = 0))\n",
    "                            else :\n",
    "                                output_list.append(a)\n",
    "\n",
    "                output = torch.cat(output_list,dim=0).unsqueeze(dim=0)\n",
    "                result = net(output).unsqueeze(dim=0)\n",
    "                \n",
    "                loss = criterion(result, target_var)\n",
    "                label = result.ge(0.5).float()\n",
    "                np_label = label.detach().cpu().numpy()\n",
    "                np_target = target_var.detach().cpu().numpy()\n",
    "                if np.all(np_label[0,:]==np_target[0,:]):\n",
    "                    epoch_acc += 1\n",
    "                    running_acc +=1 \n",
    "                    \n",
    "                running_loss += loss.data.item()\n",
    "                if (i%500) == 0:\n",
    "                \n",
    "                    print('step: {} totalloss: {loss:.3f} running_acc: {running_acc:.3f} '.format(i, loss = running_loss/500, running_acc = running_acc/500))\n",
    "                    \n",
    "                    running_acc = 0.0\n",
    "                    running_loss = 0.0\n",
    "        \n",
    "                loss.backward() \n",
    "                optimizer.step()\n",
    "\n",
    "        print('epoc acc: {epoch_acc:.3f}  '.format(epoch_acc = epoch_acc/4413) )      \n",
    "        PATH = f'log/weight-{epoch}-acc-{round(epoch_acc,3)} .pth' \n",
    "        torch.save(model.state_dict(), PATH)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_file = 'weights.csv'\n",
    "sinus_rhythm = set(['426783006'])\n",
    "classes, weights = load_weights(weights_file)\n",
    "labels = load_labels(label_files, classes)\n",
    "\n",
    "binary_outputs, scalar_outputs = load_classifier_outputs(output_files, classes)\n",
    "cf_matrix = compute_modified_confusion_matrix(labels,binary_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for j in range(26) :\n",
    "    cf_matrix[j,:] = cf_matrix[j,:]/sum(cf_matrix[j,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "fig, ax = plt.subplots(figsize=(50,50))  \n",
    "ax = sns.heatmap(cf_matrix, annot=True, cmap='Blues')\n",
    "\n",
    "ax.set_title('Confusion Matrix\\n\\n');\n",
    "ax.set_xlabel('\\nPredicted Values')\n",
    "ax.set_ylabel('Actual Values ');\n",
    "\n",
    "## Ticket labels - List must be in alphabetical order\n",
    "# ax.xaxis.set_ticklabels(['False])\n",
    "# ax.yaxis.set_ticklabels(['False','True'])\n",
    "\n",
    "## Display the visualization of the Confusion Matrix.\n",
    "plt.savefig('save_as_a_png.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import seaborn as sns\n",
    "fig, ax = plt.subplots(figsize=(20,20))  \n",
    "ax = sns.heatmap(weights, annot=True, cmap='Blues')\n",
    "\n",
    "ax.set_title('Dx Scores\\n\\n');\n",
    "ax.set_xlabel('\\n Predicted Diseases');\n",
    "ax.set_ylabel('Actual Values ');\n",
    "\n",
    "## Ticket labels - List must be in alphabetical order\n",
    "# ax.xaxis.set_ticklabels(['False])\n",
    "# ax.yaxis.set_ticklabels(['False','True'])\n",
    "\n",
    "## Display the visualization of the Confusion Matrix.\n",
    "plt.savefig('Dx scores.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}