1487 lines (1486 with data), 50.6 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from models.densenet import DenseNet\n",
"from models.MHA_resnet import ResidualNetwork\n",
"from models.resnet_bda import ResidualNetwork_classifier\n",
"from models.resnet_rnn import ResidualNetwork_lstm\n",
"from models.dense_rnn import DenseNet_rnn\n",
"from models.dense_bda import DenseNet_bda\n",
"from models.dense_cam import DenseNet_cam\n",
"from models.dense_mha import DenseNet_mha\n",
"# from models.xresnet1d import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# model1 = DenseNet(100,30).cuda()\n",
"# model1.load_state_dict(torch.load('model_weights/densenet-32-aur-0.964-auc 0.615 .pth'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from easydict import EasyDict as edict\n",
"config = edict({'hidden_size':320, 'num_attention_heads':8}) \n",
"model2 = ResidualNetwork(12,config).cuda()\n",
"model2.load_state_dict(torch.load('model_weights/Multi_head_Arrythmia detection-13-aur-0.975-auc 0.62 (1).pth'), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"_IncompatibleKeys(missing_keys=[], unexpected_keys=['classifier.linear.weight', 'classifier.linear.bias', 'classifier.attention.fc1.weight', 'classifier.attention.fc1.bias', 'classifier.attention.fc2.weight', 'classifier.attention.fc2.bias'])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model3 = ResidualNetwork_classifier(12).cuda()\n",
"model3.load_state_dict(torch.load('model_weights/RESNET_two_way-41-aur-0.972-auc 0.608 .pth'),strict=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model4 = ResidualNetwork_lstm(12,bidirectional = False, rnn = 'gru').cuda()\n",
"model4.load_state_dict(torch.load('model_weights/F_CONV_GRU-24-aur-0.968-auc 0.64 .pth'))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model5 = DenseNet_rnn(100,30).cuda()\n",
"model5.load_state_dict(torch.load('model_weights/dense-41-aur-0.97-auc 0.64 .pth'))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# model6 = DenseNet_bda(100,30).cuda()\n",
"# model6.load_state_dict(torch.load('model_weights/dense_bda-1-aur-0.967-auc 0.594 .pth'),strict=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
" model7 = DenseNet_cam(100,30).cuda()\n",
"model7.load_state_dict(torch.load('model_weights/dense_cam_17-aur-0.972-auc 0.623 .pth'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from easydict import EasyDict as edict\n",
"config = edict({'hidden_size':320, 'num_attention_heads':8}) \n",
"model8 = DenseNet_mha(100,30,config).cuda()\n",
"model8.load_state_dict(torch.load('model_weights/dense-24-aur-0.976-auc 0.634 .pth'), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import torch.nn as nn\n",
"import numpy as np, os, sys, joblib\n",
"import matplotlib.pyplot as pl\n",
"import pandas as pd\n",
"import random, os\n",
"import librosa\n",
"from torch.utils.data.dataset import Dataset\n",
"import torch.nn.functional as F\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.utils import make_grid\n",
"from torch.utils.data.dataloader import DataLoader\n",
" #from torch.utils.data import random_split\n",
"from torch.optim import lr_scheduler\n",
"import time\n",
"import tqdm\n",
"from evaluate_model import *\n",
"from my_helper_code import *\n",
"from helper_code import *\n",
"#from model import *\n",
"from torch.nn import Conv1d,Conv2d\n",
"from dataset import My_Dataset_separate\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def seed_everything(seed: int):\n",
" \n",
" \n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed(seed)\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.backends.cudnn.benchmark = True\n",
" \n",
"seed_everything(0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"30"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('dx_mapping_scored.csv')\n",
"labels = df['SNOMEDCTCode'].values\n",
"labels = [str(i) for i in labels]\n",
"classes = list(labels)\n",
"\n",
"test_data_directory = 'test_data'\n",
"test_header_files, test_recording_files = find_challenge_files(test_data_directory)\n",
"test_num_recordings = len(test_recording_files)\n",
"training_classes = list(labels)\n",
"test_classes = list(labels)\n",
"num_classes = len(classes)\n",
"num_classes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"sample_length = 4096\n",
"twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')\n",
"test_dataset = My_Dataset_separate(test_header_files, test_recording_files, twelve_leads,sample_length,test_classes)\n",
"\n",
"\n",
"test_loader = DataLoader(dataset=test_dataset, batch_size=1,\n",
" shuffle=False, num_workers=0, pin_memory=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from evaluate_model import *\n",
"th = [.2,.3]\n",
"\n",
"def one_zero(data,th):\n",
" result = []\n",
" for i,j in enumerate(data):\n",
" if (data[i] > th) :\n",
" result.append(1)\n",
" else :\n",
" result.append(0)\n",
" return result\n",
"\n",
"def generating_output_files(models,weight, test_classes,test_loader, output_directory,th) :\n",
" for model in models :\n",
" model.eval()\n",
" for inputs, target, header_files in tqdm.tqdm(test_loader):\n",
" header_files = header_files[0] \n",
" input_var = torch.autograd.Variable(inputs.cuda().float())\n",
" target_var = torch.autograd.Variable(target.cuda().float())\n",
" output = 0\n",
" if weight :\n",
" for i,model in enumerate(models):\n",
" \n",
" output = output+model(input_var)*(weight[i])\n",
" output = output/sum(weight)\n",
" else :\n",
" output = models[0](input_var)\n",
" \n",
" probabilities = output.detach().cpu().numpy().squeeze()\n",
" labels = one_zero(probabilities,th)\n",
" header = load_header(header_files)\n",
" recording_id = get_recording_id(header)\n",
" head, tail = os.path.split(header_files)\n",
" root, extension = os.path.splitext(tail)\n",
" output_file = os.path.join(output_directory, root + '.csv')\n",
" save_outputs(output_file, recording_id, test_classes, labels, probabilities)\n",
" \n",
" \n",
"def test_model(model,weight, test_classes, test_loader, label_directory, output_directory,th) : \n",
" generating_output_files(model,weight, test_classes, test_loader, output_directory,th)\n",
" classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = evaluate_model('test_data','test_outputs')\n",
" print(f'Auroc : {auroc}')\n",
" print(f'Accuracy : {accuracy}')\n",
" print(f'f1 {f_measure}')\n",
" print(f'challenge_metric{challenge_metric}')\n",
" return\t classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix\n",
"# tensor([[[0.9949],\n",
"# [0.1105],\n",
"# [0.9002],\n",
"# [0.9921]]], device='cuda:0', grad_fn=<SigmoidBackward0>) "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 2/4413 [00:00<04:40, 15.73it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"for th 0.2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [02:16<00:00, 32.24it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights...\n",
"Loading label and output files...\n",
"Evaluating model...\n",
"- AUROC and AUPRC...\n",
"- Accuracy...\n",
"- F-measure...\n",
"- Challenge metric...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 2/4413 [00:00<04:27, 16.51it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Auroc : 0.9760846512615529\n",
"Accuracy : 0.5563108996147745\n",
"f1 0.642042988659263\n",
"challenge_metric0.7373136450157349\n",
"for th 0.3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [02:18<00:00, 31.85it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights...\n",
"Loading label and output files...\n",
"Evaluating model...\n",
"- AUROC and AUPRC...\n",
"- Accuracy...\n",
"- F-measure...\n",
"- Challenge metric...\n",
"Done.\n",
"Auroc : 0.9760846512615529\n",
"Accuracy : 0.6025379560389758\n",
"f1 0.6498674132966807\n",
"challenge_metric0.7153049900471004\n"
]
}
],
"source": [
"\n",
"weight = None\n",
"label_directory = 'test_data'\n",
"output_directory='test_outputs'\n",
"models = [model8]\n",
"cf_matrix_list = []\n",
"for i in th :\n",
" print(f'for th {i}')\n",
" classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric,cf_matrix = test_model(models,weight, test_classes, test_loader, label_directory, output_directory,i) \n",
" cf_matrix_list.append(cf_matrix)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/4413 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"for th 0.2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:26<00:00, 8.72it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights...\n",
"Loading label and output files...\n",
"Evaluating model...\n",
"- AUROC and AUPRC...\n",
"- Accuracy...\n",
"- F-measure...\n",
"- Challenge metric...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/4413 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Auroc : 0.9791910648836137\n",
"Accuracy : 0.5859959211420802\n",
"f1 0.6686548325577337\n",
"challenge_metric0.7509528789248933\n",
"for th 0.3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:24<00:00, 8.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights...\n",
"Loading label and output files...\n",
"Evaluating model...\n",
"- AUROC and AUPRC...\n",
"- Accuracy...\n",
"- F-measure...\n",
"- Challenge metric...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/4413 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Auroc : 0.9791910648836137\n",
"Accuracy : 0.6317697711307501\n",
"f1 0.6561949164910517\n",
"challenge_metric0.7281301900704895\n",
"for th 0.4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████| 4413/4413 [08:24<00:00, 8.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights...\n",
"Loading label and output files...\n",
"Evaluating model...\n",
"- AUROC and AUPRC...\n",
"- Accuracy...\n",
"- F-measure...\n",
"- Challenge metric...\n",
"Done.\n",
"Auroc : 0.9791910648836137\n",
"Accuracy : 0.6501246317697711\n",
"f1 0.6374265369708496\n",
"challenge_metric0.6994859830000091\n"
]
}
],
"source": [
"model = [model2, model3, model4, model5,model7,model8]\n",
"weight = [1,1,1,1,1,1]\n",
"label_directory = 'test_data'\n",
"output_directory='test_outputs'\n",
"for i in th:\n",
" print(f'for th {i}')\n",
" cf_matrix = test_model(model,weight, test_classes, test_loader, label_directory, output_directory,i) "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"class ensemble(nn.Module):\n",
" def __init__(self, n_models) :\n",
" super(ensemble, self).__init__()\n",
" self.n_models = n_models\n",
" self.conv1d = nn.Conv1d(self.n_models,1,1)\n",
" \n",
" def forward(self,x) :\n",
" return torch.sigmoid(self.conv1d(x)).squeeze()\n",
"def compute_accuracy(labels, outputs):\n",
" num_recordings, num_classes = np.shape(labels)\n",
"\n",
" num_correct_recordings = 0\n",
" for i in range(num_recordings):\n",
" if np.all(labels[i, :]==outputs[i, :]):\n",
" num_correct_recordings += 1\n",
"\n",
" return float(num_correct_recordings) / float(num_recordings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"label_files, output_files = find_challenge_files_eval('test_data','test_outputs')"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/99\n",
"----------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3it [00:00, 5.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 0 totalloss: 0.001 running_acc: 0.000 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"503it [00:39, 12.62it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 500 totalloss: 0.424 running_acc: 0.574 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1003it [01:18, 12.64it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1000 totalloss: 0.250 running_acc: 0.486 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1503it [01:57, 13.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1500 totalloss: 0.143 running_acc: 0.646 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2003it [03:16, 12.53it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2000 totalloss: 0.100 running_acc: 0.658 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2503it [03:55, 13.19it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2500 totalloss: 0.097 running_acc: 0.608 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3003it [04:34, 12.54it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3000 totalloss: 0.117 running_acc: 0.396 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3503it [05:15, 12.68it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3500 totalloss: 0.037 running_acc: 0.820 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4003it [05:54, 12.44it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 4000 totalloss: 0.043 running_acc: 0.782 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4413it [06:41, 11.00it/s]\n",
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoc acc: 0.628 \n",
"Epoch 1/99\n",
"----------\n",
"step: 0 totalloss: 0.000 running_acc: 0.002 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"502it [00:39, 12.85it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 500 totalloss: 0.063 running_acc: 0.678 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1002it [01:18, 13.22it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1000 totalloss: 0.110 running_acc: 0.496 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1502it [01:56, 12.51it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1500 totalloss: 0.071 running_acc: 0.648 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2003it [03:15, 12.97it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2000 totalloss: 0.076 running_acc: 0.616 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2501it [03:53, 12.19it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2500 totalloss: 0.107 running_acc: 0.576 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3003it [04:33, 12.59it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3000 totalloss: 0.149 running_acc: 0.398 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3503it [05:12, 12.53it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3500 totalloss: 0.041 running_acc: 0.800 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4003it [05:50, 13.16it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 4000 totalloss: 0.049 running_acc: 0.798 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4413it [06:38, 11.07it/s]\n",
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoc acc: 0.635 \n",
"Epoch 2/99\n",
"----------\n",
"step: 0 totalloss: 0.000 running_acc: 0.002 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"502it [00:39, 12.57it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 500 totalloss: 0.089 running_acc: 0.678 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1002it [01:17, 12.71it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1000 totalloss: 0.166 running_acc: 0.484 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1502it [01:57, 12.68it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1500 totalloss: 0.103 running_acc: 0.662 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2003it [03:14, 12.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2000 totalloss: 0.101 running_acc: 0.646 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2503it [03:54, 12.63it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2500 totalloss: 0.154 running_acc: 0.574 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3001it [04:33, 12.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3000 totalloss: 0.226 running_acc: 0.320 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3503it [05:12, 13.19it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3500 totalloss: 0.064 running_acc: 0.772 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4003it [05:52, 12.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 4000 totalloss: 0.063 running_acc: 0.800 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4413it [06:39, 11.05it/s]\n",
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoc acc: 0.625 \n",
"Epoch 3/99\n",
"----------\n",
"step: 0 totalloss: 0.000 running_acc: 0.002 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"502it [00:38, 12.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 500 totalloss: 0.121 running_acc: 0.646 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1002it [01:17, 12.93it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1000 totalloss: 0.226 running_acc: 0.482 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1502it [01:55, 12.86it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1500 totalloss: 0.143 running_acc: 0.660 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2003it [03:12, 12.54it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2000 totalloss: 0.143 running_acc: 0.660 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2503it [03:50, 13.69it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2500 totalloss: 0.211 running_acc: 0.596 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3003it [04:27, 13.69it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3000 totalloss: 0.270 running_acc: 0.384 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3503it [05:06, 13.45it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3500 totalloss: 0.084 running_acc: 0.776 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4003it [05:44, 13.14it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 4000 totalloss: 0.082 running_acc: 0.798 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4413it [06:30, 11.31it/s]\n",
"0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoc acc: 0.634 \n",
"Epoch 4/99\n",
"----------\n",
"step: 0 totalloss: 0.000 running_acc: 0.002 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"502it [00:38, 13.00it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 500 totalloss: 0.157 running_acc: 0.632 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1002it [01:16, 13.70it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1000 totalloss: 0.288 running_acc: 0.460 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1502it [01:53, 13.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 1500 totalloss: 0.191 running_acc: 0.664 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2003it [03:12, 12.62it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2000 totalloss: 0.192 running_acc: 0.656 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2503it [03:51, 12.33it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 2500 totalloss: 0.264 running_acc: 0.590 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3001it [04:35, 11.01it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3000 totalloss: 0.331 running_acc: 0.398 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3502it [05:22, 11.09it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 3500 totalloss: 0.100 running_acc: 0.784 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4002it [06:07, 12.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 4000 totalloss: 0.139 running_acc: 0.806 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4303it [06:37, 10.83it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-62-0c80d7842741>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_var\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 27\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m==\u001b[0m\u001b[1;36m30\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 28\u001b[0m \u001b[0moutput_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mG:\\Thesis-Git\\models\\MHA_resnet.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresidual_unit_4\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 43\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 44\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_state)\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 79\u001b[0m \u001b[0mhidden_state\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 81\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 79\u001b[0m \u001b[0mhidden_state\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 81\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 61\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 63\u001b[1;33m \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscaled_dot_product_attention\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 64\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mG:\\Thesis-Git\\models\\attention.py\u001b[0m in \u001b[0;36mscaled_dot_product_attention\u001b[1;34m(self, q, k, v)\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[0mweights\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 55\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 57\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"models = [model2, model3, model4, model5,model7]\n",
"num_epochs = 100\n",
"learning_rate = 0.001\n",
"for i in models:\n",
" i.eval()\n",
"net = ensemble(5).cuda()\n",
"criterion = nn.BCELoss()\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)\n",
"for epoch in range(0,num_epochs):\n",
" since = time.time()\n",
" print('Epoch {}/{}'.format(epoch , num_epochs - 1))\n",
" print('-' * 10)\n",
" \n",
" net.train() \n",
" epoch_loss = 0.0\n",
" epoch_acc = 0.0\n",
" running_acc = 0.0\n",
" running_loss = 0.0\n",
" for i,(inputs, target, header_files) in tqdm.tqdm(enumerate(test_loader)):\n",
" header_files = header_files[0] \n",
" input_var = torch.autograd.Variable(inputs.cuda().float())\n",
" target_var = torch.autograd.Variable(target.cuda().float())\n",
" output_list = []\n",
" for model in models:\n",
" with torch.no_grad():\n",
" a = model(input_var)\n",
" if a.shape[0] ==30 :\n",
" output_list.append(a.unsqueeze(dim = 0))\n",
" else :\n",
" output_list.append(a)\n",
"\n",
" output = torch.cat(output_list,dim=0).unsqueeze(dim=0)\n",
" result = net(output).unsqueeze(dim=0)\n",
" \n",
" loss = criterion(result, target_var)\n",
" label = result.ge(0.5).float()\n",
" np_label = label.detach().cpu().numpy()\n",
" np_target = target_var.detach().cpu().numpy()\n",
" if np.all(np_label[0,:]==np_target[0,:]):\n",
" epoch_acc += 1\n",
" running_acc +=1 \n",
" \n",
" running_loss += loss.data.item()\n",
" if (i%500) == 0:\n",
" \n",
" print('step: {} totalloss: {loss:.3f} running_acc: {running_acc:.3f} '.format(i, loss = running_loss/500, running_acc = running_acc/500))\n",
" \n",
" running_acc = 0.0\n",
" running_loss = 0.0\n",
" \n",
" loss.backward() \n",
" optimizer.step()\n",
"\n",
" print('epoc acc: {epoch_acc:.3f} '.format(epoch_acc = epoch_acc/4413) ) \n",
" PATH = f'log/weight-{epoch}-acc-{round(epoch_acc,3)} .pth' \n",
" torch.save(model.state_dict(), PATH)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weights_file = 'weights.csv'\n",
"sinus_rhythm = set(['426783006'])\n",
"classes, weights = load_weights(weights_file)\n",
"labels = load_labels(label_files, classes)\n",
"\n",
"binary_outputs, scalar_outputs = load_classifier_outputs(output_files, classes)\n",
"cf_matrix = compute_modified_confusion_matrix(labels,binary_outputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for j in range(26) :\n",
" cf_matrix[j,:] = cf_matrix[j,:]/sum(cf_matrix[j,:])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"fig, ax = plt.subplots(figsize=(50,50)) \n",
"ax = sns.heatmap(cf_matrix, annot=True, cmap='Blues')\n",
"\n",
"ax.set_title('Confusion Matrix\\n\\n');\n",
"ax.set_xlabel('\\nPredicted Values')\n",
"ax.set_ylabel('Actual Values ');\n",
"\n",
"## Ticket labels - List must be in alphabetical order\n",
"# ax.xaxis.set_ticklabels(['False])\n",
"# ax.yaxis.set_ticklabels(['False','True'])\n",
"\n",
"## Display the visualization of the Confusion Matrix.\n",
"plt.savefig('save_as_a_png.png')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import seaborn as sns\n",
"fig, ax = plt.subplots(figsize=(20,20)) \n",
"ax = sns.heatmap(weights, annot=True, cmap='Blues')\n",
"\n",
"ax.set_title('Dx Scores\\n\\n');\n",
"ax.set_xlabel('\\n Predicted Diseases');\n",
"ax.set_ylabel('Actual Values ');\n",
"\n",
"## Ticket labels - List must be in alphabetical order\n",
"# ax.xaxis.set_ticklabels(['False])\n",
"# ax.yaxis.set_ticklabels(['False','True'])\n",
"\n",
"## Display the visualization of the Confusion Matrix.\n",
"plt.savefig('Dx scores.png')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"i) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}