[f6de96]: / .ipynb_checkpoints / Ensemble-checkpoint.ipynb

Download this file

601 lines (600 with data), 29.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from models.densenet import DenseNet\n",
    "from models.MHA_resnet import ResidualNetwork\n",
    "from models.resnet_bda import ResidualNetwork_classifier\n",
    "from models.resnet_rnn import ResidualNetwork_lstm\n",
    "from models.dense_rnn import  DenseNet_rnn\n",
    "from models.dense_bda import DenseNet_bda\n",
    "from models.dense_cam import DenseNet_cam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model7 = DenseNet_cam(100,30).cuda()\n",
    "model7.load_state_dict(torch.load('model_weights/dense_cam_17-aur-0.972-auc 0.623 .pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model1 = DenseNet(100,30).cuda()\n",
    "model1.load_state_dict(torch.load('model_weights/densenet-32-aur-0.964-auc 0.615 .pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model5 = DenseNet_rnn(100,30).cuda()\n",
    "model5.load_state_dict(torch.load('model_weights/dense-41-aur-0.97-auc 0.64 .pth'))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=[], unexpected_keys=['classifier.linear.weight', 'classifier.linear.bias', 'classifier.attention.fc1.weight', 'classifier.attention.fc1.bias', 'classifier.attention.fc2.weight', 'classifier.attention.fc2.bias'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model6 = DenseNet_bda(100,30).cuda()\n",
    "model6.load_state_dict(torch.load('model_weights/dense_bda-1-aur-0.967-auc 0.594 .pth'),strict=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from easydict import EasyDict as edict\n",
    "config = edict({'hidden_size':320, 'num_attention_heads':8})  \n",
    "model2 = ResidualNetwork(12,config).cuda()\n",
    "model2.load_state_dict(torch.load('model_weights/Multi_head_Arrythmia detection-13-aur-0.975-auc 0.62  (1).pth'),strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=[], unexpected_keys=['classifier.linear.weight', 'classifier.linear.bias', 'classifier.attention.fc1.weight', 'classifier.attention.fc1.bias', 'classifier.attention.fc2.weight', 'classifier.attention.fc2.bias'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model3 = ResidualNetwork_classifier(12).cuda()\n",
    "model3.load_state_dict(torch.load('model_weights/RESNET_two_way-41-aur-0.972-auc 0.608 .pth'),strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model4 = ResidualNetwork_lstm(12).cuda()\n",
    "model4.load_state_dict(torch.load('model_weights/F_CONV_LSTM-25-aur-0.966-auc 0.631 .pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import numpy as np, os, sys, joblib\n",
    "import matplotlib.pyplot as pl\n",
    "import pandas as pd\n",
    "import random, os\n",
    "import librosa\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import torch.nn.functional as F\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision.utils import make_grid\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    " #from torch.utils.data import random_split\n",
    "from torch.optim import lr_scheduler\n",
    "import time\n",
    "import tqdm\n",
    "from evaluate_model import *\n",
    "from my_helper_code import *\n",
    "from helper_code import *\n",
    "#from model import *\n",
    "from torch.nn import  Conv1d,Conv2d\n",
    "from dataset import My_Dataset_separate\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed: int):\n",
    "    \n",
    "    \n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    \n",
    "seed_everything(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('dx_mapping_scored.csv')\n",
    "labels = df['SNOMEDCTCode'].values\n",
    "labels = [str(i) for i in labels]\n",
    "classes = list(labels)\n",
    "\n",
    "test_data_directory = 'test_data'\n",
    "test_header_files, test_recording_files = find_challenge_files(test_data_directory)\n",
    "test_num_recordings = len(test_recording_files)\n",
    "training_classes = list(labels)\n",
    "test_classes = list(labels)\n",
    "num_classes = len(classes)\n",
    "num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_length = 4096\n",
    "twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')\n",
    "test_dataset = My_Dataset_separate(test_header_files, test_recording_files, twelve_leads,sample_length,test_classes)\n",
    "\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=1,\n",
    "                             shuffle=False, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from evaluate_model import *\n",
    "th = [.2,.3,.4]\n",
    "\n",
    "def one_zero(data,th):\n",
    "    result = []\n",
    "    for i,j in enumerate(data):\n",
    "        if (data[i] > th) :\n",
    "            result.append(1)\n",
    "        else :\n",
    "            result.append(0)\n",
    "    return result\n",
    "\n",
    "def generating_output_files(models,weight, test_classes,test_loader, output_directory,th) :\n",
    "    for model in models :\n",
    "        model.eval()\n",
    "    for inputs, target, header_files in tqdm.tqdm(test_loader):\n",
    "        header_files = header_files[0]        \n",
    "        input_var = torch.autograd.Variable(inputs.cuda().float())\n",
    "        target_var = torch.autograd.Variable(target.cuda().float())\n",
    "        output = 0\n",
    "        if weight :\n",
    "            for i,model in enumerate(models):\n",
    "                output = output+model(input_var)*(weight[i])\n",
    "            output = output/sum(weight)\n",
    "        else :\n",
    "            output = models[0](input_var)\n",
    "        \n",
    "        probabilities = output.detach().cpu().numpy().squeeze()\n",
    "        labels = one_zero(probabilities,th)\n",
    "        header = load_header(header_files)\n",
    "        recording_id = get_recording_id(header)\n",
    "        head, tail = os.path.split(header_files)\n",
    "        root, extension = os.path.splitext(tail)\n",
    "        output_file = os.path.join(output_directory, root + '.csv')\n",
    "        save_outputs(output_file, recording_id, test_classes, labels, probabilities)\n",
    "        \n",
    "        \n",
    "def test_model(model,weight, test_classes, test_loader, label_directory, output_directory,th) : \n",
    "    generating_output_files(model,weight, test_classes, test_loader, output_directory,th)\n",
    "    classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = evaluate_model('test_data','test_outputs')\n",
    "    print(f'Auroc : {auroc}')\n",
    "    print(f'Accuracy : {accuracy}')\n",
    "    print(f'f1 {f_measure}')\n",
    "    print(f'challenge_metric{challenge_metric}')\n",
    "    return\t classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix\n",
    "# tensor([[[0.9949],\n",
    "#          [0.1105],\n",
    "#          [0.9002],\n",
    "#          [0.9921]]], device='cuda:0', grad_fn=<SigmoidBackward0>)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for th 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:02<?, ?it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (320x1 and 320x30)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-d1e75f711adc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mth\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'for th {i}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m     \u001b[0mclasses\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauroc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauprc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauroc_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauprc_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_measure\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_measure_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mchallenge_metric\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcf_matrix\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodels\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel_directory\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m.3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-7-f81113d8e1b5>\u001b[0m in \u001b[0;36mtest_model\u001b[1;34m(model, weight, test_classes, test_loader, label_directory, output_directory, th)\u001b[0m\n\u001b[0;32m     37\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     38\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mtest_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel_directory\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mth\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 39\u001b[1;33m     \u001b[0mgenerating_output_files\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mth\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     40\u001b[0m     \u001b[0mclasses\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauroc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauprc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauroc_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mauprc_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_measure\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_measure_classes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mchallenge_metric\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcf_matrix\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mevaluate_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'test_data'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m'test_outputs'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     41\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'Auroc : {auroc}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-7-f81113d8e1b5>\u001b[0m in \u001b[0;36mgenerating_output_files\u001b[1;34m(models, weight, test_classes, test_loader, output_directory, th)\u001b[0m\n\u001b[0;32m     24\u001b[0m             \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m         \u001b[1;32melse\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m             \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_var\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     27\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     28\u001b[0m         \u001b[0mprobabilities\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mG:\\Thesis-Git\\models\\dense_cam.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     78\u001b[0m         \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     79\u001b[0m         \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m         \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     81\u001b[0m         \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     82\u001b[0m          \u001b[1;31m# print(out.shape)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    102\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 103\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    104\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    105\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mlinear\u001b[1;34m(input, weight, bias)\u001b[0m\n\u001b[0;32m   1846\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1847\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1848\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1849\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1850\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (320x1 and 320x30)"
     ]
    }
   ],
   "source": [
    "\n",
    "weight = None\n",
    "label_directory = 'test_data'\n",
    "output_directory='test_outputs'\n",
    "models = [model7]\n",
    "cf_matrix_list = []\n",
    "for i in th :\n",
    "    print(f'for th {i}')\n",
    "    classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = test_model(models,weight, test_classes, test_loader, label_directory, output_directory,.3) \n",
    "    cf_matrix_list.append(cf_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/4413 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for th 0.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [10:37<00:00,  6.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights...\n",
      "Loading label and output files...\n",
      "Evaluating model...\n",
      "- AUROC and AUPRC...\n",
      "- Accuracy...\n",
      "- F-measure...\n",
      "- Challenge metric...\n",
      "Done.\n",
      "Auroc : 0.9775901025836848\n",
      "Accuracy : 0.6322229775662814\n",
      "f1 0.6577955757512897\n",
      "challenge_metric0.726762723432914\n"
     ]
    }
   ],
   "source": [
    "# model = [model1, model2, model3, model4,model5,model6]\n",
    "# weight = [1,1,1,1,1,1]\n",
    "# label_directory = 'test_data'\n",
    "# output_directory='test_outputs'\n",
    "# for i in th:\n",
    "#     print(f'for th {i}')\n",
    "#     cf_matrix = test_model(model,weight, test_classes, test_loader, label_directory, output_directory,i) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ensemble(nn.Module):\n",
    "    def __init__(self, n_models) :\n",
    "        super(ensemble, self).__init__()\n",
    "        self.n_models = n_models\n",
    "        self.conv1d = nn.Conv1d(self.n_models,1,1)\n",
    "        \n",
    "    def forward(self,x) :\n",
    "        return torch.sigmoid(self.conv1d(x)).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_files, output_files = find_challenge_files_eval('test_data','test_outputs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for inputs, target, header_files in tqdm.tqdm(test_loader):\n",
    "        header_files = header_files[0]        \n",
    "        input_var = torch.autograd.Variable(inputs.cuda().float())\n",
    "        target_var = torch.autograd.Variable(target.cuda().float())\n",
    "        output = []\n",
    "        if weight :\n",
    "            for i,model in enumerate(models):\n",
    "                with torch.no_grad:\n",
    "                    output.append(model(input_var))\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_file = 'weights.csv'\n",
    "sinus_rhythm = set(['426783006'])\n",
    "classes, weights = load_weights(weights_file)\n",
    "labels = load_labels(label_files, classes)\n",
    "\n",
    "binary_outputs, scalar_outputs = load_classifier_outputs(output_files, classes)\n",
    "cf_matrix = compute_modified_confusion_matrix(labels,binary_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "for j in range(26) :\n",
    "    cf_matrix[j,:] = cf_matrix[j,:]/sum(cf_matrix[j,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "fig, ax = plt.subplots(figsize=(50,50))  \n",
    "ax = sns.heatmap(cf_matrix, annot=True, cmap='Blues')\n",
    "\n",
    "ax.set_title('Confusion Matrix\\n\\n');\n",
    "ax.set_xlabel('\\nPredicted Values')\n",
    "ax.set_ylabel('Actual Values ');\n",
    "\n",
    "## Ticket labels - List must be in alphabetical order\n",
    "# ax.xaxis.set_ticklabels(['False])\n",
    "# ax.yaxis.set_ticklabels(['False','True'])\n",
    "\n",
    "## Display the visualization of the Confusion Matrix.\n",
    "plt.savefig('save_as_a_png.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import seaborn as sns\n",
    "fig, ax = plt.subplots(figsize=(20,20))  \n",
    "ax = sns.heatmap(weights, annot=True, cmap='Blues')\n",
    "\n",
    "ax.set_title('Dx Scores\\n\\n');\n",
    "ax.set_xlabel('\\n Predicted Diseases');\n",
    "ax.set_ylabel('Actual Values ');\n",
    "\n",
    "## Ticket labels - List must be in alphabetical order\n",
    "# ax.xaxis.set_ticklabels(['False])\n",
    "# ax.yaxis.set_ticklabels(['False','True'])\n",
    "\n",
    "## Display the visualization of the Confusion Matrix.\n",
    "plt.savefig('Dx scores.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}