[24c4a6]: / 4-Models / 2D_CNN-pytorch / 2D_CNN.ipynb

Download this file

2851 lines (2850 with data), 176.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import TensorDataset\n",
    "import torch.optim as optim\n",
    "from torch.optim import lr_scheduler\n",
    "import numpy as np\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, models, transforms\n",
    "from torchvision.transforms import Resize, ToTensor, Normalize\n",
    "import matplotlib.pyplot as plt\n",
    "#from imblearn.under_sampling import RandomUnderSampler\n",
    "\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \\\n",
    "    average_precision_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "import time\n",
    "import os\n",
    "from pathlib import Path\n",
    "from skimage import io\n",
    "import copy\n",
    "from torch import optim, cuda\n",
    "import pandas as pd\n",
    "import glob\n",
    "from collections import Counter\n",
    "# Useful for examining network\n",
    "from functools import reduce\n",
    "from operator import __add__\n",
    "# from torchsummary import summary\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=FutureWarning)\n",
    "from PIL import Image\n",
    "from timeit import default_timer as timer\n",
    "import matplotlib.pyplot as plt\n",
    "import gc\n",
    "import cv2\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import StepLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_gaussian_noise(signal):\n",
    "    noise = np.random.normal(0, 0.05, signal.shape)\n",
    "    return (signal + noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make train_df_available 3.213082475\n",
      "make train_df_available 3.222219173\n",
      "Successfully loaded pre-built trainData and evalData\n",
      "time to generate arrays:  0.0005355969997253851\n"
     ]
    }
   ],
   "source": [
    "def split_train_eval_test(metadata_df_path):\n",
    "    df = pd.read_csv(metadata_df_path)\n",
    "\n",
    "    pat_ids = list(set(df['PatientID']))\n",
    "\n",
    "    permuted_pat_ids = np.random.permutation(pat_ids)\n",
    "\n",
    "    last_train_idx = int(len(permuted_pat_ids) * .80)\n",
    "\n",
    "    last_eval_idx = int(last_train_idx +\n",
    "                        ((len(permuted_pat_ids) - last_train_idx) * .70))\n",
    "\n",
    "    train_pat_ids = permuted_pat_ids[:last_train_idx]\n",
    "    eval_pat_ids = permuted_pat_ids[last_train_idx:last_eval_idx]\n",
    "    test_pat_ids = permuted_pat_ids[last_eval_idx:]\n",
    "\n",
    "    df_train = df[df['PatientID'].isin(train_pat_ids)]\n",
    "    df_eval = df[df['PatientID'].isin(eval_pat_ids)]\n",
    "    df_test = df[df['PatientID'].isin(test_pat_ids)]\n",
    "\n",
    "    print(\n",
    "        \"Train pts/pct: {}, {} \\n Eval pts/pct: {}, {} \\n Test pts/pct: {}, {}\"\n",
    "            .format(len(df_train),\n",
    "                    len(df_train) / float(len(df)), len(df_eval),\n",
    "                    len(df_eval) / float(len(df)), len(df_test),\n",
    "                    len(df_test) / float(len(df))))\n",
    "\n",
    "    df_train.to_csv('/data/ECGnet/data/train_metadata_full_no_balance_DEIDENTIFIED.csv', index=False)\n",
    "    df_eval.to_csv('/data/ECGnet/data/eval_metadata_full_no_balance_DEIDENTIFIED.csv', index=False)\n",
    "    df_test.to_csv('/data/ECGnet/data/test_metadata_full_no_balance_DEIDENTIFIED.csv', index=False)\n",
    "\n",
    "\n",
    "# train_subsample_path = \"/data/ECGnet/data/train_metadata_full_no_balance_DEIDENTIFIED.csv\"\n",
    "# if not Path(train_subsample_path).exists():\n",
    "#     split_train_eval_test('/data/proj/cardiac-amyloid/amyloid_simple_run_DEIDENTIFIED.csv')\n",
    "\n",
    "train_df = pd.read_csv('/data/ECGnet/data/train_df_pace_removed_any_patient_negative_DEIDENTIFIED.csv')\n",
    "eval_df = pd.read_csv('/data/ECGnet/data/eval_df_pace_removed_any_patient_negative_DEIDENTIFIED.csv')\n",
    "\n",
    "\n",
    "def balance_binary(df, label, sampling_strategy=1):\n",
    "    y = df[label]\n",
    "    X = df\n",
    "    binary_mask = np.bitwise_or(y == 0, y == 1)\n",
    "    binary_y = y[binary_mask]\n",
    "    binary_X = X[binary_mask]\n",
    "    rus = RandomUnderSampler(sampling_strategy=sampling_strategy)\n",
    "    X_res, y_res = rus.fit_resample(binary_X, binary_y)\n",
    "    return X_res\n",
    "\n",
    "train_available_path = '/data/ECGnet/data/train_df_pace_removed_any_patient_negative_available_DEIDENTIFIED.csv'\n",
    "print(\"make train_df_available\", time.process_time())\n",
    "if not Path(train_available_path).exists():\n",
    "    print(\"No train_df_available found, creating now...\")\n",
    "    npyfilespath = \"/data/proj/cardiac-amyloid/ekg_waveforms_output/\"\n",
    "    os.chdir(npyfilespath)\n",
    "    npfiles = glob.glob(\"*.npy\")\n",
    "    npfiles.sort()\n",
    "    npdf = pd.DataFrame({'filename': npfiles})\n",
    "    train_df_available = train_df.merge(npdf)\n",
    "    eval_df_available = eval_df.merge(npdf)\n",
    "\n",
    "    # print(\"balance datasets\", time.process_time())\n",
    "    label = \"ANY_AMYLOID\"\n",
    "    # train_df_available = balance_binary(train_df_available, label)\n",
    "    # eval_df_available = balance_binary(eval_df_available,label)\n",
    "    train_df_available = train_df_available\n",
    "    eval_df_available = eval_df_available\n",
    "    train_df_available.to_csv(\"/data/ECGnet/data/train_df_pace_removed_any_patient_negative_available_DEIDENTIFIED.csv\")\n",
    "    eval_df_available.to_csv(\"/data/ECGnet/data/eval_df_pace_removed_any_patient_negative_available_DEIDENTIFIED.csv\")\n",
    "    # print(\"balance datasets\", time.process_time())\n",
    "else:\n",
    "    train_df_available = pd.read_csv('/data/ECGnet/data/train_df_pace_removed_any_patient_negative_available_DEIDENTIFIED.csv')\n",
    "    eval_df_available = pd.read_csv('/data/ECGnet/data/eval_df_pace_removed_any_patient_negative_available_DEIDENTIFIED.csv')\n",
    "print(\"make train_df_available\", time.process_time())\n",
    "\n",
    "# filelist_train = train_df_available['filename']\n",
    "# filelist_eval = eval_df_available['filename']\n",
    "#\n",
    "# print('train DF length:', len(filelist_train), '\\n eval DF length:', len(filelist_eval))\n",
    "# print('train ECG count:', len(train_df_available), '\\n eval ECG count:', len(eval_df_available))\n",
    "\n",
    "npyfilespath = \"/data/proj/cardiac-amyloid/ekg_waveforms_output/\"\n",
    "fpath_train = \"/data/ECGnet/data/amyloid_concat_train_pace_removed_any_patient_negative.npy\"\n",
    "fpath_eval = \"/data/ECGnet/data/amyloid_concat_eval_pace_removed_any_patient_negative.npy\"\n",
    "\n",
    "try:\n",
    "    trainData = np.load(fpath_train)  # Load input data. Input data should be compiled as numpy arrays with (sample numbers, time, lead, 1)\n",
    "    evalData = np.load(fpath_eval)\n",
    "    print('Successfully loaded pre-built trainData and evalData')\n",
    "except:\n",
    "    trainData = []\n",
    "    evalData = []\n",
    "\n",
    "time1 = timer()\n",
    "if not Path(fpath_train).exists():  # or (len(trainData)!=len(train_df_available)) or (len(evalData)!=len(eval_df_available)):\n",
    "    print(\"no train array found, building now...\")\n",
    "    for npfile in filelist_train:\n",
    "        i = 0\n",
    "        try:\n",
    "            path = os.path.join(npyfilespath + npfile)\n",
    "            file = np.load(path)\n",
    "            trainData.append(file)\n",
    "            i += 1\n",
    "            if i % 1 == 100:\n",
    "                print(\"{i} EKGs have been written to array\")\n",
    "        except:\n",
    "            continue\n",
    "    trainData = np.array(trainData)\n",
    "    np.save(fpath_train, trainData)\n",
    "\n",
    "if not Path(fpath_eval).exists():  # or (len(trainData)!=len(train_df_available)) or (len(evalData)!=len(eval_df_available)):\n",
    "    print(\"no eval array found, building now...\")\n",
    "    for npfile in filelist_eval:\n",
    "        i = 0\n",
    "        try:\n",
    "            path = os.path.join(npyfilespath + npfile)\n",
    "            file = np.load(path)\n",
    "            evalData.append(file)\n",
    "            i += 1\n",
    "            if i % 1 == 100:\n",
    "                print(\"{i} EKGs have been written to array\")\n",
    "        except:\n",
    "            continue\n",
    "    evalData = np.array(evalData)\n",
    "    np.save(fpath_eval, evalData)\n",
    "print(\"time to generate arrays: \", timer() - time1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainData initial shape: (4465, 2500, 12, 1)\n",
      "trainData norm_prep shape: (4465, 12, 2500, 1)\n",
      "(4465, 12, 2500, 1) (12,) (12,)\n",
      "Mean and STD after normalization [ 1.04420378e-13 -4.35152501e-15  3.41003085e-14 -1.21242727e-14\n",
      "  1.14753556e-14 -2.43490832e-17 -1.67866714e-14 -1.79234147e-13\n",
      " -2.84146375e-13  7.58005498e-14  1.49721463e-14 -3.36270833e-14] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Median and MAD after normalization [ 1.04420378e-13 -4.35152501e-15  3.41003085e-14 -1.21242727e-14\n",
      "  1.14753556e-14 -2.43490832e-17 -1.67866714e-14 -1.79234147e-13\n",
      " -2.84146375e-13  7.58005498e-14  1.49721463e-14 -3.36270833e-14] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "trainData shape for model: (4465, 1, 2500, 12)\n",
      "evalData shape for model: (934, 1, 2500, 12)\n",
      "X_train: (4465, 1, 2500, 12)\n",
      "X_test: (934, 1, 2500, 12)\n",
      "train label count:  Counter({0: 3479, 1: 986}) \n",
      " eval label count:  Counter({0: 706, 1: 228})\n"
     ]
    }
   ],
   "source": [
    "trainData = np.array(trainData)\n",
    "evalData = np.array(evalData)\n",
    "print('trainData initial shape:', trainData.shape)\n",
    "\n",
    "trainData = np.transpose(trainData, axes=[0, 2, 1, 3])\n",
    "evalData = np.transpose(evalData, axes=[0, 2, 1, 3])\n",
    "print('trainData norm_prep shape:', trainData.shape)\n",
    "assert trainData.shape[1:] == (12, 2500, 1), \"train is not X,12,2500\"\n",
    "assert evalData.shape[1:] == (12, 2500, 1), \"eval is not X,12,2500\"\n",
    "#mean_tr, std_tr = np.mean(trainData, axis=(0,2,3)), np.std(trainData, axis=(0,2,3))\n",
    "\n",
    "# Normalization\n",
    "\n",
    "# print(\"eval sum/mean before norm: \", evalData.sum(),evalData.mean())\n",
    "# evalData = normalize_array(trainData, evalData)\n",
    "# print(\"eval sum/mean after norm: \", evalData.sum(),evalData.mean())\n",
    "#\n",
    "# print(\"train sum/mean before norm: \", trainData.sum(), trainData.mean())\n",
    "# trainData = normalize_array(trainData, trainData)\n",
    "# print(\"train sum/mean after norm: \", trainData.sum(), trainData.mean())\n",
    "\n",
    "def normalize_array(x, y):\n",
    "    xmax, xmean, xmin = x.max(), x.mean(), x.min()\n",
    "    return (y - xmean) / (xmax - xmin)\n",
    "\n",
    "def normalization_with_limits(source_array, target_array):\n",
    "    mean = np.mean(source_array)\n",
    "    sd = np.std(source_array)\n",
    "    datasdlimit = target_array\n",
    "    datasdlimit = np.where(datasdlimit > mean + (3*sd), mean + (3*sd), datasdlimit)\n",
    "    datasdlimit = np.where(datasdlimit < mean - (3*sd), mean - (3*sd), datasdlimit)\n",
    "    datanorm2 = cv2.normalize(datasdlimit, datasdlimit, -1, 1, cv2.NORM_MINMAX)\n",
    "    return datanorm2\n",
    "\n",
    "def truncate_data(source_array, target_array):\n",
    "    mean = np.mean(source_array)\n",
    "    sd = np.std(source_array)\n",
    "    datasdlimit = target_array\n",
    "    datasdlimit = np.where(datasdlimit > mean + (3*sd), mean + (3*sd), datasdlimit)\n",
    "    datasdlimit = np.where(datasdlimit < mean - (3*sd), mean - (3*sd), datasdlimit)\n",
    "    #datanorm2 = cv2.normalize(datasdlimit, datasdlimit, -1, 1, cv2.NORM_MINMAX)\n",
    "    return datasdlimit\n",
    "\n",
    "def normalization_with_limits_median(source_array, target_array):\n",
    "    median = np.median(source_array)\n",
    "    mad = stats.median_absolute_deviation(source_array)\n",
    "    datamadlimit = target_array\n",
    "    #Truncate data to 3*mad to limit outlier effect and use robust statistical approach over mean\n",
    "    datamadlimit = np.where(datamadlimit > median + (3*mad), median + (3*mad), datamadlimit)\n",
    "    datamadlimit = np.where(datamadlimit < median - (3*mad), median - (3*mad), datamadlimit)\n",
    "    datanorm2 = cv2.normalize(datamadlimit, datamadlimit, -1, 1, cv2.NORM_MINMAX)\n",
    "    return datanorm2\n",
    "\n",
    "#Per-lead normalization using mean and standard deviation\n",
    "# mean_tr, std_tr = np.mean(trainData, axis=(0,2,3)), np.std(trainData, axis=(0,2,3))\n",
    "# print(trainData.shape, mean_tr.shape, std_tr.shape)\n",
    "# for i in range(trainData.shape[1]):\n",
    "#     trainData[:,i,:,:] = (trainData[:,i,:,:] - mean_tr[i]) / std_tr[i]\n",
    "#     evalData[:,i,:,:] = (evalData[:,i,:,:] - mean_tr[i]) / std_tr[i]\n",
    "\n",
    "def truncate_data_median(source_array, target_array):\n",
    "    median = np.median(source_array)\n",
    "    mad = stats.median_absolute_deviation(source_array)\n",
    "    datamadlimit = target_array\n",
    "    #Truncate data to 3*mad to limit outlier effect and use robust statistical approach over mean\n",
    "    datamadlimit = np.where(datamadlimit > median + (3*mad), median + (3*mad), datamadlimit)\n",
    "    datamadlimit = np.where(datamadlimit < median - (3*mad), median - (3*mad), datamadlimit)\n",
    "    # datanorm2 = cv2.normalize(datamadlimit, datamadlimit, -1, 1, cv2.NORM_MINMAX)\n",
    "    return datamadlimit\n",
    "\n",
    "evalData = truncate_data(trainData, evalData)\n",
    "trainData = truncate_data(trainData, trainData)\n",
    "\n",
    "mean_tr, std_tr = np.mean(trainData, axis=(0,2,3)), np.std(trainData, axis=(0,2,3))\n",
    "#med_va, mad_va = np.median(evalData, axis=(0,2,3)), stats.std(evalData, axis=(0,2,3))\n",
    "print(trainData.shape, mean_tr.shape, std_tr.shape)\n",
    "for i in range(trainData.shape[1]):\n",
    "    trainData[:,i,:,:] = (trainData[:,i,:,:] - mean_tr[i]) / std_tr[i]\n",
    "    evalData[:,i,:,:] = (evalData[:,i,:,:] - mean_tr[i]) / std_tr[i]\n",
    "\n",
    "\n",
    "mean_norm, std_norm  = np.mean(trainData, axis=(0,2,3)), np.std(trainData, axis=(0,2,3))\n",
    "print(\"Mean and STD after normalization\", mean_norm, std_norm)\n",
    "med_norm, mad_norm  = np.median(trainData, axis=(0,2,3)), stats.median_absolute_deviation(trainData, axis=(0,2,3))\n",
    "print(\"Median and MAD after normalization\", mean_norm, std_norm)\n",
    "\n",
    "trainData = np.transpose(trainData, axes=[0, 3, 2, 1])\n",
    "print('trainData shape for model:', trainData.shape)\n",
    "evalData = np.transpose(evalData, axes=[0, 3, 2, 1])\n",
    "print('evalData shape for model:', evalData.shape)\n",
    "\n",
    "assert trainData.shape[1:] == (1, 2500, 12), \"train is not X,12,2500\"\n",
    "assert evalData.shape[1:] == (1, 2500, 12), \"eval is not X,12,2500\"\n",
    "\n",
    "\n",
    "\n",
    "#MODIFYING DATA TO SIMPLIFY\n",
    "def ecg_modifier(data, labels, list):\n",
    "    for i,ecg in enumerate(data):\n",
    "        if labels[i]==1:\n",
    "            list.append(ecg)\n",
    "        else:\n",
    "            list.append(ecg)\n",
    "\n",
    "trainlist = []\n",
    "evallist = []\n",
    "\n",
    "# ecg_modifier(trainData,y_train,trainlist)\n",
    "# trainData = np.array(trainlist)\n",
    "# ecg_modifier(evalData,y_test,evallist)\n",
    "# evalData = np.array(evallist)\n",
    "\n",
    "y_train = [round(x) for x in train_df_available['ANY_AMYLOID']]\n",
    "y_train = np.array(y_train)\n",
    "\n",
    "y_test = [round(x) for x in eval_df_available['ANY_AMYLOID']]\n",
    "y_test = np.array(y_test)\n",
    "\n",
    "# Change the place holder names\n",
    "X_train = np.array(trainData)\n",
    "X_test = np.array(evalData)\n",
    "\n",
    "# Return unique values and count\n",
    "# u_train, train_counts = np.unique(label_train, return_counts=True)\n",
    "# u_eval, eval_counts = np.unique(label_eval, return_counts=True)\n",
    "# print('Training class counts:', train_counts)\n",
    "# print('Evaluation class counts:', eval_counts)\n",
    "print('X_train:', X_train.shape)\n",
    "print('X_test:', X_test.shape)\n",
    "print('train label count: ', Counter(y_train), '\\n eval label count: ', Counter(y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "SHAPE = (1, 12, 2500)  # Shape of the input. ECG model takes 2500 points sampled at 4 micro-seconds/point for 12 lead.\n",
    "batch_size = 10\n",
    "dlen = X_train.shape[0]\n",
    "\n",
    "# print(X_test.shape, y_test.shape)\n",
    "# print(\"X_test.size(0): \", X_test.size(0), \"y_test.size(0): \", y_test.size(0))\n",
    "y_test = torch.FloatTensor(y_test)\n",
    "X_test = TensorDataset(torch.from_numpy(X_test), y_test)\n",
    "test_loader = DataLoader(X_test, batch_size=batch_size, pin_memory=True, shuffle=True)\n",
    "# a,b=next(iter(test_loader))\n",
    "\n",
    "y_train = torch.FloatTensor(y_train)\n",
    "X_train = TensorDataset(torch.from_numpy(X_train), y_train)\n",
    "train_loader = DataLoader(X_train, batch_size=batch_size, pin_memory=True, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    del model\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "except:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "\n",
    "    # Convolution as a whole should span at least 1 beat, preferably more\n",
    "    # Input shape = (1,2500,12)\n",
    "    #     summary(model, input_size=(1, 2500, 12))\n",
    "    # CLEAR EXPERIMENTS TO TRY\n",
    "    #     Alter kernel size, right now drops to 10 channels at beginning then stays there\n",
    "    #     Try increasing output channel size as you go deeper\n",
    "    #     Alter stride to have larger image at FC layer\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        base_conv = 64\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=1, out_channels=base_conv, kernel_size=(12,3), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))\n",
    "        )\n",
    "\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1))\n",
    "        )\n",
    "\n",
    "        self.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))\n",
    "        )\n",
    "\n",
    "        self.conv4 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1))\n",
    "        )\n",
    "\n",
    "        self.conv5 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))\n",
    "        )\n",
    "        self.conv6 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1))\n",
    "        )\n",
    "\n",
    "        self.conv7 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1,1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))\n",
    "        )\n",
    "        \n",
    "        self.conv8 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.Conv2d(in_channels=base_conv, out_channels=base_conv, kernel_size=(1, 1), stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm2d(base_conv),\n",
    "            nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1))\n",
    "        )\n",
    "\n",
    "        self.fc1 = nn.Sequential(\n",
    "            #             nn.AdaptiveAvgPool2d((1, 1)),\n",
    "            nn.Flatten(),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.Linear(19200, 4096),  # 64 kernel size, 2500 pooled to 29,\n",
    "                #nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.Linear(4096, 512),\n",
    "                #nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.Linear(512, 32),\n",
    "                #nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.2),\n",
    "            nn.Linear(32, 8),\n",
    "                #nn.Dropout(p=0.2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(8, 1))\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv2(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv3(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv4(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv5(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv6(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv7(out)\n",
    "        #         print(out.shape)\n",
    "        out = self.conv8(out)\n",
    "        #out = self.conv9(out)\n",
    "        out = self.fc1(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (conv1): Sequential(\n",
      "    (0): Conv2d(1, 64, kernel_size=(12, 3), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv2): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv3): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv4): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv5): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv6): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv7): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv8): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (4): ReLU()\n",
      "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (fc1): Sequential(\n",
      "    (0): Flatten()\n",
      "    (1): Dropout(p=0.5, inplace=False)\n",
      "    (2): Linear(in_features=19200, out_features=4096, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Dropout(p=0.5, inplace=False)\n",
      "    (5): Linear(in_features=4096, out_features=512, bias=True)\n",
      "    (6): ReLU()\n",
      "    (7): Dropout(p=0.5, inplace=False)\n",
      "    (8): Linear(in_features=512, out_features=32, bias=True)\n",
      "    (9): ReLU()\n",
      "    (10): Dropout(p=0.2, inplace=False)\n",
      "    (11): Linear(in_features=32, out_features=8, bias=True)\n",
      "    (12): ReLU()\n",
      "    (13): Linear(in_features=8, out_features=1, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = Net()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 0)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#model.summary\n",
    "torch.cuda.get_device_capability()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on: cuda:0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (conv1): Sequential(\n",
       "    (0): Conv2d(1, 64, kernel_size=(12, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv2): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv3): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv4): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv5): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv6): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv7): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (conv8): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=(1, 1), stride=(1, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc1): Sequential(\n",
       "    (0): Flatten()\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=19200, out_features=4096, bias=True)\n",
       "    (3): ReLU()\n",
       "    (4): Dropout(p=0.5, inplace=False)\n",
       "    (5): Linear(in_features=4096, out_features=512, bias=True)\n",
       "    (6): ReLU()\n",
       "    (7): Dropout(p=0.5, inplace=False)\n",
       "    (8): Linear(in_features=512, out_features=32, bias=True)\n",
       "    (9): ReLU()\n",
       "    (10): Dropout(p=0.2, inplace=False)\n",
       "    (11): Linear(in_features=32, out_features=8, bias=True)\n",
       "    (12): ReLU()\n",
       "    (13): Linear(in_features=8, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f'Train on: {device}')\n",
    "# send model to device\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_outputs(output_dict,\n",
    "                   new_outputs,\n",
    "                   outputs_to_update=['scores', 'predictions', 'labels']):\n",
    "    for x in outputs_to_update:\n",
    "        if (type(new_outputs[x]) is list):\n",
    "            output_dict[x] += list(new_outputs[x])\n",
    "        else:\n",
    "            output_dict[x].append(new_outputs[x])\n",
    "\n",
    "    return output_dict\n",
    "\n",
    "def accuracyFromLogits(output, target):\n",
    "    scores = torch.sigmoid(output)\n",
    "    pred = torch.round(scores)\n",
    "    correct_tensor = pred.eq(target.data.view_as(pred))\n",
    "    # Need to convert correct tensor from int to float to average\n",
    "    accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))\n",
    "    return accuracy, scores, pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model,\n",
    "          device,\n",
    "          criterion,\n",
    "          optimizer,\n",
    "          scheduler,\n",
    "          train_loader,\n",
    "          valid_loader,\n",
    "          save_file_name,\n",
    "          max_epochs_stop=5,\n",
    "          n_epochs=20,\n",
    "          print_every=1):\n",
    "    \"\"\"Train a PyTorch Model\n",
    "\n",
    "    Params\n",
    "    --------\n",
    "        model (PyTorch model): cnn to train\n",
    "        criterion (PyTorch loss): objective to minimize\n",
    "        optimizer (PyTorch optimizier): optimizer to compute gradients of model parameters\n",
    "        train_loader (PyTorch dataloader): training dataloader to iterate through\n",
    "        valid_loader (PyTorch dataloader): validation dataloader used for early stopping\n",
    "        save_file_name (str ending in '.pt'): file path to save the model state dict\n",
    "        max_epochs_stop (int): maximum number of epochs with no improvement in validation loss for early stopping\n",
    "        n_epochs (int): maximum number of training epochs\n",
    "        print_every (int): frequency of epochs to print training stats\n",
    "\n",
    "    Returns\n",
    "    --------\n",
    "        model (PyTorch model): trained cnn with best weights\n",
    "        history (DataFrame): history of train and validation loss and accuracy\n",
    "    \"\"\"\n",
    "\n",
    "    # Early stopping intialization\n",
    "    epochs_no_improve = 0\n",
    "    valid_loss_min = np.Inf\n",
    "\n",
    "    # Create empty history\n",
    "    history = []\n",
    "\n",
    "    # Number of epochs already trained (if using loaded in model weights)\n",
    "    try:\n",
    "        print(f'Model has been trained for: {model.epochs} epochs.\\n')\n",
    "    except:\n",
    "        model.epochs = 0\n",
    "        print(f'Starting Training from Scratch on \\n')\n",
    "\n",
    "    overall_start = timer()\n",
    "\n",
    "    # Main loop\n",
    "    for epoch in range(n_epochs):\n",
    "        #scheduler.step()\n",
    "        # keep track of training and validation loss each epoch\n",
    "        train_step_counter = 0\n",
    "        valid_step_counter = 0\n",
    "\n",
    "        train_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "\n",
    "        train_acc = 0\n",
    "        valid_acc = 0\n",
    "        running_corrects = 0\n",
    "\n",
    "        # Set to training\n",
    "        model.train()\n",
    "        start = timer()\n",
    "\n",
    "        # Training loop\n",
    "        for ii, (input, target) in enumerate(train_loader):\n",
    "\n",
    "            # Increment counter\n",
    "            train_step_counter += 1\n",
    "\n",
    "            # Set inputs dtype. Check input.dtype and target.dtype\n",
    "            # NOTE: This match the type of input and model weights\n",
    "            input, target = input.to(torch.float32), target.to(torch.float32)\n",
    "\n",
    "            # Send inputs to device\n",
    "            input, target = input.to(device), target.to(device)\n",
    "\n",
    "            # Clear gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Calculate output\n",
    "            output = model(input)\n",
    "\n",
    "            # Loss and backpropagation of gradients\n",
    "            # WARNING: Be careful about .squeeze(). In this case, since the output is\n",
    "            # [BATCHSIZE, 1] the scenario is easy!\n",
    "            matched_output = output.squeeze()\n",
    "            assert matched_output.shape == target.shape\n",
    "            loss = criterion(matched_output, target) # save a copy, plot the \n",
    "\n",
    "            # Backward computation\n",
    "            loss.backward()\n",
    "\n",
    "            # Update the parameters\n",
    "            optimizer.step()\n",
    "            #scheduler.step()\n",
    "\n",
    "            # Track train loss by multiplying average loss by number of examples in batch\n",
    "            train_loss += loss.item()\n",
    "\n",
    "            # Calculate accuracy by finding max log probability\n",
    "            accuracy, scores, pred = accuracyFromLogits(output, target)\n",
    "            train_acc += accuracy.item()\n",
    "\n",
    "            # Track training progress\n",
    "            print(\n",
    "                f'Epoch: {epoch}\\t{100 * (ii + 1) / len(train_loader):.2f}% complete. {timer() - start:.2f} seconds elapsed in epoch.',\n",
    "                end='\\r')\n",
    "\n",
    "        # Calculate accuracy and loss over the batch\n",
    "        train_acc = train_acc / train_step_counter\n",
    "        train_loss = train_loss / train_step_counter\n",
    "\n",
    "\n",
    "        # After training loops ends, start validation\n",
    "        model.epochs += 1\n",
    "\n",
    "        # Don't need to keep track of gradients\n",
    "        output_dict = {\n",
    "            \"scores\": [],\n",
    "            \"predictions\": [],\n",
    "            \"labels\": []\n",
    "        }\n",
    "        with torch.no_grad():\n",
    "            # Set to evaluation mode\n",
    "            model.eval()\n",
    "\n",
    "            # Validation loop\n",
    "            for input, target in valid_loader:\n",
    "\n",
    "                # Increment counter\n",
    "                valid_step_counter += 1\n",
    "                #print()\n",
    "\n",
    "                # Set inputs dtype. Check input.dtype and target.dtype\n",
    "                # NOTE: This match the type of input and model weights\n",
    "                input, target = input.to(torch.float32), target.to(torch.float32)\n",
    "\n",
    "                # Send inputs to device\n",
    "                input, target = input.to(device), target.to(device)\n",
    "\n",
    "                # Calculate output\n",
    "                output = model(input)\n",
    "\n",
    "                # Loss and backpropagation of gradients\n",
    "                loss = criterion(output.squeeze(), target) # save copy\n",
    "\n",
    "                # Track train loss by multiplying average loss by number of examples in batch\n",
    "                valid_loss += loss.item()\n",
    "\n",
    "                # Calculate accuracy by finding max log probability\n",
    "                accuracy, scores, pred = accuracyFromLogits(output, target)\n",
    "                #print(\"accuracy, scores, pred: \", accuracy, scores, pred)\n",
    "                valid_acc += accuracy.item()\n",
    "\n",
    "                output_dict = update_outputs(output_dict, new_outputs={\n",
    "                    \"predictions\": pred.detach().cpu().numpy().squeeze().tolist(),\n",
    "                    \"scores\": scores.detach().cpu().numpy().squeeze().tolist(),\n",
    "                    \"labels\": target.detach().cpu().numpy().squeeze().tolist()\n",
    "                })\n",
    "\n",
    "            # Calculate accuracy and loss over the batch\n",
    "            valid_acc = valid_acc / valid_step_counter\n",
    "            valid_loss = valid_loss / valid_step_counter\n",
    "\n",
    "            # Calculate average accuracy\n",
    "            roc_auc_curve = roc_auc_score(output_dict[\"labels\"], output_dict[\"scores\"])\n",
    "            history.append([train_loss, valid_loss, train_acc, valid_acc, roc_auc_curve])\n",
    "\n",
    "        # Print training and validation results\n",
    "        if (epoch + 1) % print_every == 0:\n",
    "            # print(\"epoch_acc\",epoch_acc)\n",
    "            print(\"\\n---------------------------------------------------------------------------------------------\"\n",
    "                  f'Epoch {epoch} Results'\n",
    "                  \"---------------------------------------------------------------------------------------------\\n\")\n",
    "            print(\n",
    "                f'Training Loss: {train_loss:.4f} \\t Validation Loss: {valid_loss:.4f}'\n",
    "            )\n",
    "            print(\n",
    "                f'Training Accuracy: {100 * train_acc:.2f}%\\t Validation Accuracy: {100 * valid_acc:.2f}%'\n",
    "            )\n",
    "\n",
    "            precision, recall, f1_score, _ = precision_recall_fscore_support(output_dict[\"labels\"],\n",
    "                                                                             output_dict[\"predictions\"])\n",
    "            print(f'roc_auc_score :{roc_auc_curve}')\n",
    "\n",
    "            for j in range(len(precision)):\n",
    "                print(f'\\nClass {j} Precision :{precision[j] * 100:.2f} ')\n",
    "                print(f'Class {j} Recall :{recall[j] * 100:.2f} ')\n",
    "                print(f'Class {j} f1_score :{f1_score[j] * 100:.2f} ')\n",
    "\n",
    "        # Save the model if validation loss decreases\n",
    "        if valid_loss < valid_loss_min:\n",
    "            # Save model\n",
    "            torch.save(model.state_dict(), save_file_name)\n",
    "            # Track improvement\n",
    "            epochs_no_improve = 0\n",
    "            valid_loss_min = valid_loss\n",
    "            valid_best_acc = valid_acc\n",
    "            best_epoch = epoch\n",
    "\n",
    "        # Otherwise increment count of epochs with no improvement\n",
    "        else:\n",
    "            epochs_no_improve += 1\n",
    "\n",
    "        scheduler.step()\n",
    "    # # Attach the optimizer\n",
    "    # model.optimizer = optimizer\n",
    "    # Record overall time and print out stats\n",
    "    total_time = timer() - overall_start\n",
    "    print(\n",
    "        f'\\nBest epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%'\n",
    "    )\n",
    "    print(\n",
    "        f'{total_time:.2f} total seconds elapsed. {total_time / (epoch):.2f} seconds per epoch.'\n",
    "    )\n",
    "    # eval_metrics_dict, b = (get_metrics(eval_outputs_dict))\n",
    "    # for idx, class_idx in enumerate(eval_metrics_dict):\n",
    "    #     for metric_name in eval_metrics_dict[class_idx]:\n",
    "    #         print(\"Class \", class_idx, \"Evaluation\", metric_name, \":\", eval_metrics_dict[class_idx][metric_name])\n",
    "\n",
    "    print(\"\\n\\n----------------\\n\\n\")\n",
    "    # Format history\n",
    "    history = pd.DataFrame(\n",
    "        history,\n",
    "        columns=['train_loss', 'valid_loss', 'train_acc', 'valid_acc', 'roc'])\n",
    "    print(\"history df shape: \", history.shape)\n",
    "    return model, history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "\n",
    "# WARNING: It is necessary as arguments like pos_weight may change the device if not defined on the correct device\n",
    "criterion.to(device)\n",
    "\n",
    "# WARNING: Check optimization parameters\n",
    "optimizer_ft = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01,\n",
    "                           amsgrad=False)\n",
    "scheduler = StepLR(optimizer_ft, step_size = 2, gamma = 0.95)\n",
    "# add L2 regularization penalty through weight_decay\n",
    "NUM_EPOCHS = 100\n",
    "save_file_name = 'ECGnet_torch_2DCNN_train_balanced_synthetic.pt'\n",
    "checkpoint_path = 'ECGnet_torch_2DCNN_train_balanced_synthetic.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training from Scratch on \n",
      "\n",
      "Epoch: 0\t100.00% complete. 21.94 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 0 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.5451 \t Validation Loss: 0.4977\n",
      "Training Accuracy: 75.03%\t Validation Accuracy: 75.74%\n",
      "roc_auc_score :0.7377615426668654\n",
      "\n",
      "Class 0 Precision :75.59 \n",
      "Class 0 Recall :100.00 \n",
      "Class 0 f1_score :86.10 \n",
      "\n",
      "Class 1 Precision :0.00 \n",
      "Class 1 Recall :0.00 \n",
      "Class 1 f1_score :0.00 \n",
      "Epoch: 1\t100.00% complete. 21.96 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 1 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.4235 \t Validation Loss: 0.4449\n",
      "Training Accuracy: 78.01%\t Validation Accuracy: 76.17%\n",
      "roc_auc_score :0.8048618358928483\n",
      "\n",
      "Class 0 Precision :76.48 \n",
      "Class 0 Recall :98.58 \n",
      "Class 0 f1_score :86.14 \n",
      "\n",
      "Class 1 Precision :58.33 \n",
      "Class 1 Recall :6.14 \n",
      "Class 1 f1_score :11.11 \n",
      "Epoch: 2\t100.00% complete. 21.92 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 2 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.3742 \t Validation Loss: 0.4385\n",
      "Training Accuracy: 81.48%\t Validation Accuracy: 77.23%\n",
      "roc_auc_score :0.8137269519407584\n",
      "\n",
      "Class 0 Precision :79.01 \n",
      "Class 0 Recall :94.90 \n",
      "Class 0 f1_score :86.23 \n",
      "\n",
      "Class 1 Precision :58.14 \n",
      "Class 1 Recall :21.93 \n",
      "Class 1 f1_score :31.85 \n",
      "Epoch: 3\t100.00% complete. 21.80 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 3 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.3352 \t Validation Loss: 0.4366\n",
      "Training Accuracy: 85.57%\t Validation Accuracy: 78.14%\n",
      "roc_auc_score :0.8222690224143929\n",
      "\n",
      "Class 0 Precision :80.61 \n",
      "Class 0 Recall :93.63 \n",
      "Class 0 f1_score :86.63 \n",
      "\n",
      "Class 1 Precision :60.53 \n",
      "Class 1 Recall :30.26 \n",
      "Class 1 f1_score :40.35 \n",
      "Epoch: 4\t100.00% complete. 21.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 4 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.3027 \t Validation Loss: 0.4304\n",
      "Training Accuracy: 87.79%\t Validation Accuracy: 78.35%\n",
      "roc_auc_score :0.8319541772277721\n",
      "\n",
      "Class 0 Precision :81.90 \n",
      "Class 0 Recall :91.64 \n",
      "Class 0 f1_score :86.50 \n",
      "\n",
      "Class 1 Precision :59.03 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :45.70 \n",
      "Epoch: 5\t100.00% complete. 21.85 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 5 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.2652 \t Validation Loss: 0.5016\n",
      "Training Accuracy: 90.60%\t Validation Accuracy: 76.91%\n",
      "roc_auc_score :0.8172866656726803\n",
      "\n",
      "Class 0 Precision :80.00 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :85.98 \n",
      "\n",
      "Class 1 Precision :56.14 \n",
      "Class 1 Recall :28.07 \n",
      "Class 1 f1_score :37.43 \n",
      "Epoch: 6\t100.00% complete. 22.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 6 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.2390 \t Validation Loss: 0.4960\n",
      "Training Accuracy: 92.01%\t Validation Accuracy: 79.15%\n",
      "roc_auc_score :0.82544977883803\n",
      "\n",
      "Class 0 Precision :81.71 \n",
      "Class 0 Recall :93.63 \n",
      "Class 0 f1_score :87.26 \n",
      "\n",
      "Class 1 Precision :64.00 \n",
      "Class 1 Recall :35.09 \n",
      "Class 1 f1_score :45.33 \n",
      "Epoch: 7\t100.00% complete. 22.09 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 7 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.2130 \t Validation Loss: 0.4742\n",
      "Training Accuracy: 93.67%\t Validation Accuracy: 80.16%\n",
      "roc_auc_score :0.8212563987873367\n",
      "\n",
      "Class 0 Precision :84.14 \n",
      "Class 0 Recall :90.93 \n",
      "Class 0 f1_score :87.41 \n",
      "\n",
      "Class 1 Precision :62.57 \n",
      "Class 1 Recall :46.93 \n",
      "Class 1 f1_score :53.63 \n",
      "Epoch: 8\t100.00% complete. 21.81 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 8 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1845 \t Validation Loss: 0.5647\n",
      "Training Accuracy: 94.56%\t Validation Accuracy: 80.16%\n",
      "roc_auc_score :0.8285870483574376\n",
      "\n",
      "Class 0 Precision :82.36 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :87.76 \n",
      "\n",
      "Class 1 Precision :66.67 \n",
      "Class 1 Recall :37.72 \n",
      "Class 1 f1_score :48.18 \n",
      "Epoch: 9\t100.00% complete. 21.97 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 9 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1670 \t Validation Loss: 0.6113\n",
      "Training Accuracy: 94.97%\t Validation Accuracy: 79.15%\n",
      "roc_auc_score :0.8203804482878585\n",
      "\n",
      "Class 0 Precision :81.33 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.11 \n",
      "\n",
      "Class 1 Precision :63.33 \n",
      "Class 1 Recall :33.33 \n",
      "Class 1 f1_score :43.68 \n",
      "Epoch: 10\t100.00% complete. 21.96 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 10 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1549 \t Validation Loss: 0.6409\n",
      "Training Accuracy: 95.64%\t Validation Accuracy: 79.10%\n",
      "roc_auc_score :0.8103722479002037\n",
      "\n",
      "Class 0 Precision :82.14 \n",
      "Class 0 Recall :92.49 \n",
      "Class 0 f1_score :87.01 \n",
      "\n",
      "Class 1 Precision :61.87 \n",
      "Class 1 Recall :37.72 \n",
      "Class 1 f1_score :46.87 \n",
      "Epoch: 11\t100.00% complete. 21.88 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 11 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1354 \t Validation Loss: 0.6322\n",
      "Training Accuracy: 95.79%\t Validation Accuracy: 80.27%\n",
      "roc_auc_score :0.8126397793350232\n",
      "\n",
      "Class 0 Precision :83.04 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :87.70 \n",
      "\n",
      "Class 1 Precision :65.28 \n",
      "Class 1 Recall :41.23 \n",
      "Class 1 f1_score :50.54 \n",
      "Epoch: 12\t100.00% complete. 21.84 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 12 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1244 \t Validation Loss: 0.6262\n",
      "Training Accuracy: 96.44%\t Validation Accuracy: 80.05%\n",
      "roc_auc_score :0.820125739277372\n",
      "\n",
      "Class 0 Precision :83.33 \n",
      "Class 0 Recall :92.07 \n",
      "Class 0 f1_score :87.48 \n",
      "\n",
      "Class 1 Precision :63.64 \n",
      "Class 1 Recall :42.98 \n",
      "Class 1 f1_score :51.31 \n",
      "Epoch: 13\t100.00% complete. 21.76 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 13 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1148 \t Validation Loss: 0.6662\n",
      "Training Accuracy: 96.94%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.8241203220515878\n",
      "\n",
      "Class 0 Precision :81.34 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.42 \n",
      "\n",
      "Class 1 Precision :65.79 \n",
      "Class 1 Recall :32.89 \n",
      "Class 1 f1_score :43.86 \n",
      "Epoch: 14\t100.00% complete. 21.92 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 14 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1127 \t Validation Loss: 0.6754\n",
      "Training Accuracy: 96.42%\t Validation Accuracy: 80.32%\n",
      "roc_auc_score :0.8201754385964912\n",
      "\n",
      "Class 0 Precision :82.60 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.71 \n",
      "\n",
      "Class 1 Precision :65.93 \n",
      "Class 1 Recall :39.04 \n",
      "Class 1 f1_score :49.04 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 15\t100.00% complete. 21.89 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 15 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.1043 \t Validation Loss: 0.7724\n",
      "Training Accuracy: 96.58%\t Validation Accuracy: 79.10%\n",
      "roc_auc_score :0.8133355698026936\n",
      "\n",
      "Class 0 Precision :80.75 \n",
      "Class 0 Recall :95.04 \n",
      "Class 0 f1_score :87.31 \n",
      "\n",
      "Class 1 Precision :66.02 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :41.09 \n",
      "Epoch: 16\t100.00% complete. 22.03 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 16 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0997 \t Validation Loss: 0.7270\n",
      "Training Accuracy: 96.87%\t Validation Accuracy: 79.04%\n",
      "roc_auc_score :0.804588489637692\n",
      "\n",
      "Class 0 Precision :80.92 \n",
      "Class 0 Recall :94.33 \n",
      "Class 0 f1_score :87.12 \n",
      "\n",
      "Class 1 Precision :63.96 \n",
      "Class 1 Recall :31.14 \n",
      "Class 1 f1_score :41.89 \n",
      "Epoch: 17\t100.00% complete. 21.87 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 17 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0962 \t Validation Loss: 0.6965\n",
      "Training Accuracy: 96.91%\t Validation Accuracy: 80.05%\n",
      "roc_auc_score :0.823070423935192\n",
      "\n",
      "Class 0 Precision :82.50 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.65 \n",
      "\n",
      "Class 1 Precision :65.67 \n",
      "Class 1 Recall :38.60 \n",
      "Class 1 f1_score :48.62 \n",
      "Epoch: 18\t100.00% complete. 21.96 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 18 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0845 \t Validation Loss: 0.6860\n",
      "Training Accuracy: 97.63%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.8057315739774364\n",
      "\n",
      "Class 0 Precision :83.55 \n",
      "Class 0 Recall :90.65 \n",
      "Class 0 f1_score :86.96 \n",
      "\n",
      "Class 1 Precision :60.71 \n",
      "Class 1 Recall :44.74 \n",
      "Class 1 f1_score :51.52 \n",
      "Epoch: 19\t100.00% complete. 21.56 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 19 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0836 \t Validation Loss: 0.6240\n",
      "Training Accuracy: 97.32%\t Validation Accuracy: 79.89%\n",
      "roc_auc_score :0.8257169126782964\n",
      "\n",
      "Class 0 Precision :83.79 \n",
      "Class 0 Recall :90.79 \n",
      "Class 0 f1_score :87.15 \n",
      "\n",
      "Class 1 Precision :61.54 \n",
      "Class 1 Recall :45.61 \n",
      "Class 1 f1_score :52.39 \n",
      "Epoch: 20\t100.00% complete. 20.17 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 20 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0781 \t Validation Loss: 0.7997\n",
      "Training Accuracy: 97.81%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.8019792753839272\n",
      "\n",
      "Class 0 Precision :82.03 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.51 \n",
      "\n",
      "Class 1 Precision :65.35 \n",
      "Class 1 Recall :36.40 \n",
      "Class 1 f1_score :46.76 \n",
      "Epoch: 21\t100.00% complete. 20.69 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 21 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0789 \t Validation Loss: 0.7253\n",
      "Training Accuracy: 97.54%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.812062024750261\n",
      "\n",
      "Class 0 Precision :82.43 \n",
      "Class 0 Recall :93.06 \n",
      "Class 0 f1_score :87.43 \n",
      "\n",
      "Class 1 Precision :64.23 \n",
      "Class 1 Recall :38.60 \n",
      "Class 1 f1_score :48.22 \n",
      "Epoch: 22\t100.00% complete. 20.25 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 22 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0751 \t Validation Loss: 0.7650\n",
      "Training Accuracy: 97.83%\t Validation Accuracy: 79.95%\n",
      "roc_auc_score :0.8063155409770886\n",
      "\n",
      "Class 0 Precision :81.68 \n",
      "Class 0 Recall :94.76 \n",
      "Class 0 f1_score :87.74 \n",
      "\n",
      "Class 1 Precision :67.83 \n",
      "Class 1 Recall :34.21 \n",
      "Class 1 f1_score :45.48 \n",
      "Epoch: 23\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 23 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0683 \t Validation Loss: 0.7674\n",
      "Training Accuracy: 98.17%\t Validation Accuracy: 80.74%\n",
      "roc_auc_score :0.815913721982009\n",
      "\n",
      "Class 0 Precision :82.61 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :88.02 \n",
      "\n",
      "Class 1 Precision :68.22 \n",
      "Class 1 Recall :38.60 \n",
      "Class 1 f1_score :49.30 \n",
      "Epoch: 24\t100.00% complete. 20.55 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 24 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0656 \t Validation Loss: 0.7198\n",
      "Training Accuracy: 98.10%\t Validation Accuracy: 79.84%\n",
      "roc_auc_score :0.8144600168977685\n",
      "\n",
      "Class 0 Precision :82.85 \n",
      "Class 0 Recall :93.06 \n",
      "Class 0 f1_score :87.66 \n",
      "\n",
      "Class 1 Precision :65.25 \n",
      "Class 1 Recall :40.35 \n",
      "Class 1 f1_score :49.86 \n",
      "Epoch: 25\t100.00% complete. 20.62 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 25 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0651 \t Validation Loss: 0.7393\n",
      "Training Accuracy: 98.14%\t Validation Accuracy: 78.83%\n",
      "roc_auc_score :0.8113041101336912\n",
      "\n",
      "Class 0 Precision :82.21 \n",
      "Class 0 Recall :91.64 \n",
      "Class 0 f1_score :86.67 \n",
      "\n",
      "Class 1 Precision :59.86 \n",
      "Class 1 Recall :38.60 \n",
      "Class 1 f1_score :46.93 \n",
      "Epoch: 26\t100.00% complete. 20.54 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 26 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0659 \t Validation Loss: 0.7362\n",
      "Training Accuracy: 98.05%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.8175724367576165\n",
      "\n",
      "Class 0 Precision :82.21 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :87.23 \n",
      "\n",
      "Class 1 Precision :63.24 \n",
      "Class 1 Recall :37.72 \n",
      "Class 1 f1_score :47.25 \n",
      "Epoch: 27\t100.00% complete. 20.08 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 27 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0627 \t Validation Loss: 0.7662\n",
      "Training Accuracy: 98.26%\t Validation Accuracy: 78.56%\n",
      "roc_auc_score :0.8137145271109786\n",
      "\n",
      "Class 0 Precision :81.23 \n",
      "Class 0 Recall :93.20 \n",
      "Class 0 f1_score :86.81 \n",
      "\n",
      "Class 1 Precision :61.29 \n",
      "Class 1 Recall :33.33 \n",
      "Class 1 f1_score :43.18 \n",
      "Epoch: 28\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 28 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0688 \t Validation Loss: 0.8259\n",
      "Training Accuracy: 97.70%\t Validation Accuracy: 78.40%\n",
      "roc_auc_score :0.8132920828984643\n",
      "\n",
      "Class 0 Precision :80.26 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :86.79 \n",
      "\n",
      "Class 1 Precision :62.14 \n",
      "Class 1 Recall :28.07 \n",
      "Class 1 f1_score :38.67 \n",
      "Epoch: 29\t100.00% complete. 20.61 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 29 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0611 \t Validation Loss: 0.7604\n",
      "Training Accuracy: 98.32%\t Validation Accuracy: 78.67%\n",
      "roc_auc_score :0.8007554296506137\n",
      "\n",
      "Class 0 Precision :81.89 \n",
      "Class 0 Recall :92.21 \n",
      "Class 0 f1_score :86.74 \n",
      "\n",
      "Class 1 Precision :60.43 \n",
      "Class 1 Recall :36.84 \n",
      "Class 1 f1_score :45.78 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 30\t100.00% complete. 20.51 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 30 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0603 \t Validation Loss: 0.7687\n",
      "Training Accuracy: 98.30%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.810950002484966\n",
      "\n",
      "Class 0 Precision :81.27 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.43 \n",
      "\n",
      "Class 1 Precision :66.07 \n",
      "Class 1 Recall :32.46 \n",
      "Class 1 f1_score :43.53 \n",
      "Epoch: 31\t100.00% complete. 20.15 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 31 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0593 \t Validation Loss: 0.7236\n",
      "Training Accuracy: 98.37%\t Validation Accuracy: 80.37%\n",
      "roc_auc_score :0.8175165250236072\n",
      "\n",
      "Class 0 Precision :82.89 \n",
      "Class 0 Recall :93.34 \n",
      "Class 0 f1_score :87.81 \n",
      "\n",
      "Class 1 Precision :66.19 \n",
      "Class 1 Recall :40.35 \n",
      "Class 1 f1_score :50.14 \n",
      "Epoch: 32\t100.00% complete. 20.03 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 32 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0565 \t Validation Loss: 0.7014\n",
      "Training Accuracy: 98.32%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.828183241389593\n",
      "\n",
      "Class 0 Precision :82.33 \n",
      "Class 0 Recall :93.06 \n",
      "Class 0 f1_score :87.37 \n",
      "\n",
      "Class 1 Precision :63.97 \n",
      "Class 1 Recall :38.16 \n",
      "Class 1 f1_score :47.80 \n",
      "Epoch: 33\t100.00% complete. 20.55 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 33 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0571 \t Validation Loss: 0.8409\n",
      "Training Accuracy: 98.34%\t Validation Accuracy: 78.30%\n",
      "roc_auc_score :0.8063093285621987\n",
      "\n",
      "Class 0 Precision :80.10 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :86.75 \n",
      "\n",
      "Class 1 Precision :62.00 \n",
      "Class 1 Recall :27.19 \n",
      "Class 1 f1_score :37.80 \n",
      "Epoch: 34\t100.00% complete. 20.55 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 34 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0585 \t Validation Loss: 0.8204\n",
      "Training Accuracy: 98.17%\t Validation Accuracy: 79.20%\n",
      "roc_auc_score :0.813174047015556\n",
      "\n",
      "Class 0 Precision :80.77 \n",
      "Class 0 Recall :95.18 \n",
      "Class 0 f1_score :87.39 \n",
      "\n",
      "Class 1 Precision :66.67 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :41.21 \n",
      "Epoch: 35\t100.00% complete. 20.60 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 35 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0587 \t Validation Loss: 0.7496\n",
      "Training Accuracy: 98.26%\t Validation Accuracy: 79.89%\n",
      "roc_auc_score :0.8202499875751703\n",
      "\n",
      "Class 0 Precision :82.03 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.51 \n",
      "\n",
      "Class 1 Precision :65.35 \n",
      "Class 1 Recall :36.40 \n",
      "Class 1 f1_score :46.76 \n",
      "Epoch: 36\t100.00% complete. 20.14 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 36 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0480 \t Validation Loss: 0.8689\n",
      "Training Accuracy: 98.93%\t Validation Accuracy: 79.04%\n",
      "roc_auc_score :0.8109934893891952\n",
      "\n",
      "Class 0 Precision :80.62 \n",
      "Class 0 Recall :95.47 \n",
      "Class 0 f1_score :87.42 \n",
      "\n",
      "Class 1 Precision :67.35 \n",
      "Class 1 Recall :28.95 \n",
      "Class 1 f1_score :40.49 \n",
      "Epoch: 37\t100.00% complete. 20.03 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 37 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0488 \t Validation Loss: 0.7923\n",
      "Training Accuracy: 98.70%\t Validation Accuracy: 78.83%\n",
      "roc_auc_score :0.8161808558222752\n",
      "\n",
      "Class 0 Precision :81.73 \n",
      "Class 0 Recall :92.49 \n",
      "Class 0 f1_score :86.78 \n",
      "\n",
      "Class 1 Precision :60.74 \n",
      "Class 1 Recall :35.96 \n",
      "Class 1 f1_score :45.18 \n",
      "Epoch: 38\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 38 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0453 \t Validation Loss: 0.7235\n",
      "Training Accuracy: 98.64%\t Validation Accuracy: 80.32%\n",
      "roc_auc_score :0.8210886635853087\n",
      "\n",
      "Class 0 Precision :82.93 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :87.64 \n",
      "\n",
      "Class 1 Precision :65.03 \n",
      "Class 1 Recall :40.79 \n",
      "Class 1 f1_score :50.13 \n",
      "Epoch: 39\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 39 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0490 \t Validation Loss: 0.7612\n",
      "Training Accuracy: 98.61%\t Validation Accuracy: 79.57%\n",
      "roc_auc_score :0.822119924457035\n",
      "\n",
      "Class 0 Precision :81.87 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :87.54 \n",
      "\n",
      "Class 1 Precision :65.85 \n",
      "Class 1 Recall :35.53 \n",
      "Class 1 f1_score :46.15 \n",
      "Epoch: 40\t100.00% complete. 20.04 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 40 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0516 \t Validation Loss: 0.7799\n",
      "Training Accuracy: 98.41%\t Validation Accuracy: 80.37%\n",
      "roc_auc_score :0.815168232195219\n",
      "\n",
      "Class 0 Precision :82.48 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :87.89 \n",
      "\n",
      "Class 1 Precision :67.44 \n",
      "Class 1 Recall :38.16 \n",
      "Class 1 f1_score :48.74 \n",
      "Epoch: 41\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 41 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0451 \t Validation Loss: 0.7786\n",
      "Training Accuracy: 98.68%\t Validation Accuracy: 80.53%\n",
      "roc_auc_score :0.8219024899358878\n",
      "\n",
      "Class 0 Precision :81.93 \n",
      "Class 0 Recall :95.04 \n",
      "Class 0 f1_score :88.00 \n",
      "\n",
      "Class 1 Precision :69.57 \n",
      "Class 1 Recall :35.09 \n",
      "Class 1 f1_score :46.65 \n",
      "Epoch: 42\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 42 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0450 \t Validation Loss: 0.7583\n",
      "Training Accuracy: 98.79%\t Validation Accuracy: 80.96%\n",
      "roc_auc_score :0.8170692311515332\n",
      "\n",
      "Class 0 Precision :82.81 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :88.14 \n",
      "\n",
      "Class 1 Precision :68.70 \n",
      "Class 1 Recall :39.47 \n",
      "Class 1 f1_score :50.14 \n",
      "Epoch: 43\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 43 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0418 \t Validation Loss: 0.7837\n",
      "Training Accuracy: 99.04%\t Validation Accuracy: 80.32%\n",
      "roc_auc_score :0.8116085184632971\n",
      "\n",
      "Class 0 Precision :82.20 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :87.79 \n",
      "\n",
      "Class 1 Precision :67.20 \n",
      "Class 1 Recall :36.84 \n",
      "Class 1 f1_score :47.59 \n",
      "Epoch: 44\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 44 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0419 \t Validation Loss: 0.8265\n",
      "Training Accuracy: 98.90%\t Validation Accuracy: 78.56%\n",
      "roc_auc_score :0.7998919039809154\n",
      "\n",
      "Class 0 Precision :80.93 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :86.88 \n",
      "\n",
      "Class 1 Precision :62.07 \n",
      "Class 1 Recall :31.58 \n",
      "Class 1 f1_score :41.86 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 45\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 45 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0469 \t Validation Loss: 0.8442\n",
      "Training Accuracy: 98.61%\t Validation Accuracy: 78.40%\n",
      "roc_auc_score :0.8025694547984692\n",
      "\n",
      "Class 0 Precision :80.63 \n",
      "Class 0 Recall :94.33 \n",
      "Class 0 f1_score :86.95 \n",
      "\n",
      "Class 1 Precision :62.96 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :40.48 \n",
      "Epoch: 46\t100.00% complete. 20.01 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 46 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0406 \t Validation Loss: 0.7878\n",
      "Training Accuracy: 98.95%\t Validation Accuracy: 80.43%\n",
      "roc_auc_score :0.8059552209134735\n",
      "\n",
      "Class 0 Precision :82.38 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :87.83 \n",
      "\n",
      "Class 1 Precision :67.19 \n",
      "Class 1 Recall :37.72 \n",
      "Class 1 f1_score :48.31 \n",
      "Epoch: 47\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 47 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0409 \t Validation Loss: 0.7808\n",
      "Training Accuracy: 98.88%\t Validation Accuracy: 78.99%\n",
      "roc_auc_score :0.8038678495104616\n",
      "\n",
      "Class 0 Precision :81.80 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :87.00 \n",
      "\n",
      "Class 1 Precision :62.12 \n",
      "Class 1 Recall :35.96 \n",
      "Class 1 f1_score :45.56 \n",
      "Epoch: 48\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 48 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0409 \t Validation Loss: 0.8661\n",
      "Training Accuracy: 98.90%\t Validation Accuracy: 78.62%\n",
      "roc_auc_score :0.8087197455394861\n",
      "\n",
      "Class 0 Precision :80.46 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :86.91 \n",
      "\n",
      "Class 1 Precision :62.86 \n",
      "Class 1 Recall :28.95 \n",
      "Class 1 f1_score :39.64 \n",
      "Epoch: 49\t100.00% complete. 19.93 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 49 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0393 \t Validation Loss: 0.8274\n",
      "Training Accuracy: 98.99%\t Validation Accuracy: 78.88%\n",
      "roc_auc_score :0.803029173500323\n",
      "\n",
      "Class 0 Precision :81.15 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :87.07 \n",
      "\n",
      "Class 1 Precision :63.25 \n",
      "Class 1 Recall :32.46 \n",
      "Class 1 f1_score :42.90 \n",
      "Epoch: 50\t100.00% complete. 19.94 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 50 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0366 \t Validation Loss: 0.7627\n",
      "Training Accuracy: 98.88%\t Validation Accuracy: 80.11%\n",
      "roc_auc_score :0.8142363699617317\n",
      "\n",
      "Class 0 Precision :82.72 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :87.53 \n",
      "\n",
      "Class 1 Precision :64.54 \n",
      "Class 1 Recall :39.91 \n",
      "Class 1 f1_score :49.32 \n",
      "Epoch: 51\t100.00% complete. 19.97 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 51 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0368 \t Validation Loss: 0.7529\n",
      "Training Accuracy: 99.08%\t Validation Accuracy: 81.44%\n",
      "roc_auc_score :0.819454798469261\n",
      "\n",
      "Class 0 Precision :83.27 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :88.52 \n",
      "\n",
      "Class 1 Precision :70.68 \n",
      "Class 1 Recall :41.23 \n",
      "Class 1 f1_score :52.08 \n",
      "Epoch: 52\t100.00% complete. 20.02 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 52 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0359 \t Validation Loss: 0.7974\n",
      "Training Accuracy: 98.97%\t Validation Accuracy: 80.64%\n",
      "roc_auc_score :0.8203307489687393\n",
      "\n",
      "Class 0 Precision :82.03 \n",
      "Class 0 Recall :95.04 \n",
      "Class 0 f1_score :88.06 \n",
      "\n",
      "Class 1 Precision :69.83 \n",
      "Class 1 Recall :35.53 \n",
      "Class 1 f1_score :47.09 \n",
      "Epoch: 53\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 53 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0400 \t Validation Loss: 0.7680\n",
      "Training Accuracy: 98.90%\t Validation Accuracy: 80.64%\n",
      "roc_auc_score :0.8256299388698374\n",
      "\n",
      "Class 0 Precision :82.35 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.99 \n",
      "\n",
      "Class 1 Precision :68.55 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :48.30 \n",
      "Epoch: 54\t100.00% complete. 20.01 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 54 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0429 \t Validation Loss: 0.7824\n",
      "Training Accuracy: 98.70%\t Validation Accuracy: 79.52%\n",
      "roc_auc_score :0.8167026986730281\n",
      "\n",
      "Class 0 Precision :81.75 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :87.41 \n",
      "\n",
      "Class 1 Precision :65.04 \n",
      "Class 1 Recall :35.09 \n",
      "Class 1 f1_score :45.58 \n",
      "Epoch: 55\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 55 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0354 \t Validation Loss: 0.8072\n",
      "Training Accuracy: 99.02%\t Validation Accuracy: 77.98%\n",
      "roc_auc_score :0.8009852890015408\n",
      "\n",
      "Class 0 Precision :80.99 \n",
      "Class 0 Recall :92.92 \n",
      "Class 0 f1_score :86.54 \n",
      "\n",
      "Class 1 Precision :59.68 \n",
      "Class 1 Recall :32.46 \n",
      "Class 1 f1_score :42.05 \n",
      "Epoch: 56\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 56 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0367 \t Validation Loss: 0.8518\n",
      "Training Accuracy: 99.04%\t Validation Accuracy: 78.56%\n",
      "roc_auc_score :0.8076263605188609\n",
      "\n",
      "Class 0 Precision :80.78 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :86.91 \n",
      "\n",
      "Class 1 Precision :62.50 \n",
      "Class 1 Recall :30.70 \n",
      "Class 1 f1_score :41.18 \n",
      "Epoch: 57\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 57 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0397 \t Validation Loss: 0.7943\n",
      "Training Accuracy: 98.97%\t Validation Accuracy: 80.00%\n",
      "roc_auc_score :0.8108878783360668\n",
      "\n",
      "Class 0 Precision :81.66 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.66 \n",
      "\n",
      "Class 1 Precision :67.24 \n",
      "Class 1 Recall :34.21 \n",
      "Class 1 f1_score :45.35 \n",
      "Epoch: 58\t100.00% complete. 19.96 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 58 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0359 \t Validation Loss: 0.8203\n",
      "Training Accuracy: 98.93%\t Validation Accuracy: 79.10%\n",
      "roc_auc_score :0.8122980965160777\n",
      "\n",
      "Class 0 Precision :81.04 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.25 \n",
      "\n",
      "Class 1 Precision :64.86 \n",
      "Class 1 Recall :31.58 \n",
      "Class 1 f1_score :42.48 \n",
      "Epoch: 59\t100.00% complete. 20.08 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 59 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0354 \t Validation Loss: 0.7775\n",
      "Training Accuracy: 99.04%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.8134349684409323\n",
      "\n",
      "Class 0 Precision :81.70 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :87.50 \n",
      "\n",
      "Class 1 Precision :65.83 \n",
      "Class 1 Recall :34.65 \n",
      "Class 1 f1_score :45.40 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 60\t100.00% complete. 20.01 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 60 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0295 \t Validation Loss: 0.7803\n",
      "Training Accuracy: 99.37%\t Validation Accuracy: 80.00%\n",
      "roc_auc_score :0.8167772476517071\n",
      "\n",
      "Class 0 Precision :81.43 \n",
      "Class 0 Recall :95.04 \n",
      "Class 0 f1_score :87.71 \n",
      "\n",
      "Class 1 Precision :68.18 \n",
      "Class 1 Recall :32.89 \n",
      "Class 1 f1_score :44.38 \n",
      "Epoch: 61\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 61 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0399 \t Validation Loss: 0.8105\n",
      "Training Accuracy: 98.79%\t Validation Accuracy: 79.63%\n",
      "roc_auc_score :0.8074337756572734\n",
      "\n",
      "Class 0 Precision :81.54 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.53 \n",
      "\n",
      "Class 1 Precision :66.38 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.77 \n",
      "Epoch: 62\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 62 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0335 \t Validation Loss: 0.7650\n",
      "Training Accuracy: 99.19%\t Validation Accuracy: 80.43%\n",
      "roc_auc_score :0.8112295611550122\n",
      "\n",
      "Class 0 Precision :82.14 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.88 \n",
      "\n",
      "Class 1 Precision :68.03 \n",
      "Class 1 Recall :36.40 \n",
      "Class 1 f1_score :47.43 \n",
      "Epoch: 63\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 63 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0353 \t Validation Loss: 0.7998\n",
      "Training Accuracy: 99.06%\t Validation Accuracy: 79.89%\n",
      "roc_auc_score :0.8134846677600517\n",
      "\n",
      "Class 0 Precision :81.55 \n",
      "Class 0 Recall :95.18 \n",
      "Class 0 f1_score :87.84 \n",
      "\n",
      "Class 1 Precision :69.09 \n",
      "Class 1 Recall :33.33 \n",
      "Class 1 f1_score :44.97 \n",
      "Epoch: 64\t100.00% complete. 20.05 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 64 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0307 \t Validation Loss: 0.8475\n",
      "Training Accuracy: 99.26%\t Validation Accuracy: 79.41%\n",
      "roc_auc_score :0.8069367824660802\n",
      "\n",
      "Class 0 Precision :81.27 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.43 \n",
      "\n",
      "Class 1 Precision :66.07 \n",
      "Class 1 Recall :32.46 \n",
      "Class 1 f1_score :43.53 \n",
      "Epoch: 65\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 65 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0342 \t Validation Loss: 0.8037\n",
      "Training Accuracy: 99.15%\t Validation Accuracy: 79.95%\n",
      "roc_auc_score :0.8169139207792854\n",
      "\n",
      "Class 0 Precision :81.92 \n",
      "Class 0 Recall :94.33 \n",
      "Class 0 f1_score :87.69 \n",
      "\n",
      "Class 1 Precision :66.94 \n",
      "Class 1 Recall :35.53 \n",
      "Class 1 f1_score :46.42 \n",
      "Epoch: 66\t100.00% complete. 20.22 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 66 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0325 \t Validation Loss: 0.8694\n",
      "Training Accuracy: 99.22%\t Validation Accuracy: 79.36%\n",
      "roc_auc_score :0.8109251528254062\n",
      "\n",
      "Class 0 Precision :80.77 \n",
      "Class 0 Recall :95.18 \n",
      "Class 0 f1_score :87.39 \n",
      "\n",
      "Class 1 Precision :66.67 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :41.21 \n",
      "Epoch: 67\t100.00% complete. 20.48 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 67 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0333 \t Validation Loss: 0.8188\n",
      "Training Accuracy: 99.19%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.8118197405695542\n",
      "\n",
      "Class 0 Precision :81.77 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :87.48 \n",
      "\n",
      "Class 1 Precision :65.57 \n",
      "Class 1 Recall :35.09 \n",
      "Class 1 f1_score :45.71 \n",
      "Epoch: 68\t100.00% complete. 19.95 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 68 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0315 \t Validation Loss: 0.7890\n",
      "Training Accuracy: 99.22%\t Validation Accuracy: 79.15%\n",
      "roc_auc_score :0.8120930868247105\n",
      "\n",
      "Class 0 Precision :81.94 \n",
      "Class 0 Recall :93.20 \n",
      "Class 0 f1_score :87.21 \n",
      "\n",
      "Class 1 Precision :63.36 \n",
      "Class 1 Recall :36.40 \n",
      "Class 1 f1_score :46.24 \n",
      "Epoch: 69\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 69 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0298 \t Validation Loss: 0.7909\n",
      "Training Accuracy: 99.17%\t Validation Accuracy: 78.99%\n",
      "roc_auc_score :0.8159385716415686\n",
      "\n",
      "Class 0 Precision :81.48 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.07 \n",
      "\n",
      "Class 1 Precision :62.90 \n",
      "Class 1 Recall :34.21 \n",
      "Class 1 f1_score :44.32 \n",
      "Epoch: 70\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 70 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0297 \t Validation Loss: 0.7797\n",
      "Training Accuracy: 99.26%\t Validation Accuracy: 80.64%\n",
      "roc_auc_score :0.82691590875205\n",
      "\n",
      "Class 0 Precision :82.41 \n",
      "Class 0 Recall :94.90 \n",
      "Class 0 f1_score :88.22 \n",
      "\n",
      "Class 1 Precision :70.25 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :48.71 \n",
      "Epoch: 71\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 71 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0380 \t Validation Loss: 0.7924\n",
      "Training Accuracy: 98.93%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.8210451766810795\n",
      "\n",
      "Class 0 Precision :81.70 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :87.50 \n",
      "\n",
      "Class 1 Precision :65.83 \n",
      "Class 1 Recall :34.65 \n",
      "Class 1 f1_score :45.40 \n",
      "Epoch: 72\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 72 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0319 \t Validation Loss: 0.8082\n",
      "Training Accuracy: 99.19%\t Validation Accuracy: 79.47%\n",
      "roc_auc_score :0.814273644451071\n",
      "\n",
      "Class 0 Precision :81.63 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.28 \n",
      "\n",
      "Class 1 Precision :64.23 \n",
      "Class 1 Recall :34.65 \n",
      "Class 1 f1_score :45.01 \n",
      "Epoch: 73\t100.00% complete. 19.97 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 73 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0336 \t Validation Loss: 0.7939\n",
      "Training Accuracy: 99.17%\t Validation Accuracy: 79.36%\n",
      "roc_auc_score :0.8117265543462056\n",
      "\n",
      "Class 0 Precision :81.45 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :87.24 \n",
      "\n",
      "Class 1 Precision :64.17 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.25 \n",
      "Epoch: 74\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 74 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0331 \t Validation Loss: 0.8664\n",
      "Training Accuracy: 99.19%\t Validation Accuracy: 77.98%\n",
      "roc_auc_score :0.8004199592465583\n",
      "\n",
      "Class 0 Precision :80.17 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :86.74 \n",
      "\n",
      "Class 1 Precision :61.76 \n",
      "Class 1 Recall :27.63 \n",
      "Class 1 f1_score :38.18 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 75\t100.00% complete. 20.07 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 75 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0295 \t Validation Loss: 0.8167\n",
      "Training Accuracy: 99.24%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.8174792505342677\n",
      "\n",
      "Class 0 Precision :81.23 \n",
      "Class 0 Recall :95.04 \n",
      "Class 0 f1_score :87.60 \n",
      "\n",
      "Class 1 Precision :67.59 \n",
      "Class 1 Recall :32.02 \n",
      "Class 1 f1_score :43.45 \n",
      "Epoch: 76\t100.00% complete. 20.01 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 76 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0306 \t Validation Loss: 0.7587\n",
      "Training Accuracy: 99.24%\t Validation Accuracy: 80.11%\n",
      "roc_auc_score :0.824474429700313\n",
      "\n",
      "Class 0 Precision :82.24 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.62 \n",
      "\n",
      "Class 1 Precision :65.89 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :47.62 \n",
      "Epoch: 77\t100.00% complete. 20.04 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 77 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0264 \t Validation Loss: 0.8102\n",
      "Training Accuracy: 99.44%\t Validation Accuracy: 78.88%\n",
      "roc_auc_score :0.8116520053675265\n",
      "\n",
      "Class 0 Precision :81.46 \n",
      "Class 0 Recall :93.34 \n",
      "Class 0 f1_score :87.00 \n",
      "\n",
      "Class 1 Precision :62.40 \n",
      "Class 1 Recall :34.21 \n",
      "Class 1 f1_score :44.19 \n",
      "Epoch: 78\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 78 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0299 \t Validation Loss: 0.8146\n",
      "Training Accuracy: 99.33%\t Validation Accuracy: 80.43%\n",
      "roc_auc_score :0.8146588141742458\n",
      "\n",
      "Class 0 Precision :81.37 \n",
      "Class 0 Recall :95.89 \n",
      "Class 0 f1_score :88.04 \n",
      "\n",
      "Class 1 Precision :71.57 \n",
      "Class 1 Recall :32.02 \n",
      "Class 1 f1_score :44.24 \n",
      "Epoch: 79\t100.00% complete. 20.14 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 79 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0306 \t Validation Loss: 0.8154\n",
      "Training Accuracy: 99.40%\t Validation Accuracy: 79.47%\n",
      "roc_auc_score :0.8084898861885592\n",
      "\n",
      "Class 0 Precision :81.78 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.24 \n",
      "\n",
      "Class 1 Precision :63.78 \n",
      "Class 1 Recall :35.53 \n",
      "Class 1 f1_score :45.63 \n",
      "Epoch: 80\t100.00% complete. 20.66 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 80 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0299 \t Validation Loss: 0.8219\n",
      "Training Accuracy: 99.33%\t Validation Accuracy: 80.27%\n",
      "roc_auc_score :0.8150377714825306\n",
      "\n",
      "Class 0 Precision :81.67 \n",
      "Class 0 Recall :95.33 \n",
      "Class 0 f1_score :87.97 \n",
      "\n",
      "Class 1 Precision :70.00 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :45.56 \n",
      "Epoch: 81\t100.00% complete. 20.32 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 81 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0299 \t Validation Loss: 0.7977\n",
      "Training Accuracy: 99.06%\t Validation Accuracy: 79.95%\n",
      "roc_auc_score :0.8174978877789374\n",
      "\n",
      "Class 0 Precision :82.24 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.62 \n",
      "\n",
      "Class 1 Precision :65.89 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :47.62 \n",
      "Epoch: 82\t100.00% complete. 20.06 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 82 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0261 \t Validation Loss: 0.8321\n",
      "Training Accuracy: 99.46%\t Validation Accuracy: 79.73%\n",
      "roc_auc_score :0.8106766562298096\n",
      "\n",
      "Class 0 Precision :81.56 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.61 \n",
      "\n",
      "Class 1 Precision :66.96 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.90 \n",
      "Epoch: 83\t100.00% complete. 20.24 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 83 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0287 \t Validation Loss: 0.7902\n",
      "Training Accuracy: 99.33%\t Validation Accuracy: 79.79%\n",
      "roc_auc_score :0.8152862680781273\n",
      "\n",
      "Class 0 Precision :82.40 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.59 \n",
      "\n",
      "Class 1 Precision :65.41 \n",
      "Class 1 Recall :38.16 \n",
      "Class 1 f1_score :48.20 \n",
      "Epoch: 84\t100.00% complete. 20.70 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 84 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0294 \t Validation Loss: 0.8150\n",
      "Training Accuracy: 99.31%\t Validation Accuracy: 79.63%\n",
      "roc_auc_score :0.8167710352368173\n",
      "\n",
      "Class 0 Precision :81.70 \n",
      "Class 0 Recall :94.19 \n",
      "Class 0 f1_score :87.50 \n",
      "\n",
      "Class 1 Precision :65.83 \n",
      "Class 1 Recall :34.65 \n",
      "Class 1 f1_score :45.40 \n",
      "Epoch: 85\t100.00% complete. 20.35 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 85 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0272 \t Validation Loss: 0.7500\n",
      "Training Accuracy: 99.40%\t Validation Accuracy: 80.59%\n",
      "roc_auc_score :0.822629342478008\n",
      "\n",
      "Class 0 Precision :83.02 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.94 \n",
      "\n",
      "Class 1 Precision :66.91 \n",
      "Class 1 Recall :40.79 \n",
      "Class 1 f1_score :50.68 \n",
      "Epoch: 86\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 86 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0294 \t Validation Loss: 0.8381\n",
      "Training Accuracy: 99.28%\t Validation Accuracy: 78.19%\n",
      "roc_auc_score :0.8079494060931365\n",
      "\n",
      "Class 0 Precision :80.58 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :86.80 \n",
      "\n",
      "Class 1 Precision :61.82 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :40.24 \n",
      "Epoch: 87\t100.00% complete. 19.99 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 87 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0242 \t Validation Loss: 0.8519\n",
      "Training Accuracy: 99.57%\t Validation Accuracy: 79.68%\n",
      "roc_auc_score :0.8093223497838079\n",
      "\n",
      "Class 0 Precision :81.74 \n",
      "Class 0 Recall :94.48 \n",
      "Class 0 f1_score :87.65 \n",
      "\n",
      "Class 1 Precision :66.95 \n",
      "Class 1 Recall :34.65 \n",
      "Class 1 f1_score :45.66 \n",
      "Epoch: 88\t100.00% complete. 19.97 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 88 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0285 \t Validation Loss: 0.7911\n",
      "Training Accuracy: 99.37%\t Validation Accuracy: 80.05%\n",
      "roc_auc_score :0.8161932806520551\n",
      "\n",
      "Class 0 Precision :82.26 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :87.70 \n",
      "\n",
      "Class 1 Precision :66.41 \n",
      "Class 1 Recall :37.28 \n",
      "Class 1 f1_score :47.75 \n",
      "Epoch: 89\t100.00% complete. 19.97 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 89 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0267 \t Validation Loss: 0.7789\n",
      "Training Accuracy: 99.37%\t Validation Accuracy: 79.95%\n",
      "roc_auc_score :0.8178644202574424\n",
      "\n",
      "Class 0 Precision :82.40 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :87.59 \n",
      "\n",
      "Class 1 Precision :65.41 \n",
      "Class 1 Recall :38.16 \n",
      "Class 1 f1_score :48.20 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 90\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 90 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0251 \t Validation Loss: 0.7689\n",
      "Training Accuracy: 99.46%\t Validation Accuracy: 80.32%\n",
      "roc_auc_score :0.8211383629044282\n",
      "\n",
      "Class 0 Precision :82.12 \n",
      "Class 0 Recall :94.33 \n",
      "Class 0 f1_score :87.80 \n",
      "\n",
      "Class 1 Precision :67.48 \n",
      "Class 1 Recall :36.40 \n",
      "Class 1 f1_score :47.29 \n",
      "Epoch: 91\t100.00% complete. 19.98 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 91 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0273 \t Validation Loss: 0.7538\n",
      "Training Accuracy: 99.37%\t Validation Accuracy: 80.69%\n",
      "roc_auc_score :0.830177426569256\n",
      "\n",
      "Class 0 Precision :82.88 \n",
      "Class 0 Recall :93.91 \n",
      "Class 0 f1_score :88.05 \n",
      "\n",
      "Class 1 Precision :67.91 \n",
      "Class 1 Recall :39.91 \n",
      "Class 1 f1_score :50.28 \n",
      "Epoch: 92\t100.00% complete. 20.02 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 92 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0308 \t Validation Loss: 0.7863\n",
      "Training Accuracy: 99.15%\t Validation Accuracy: 79.89%\n",
      "roc_auc_score :0.8186968838526912\n",
      "\n",
      "Class 0 Precision :81.56 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.61 \n",
      "\n",
      "Class 1 Precision :66.96 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.90 \n",
      "Epoch: 93\t100.00% complete. 20.00 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 93 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0298 \t Validation Loss: 0.8290\n",
      "Training Accuracy: 99.24%\t Validation Accuracy: 79.36%\n",
      "roc_auc_score :0.8142363699617314\n",
      "\n",
      "Class 0 Precision :81.53 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.22 \n",
      "\n",
      "Class 1 Precision :63.93 \n",
      "Class 1 Recall :34.21 \n",
      "Class 1 f1_score :44.57 \n",
      "Epoch: 94\t100.00% complete. 19.95 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 94 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0280 \t Validation Loss: 0.8295\n",
      "Training Accuracy: 99.24%\t Validation Accuracy: 80.32%\n",
      "roc_auc_score :0.8182620148103971\n",
      "\n",
      "Class 0 Precision :81.96 \n",
      "Class 0 Recall :94.62 \n",
      "Class 0 f1_score :87.84 \n",
      "\n",
      "Class 1 Precision :68.07 \n",
      "Class 1 Recall :35.53 \n",
      "Class 1 f1_score :46.69 \n",
      "Epoch: 95\t100.00% complete. 20.47 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 95 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0286 \t Validation Loss: 0.8599\n",
      "Training Accuracy: 99.26%\t Validation Accuracy: 77.93%\n",
      "roc_auc_score :0.8132920828984643\n",
      "\n",
      "Class 0 Precision :80.49 \n",
      "Class 0 Recall :93.48 \n",
      "Class 0 f1_score :86.50 \n",
      "\n",
      "Class 1 Precision :59.65 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :39.77 \n",
      "Epoch: 96\t100.00% complete. 20.73 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 96 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0263 \t Validation Loss: 0.8357\n",
      "Training Accuracy: 99.44%\t Validation Accuracy: 78.78%\n",
      "roc_auc_score :0.8083035137418617\n",
      "\n",
      "Class 0 Precision :81.43 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :87.16 \n",
      "\n",
      "Class 1 Precision :63.64 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.13 \n",
      "Epoch: 97\t100.00% complete. 20.54 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 97 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0272 \t Validation Loss: 0.8160\n",
      "Training Accuracy: 99.31%\t Validation Accuracy: 78.94%\n",
      "roc_auc_score :0.8099249540281299\n",
      "\n",
      "Class 0 Precision :81.13 \n",
      "Class 0 Recall :93.77 \n",
      "Class 0 f1_score :86.99 \n",
      "\n",
      "Class 1 Precision :62.71 \n",
      "Class 1 Recall :32.46 \n",
      "Class 1 f1_score :42.77 \n",
      "Epoch: 98\t100.00% complete. 20.14 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 98 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0280 \t Validation Loss: 0.8553\n",
      "Training Accuracy: 99.24%\t Validation Accuracy: 78.19%\n",
      "roc_auc_score :0.8123850703245366\n",
      "\n",
      "Class 0 Precision :80.58 \n",
      "Class 0 Recall :94.05 \n",
      "Class 0 f1_score :86.80 \n",
      "\n",
      "Class 1 Precision :61.82 \n",
      "Class 1 Recall :29.82 \n",
      "Class 1 f1_score :40.24 \n",
      "Epoch: 99\t100.00% complete. 20.03 seconds elapsed in epoch.\n",
      "---------------------------------------------------------------------------------------------Epoch 99 Results---------------------------------------------------------------------------------------------\n",
      "\n",
      "Training Loss: 0.0239 \t Validation Loss: 0.8146\n",
      "Training Accuracy: 99.51%\t Validation Accuracy: 79.52%\n",
      "roc_auc_score :0.820939565627951\n",
      "\n",
      "Class 0 Precision :81.52 \n",
      "Class 0 Recall :94.33 \n",
      "Class 0 f1_score :87.46 \n",
      "\n",
      "Class 1 Precision :65.81 \n",
      "Class 1 Recall :33.77 \n",
      "Class 1 f1_score :44.64 \n",
      "\n",
      "Best epoch: 4 with loss: 0.43 and acc: 79.52%\n",
      "2144.07 total seconds elapsed. 21.66 seconds per epoch.\n",
      "\n",
      "\n",
      "----------------\n",
      "\n",
      "\n",
      "history df shape:  (100, 5)\n"
     ]
    }
   ],
   "source": [
    "model, history = train(\n",
    "    model,\n",
    "    device,\n",
    "    criterion,\n",
    "    optimizer_ft,\n",
    "    scheduler,\n",
    "    train_loader,\n",
    "    test_loader,\n",
    "    save_file_name=save_file_name,\n",
    "    max_epochs_stop=4,\n",
    "    n_epochs=NUM_EPOCHS,\n",
    "    print_every=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(history):\n",
    "  \n",
    "  # The history object contains results on the training and test\n",
    "  # sets for each epoch\n",
    "  acc = history['train_acc']\n",
    "  val_acc = history['valid_acc']\n",
    "  loss = history['train_loss']\n",
    "  val_loss = history['valid_loss']\n",
    "  auc = history['roc']\n",
    "\n",
    "  # Get the number of epochs\n",
    "  epochs = range(len(acc))\n",
    "\n",
    "  plt.title('Training and validation accuracy')\n",
    "  plt.plot(epochs, acc, color='blue', label='Train')\n",
    "  plt.plot(epochs, val_acc, color='orange', label='Val')\n",
    "  plt.xlabel('Epoch')\n",
    "  plt.ylabel('Accuracy')\n",
    "  plt.grid(True, axis = 'y')\n",
    "  plt.legend()\n",
    "\n",
    "#   _ = plt.figure()\n",
    "#   plt.title('Training and validation loss')\n",
    "#   plt.plot(epochs, loss, color='blue', label='Train')\n",
    "#   plt.plot(epochs, val_loss, color='orange', label='Val')\n",
    "#   plt.xlabel('Epoch')\n",
    "#   plt.ylabel('Loss')\n",
    "#   plt.grid(True, axis = 'y')\n",
    "#   plt.legend()\n",
    "\n",
    "  _ = plt.figure()\n",
    "  plt.title('Validation AUC')\n",
    "  plt.plot(epochs, auc, color='blue', label='Train')\n",
    "  plt.xlabel('Epoch')\n",
    "  plt.ylabel('Auc')\n",
    "  plt.grid(True, axis = 'y')\n",
    "  plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>valid_acc</th>\n",
       "      <th>roc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.545076</td>\n",
       "      <td>0.497729</td>\n",
       "      <td>0.750336</td>\n",
       "      <td>0.757447</td>\n",
       "      <td>0.737762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.423507</td>\n",
       "      <td>0.444942</td>\n",
       "      <td>0.780089</td>\n",
       "      <td>0.761702</td>\n",
       "      <td>0.804862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.374174</td>\n",
       "      <td>0.438525</td>\n",
       "      <td>0.814765</td>\n",
       "      <td>0.772340</td>\n",
       "      <td>0.813727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.335164</td>\n",
       "      <td>0.436611</td>\n",
       "      <td>0.855705</td>\n",
       "      <td>0.781383</td>\n",
       "      <td>0.822269</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.302722</td>\n",
       "      <td>0.430409</td>\n",
       "      <td>0.877852</td>\n",
       "      <td>0.783511</td>\n",
       "      <td>0.831954</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>0.028597</td>\n",
       "      <td>0.859924</td>\n",
       "      <td>0.992617</td>\n",
       "      <td>0.779255</td>\n",
       "      <td>0.813292</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>0.026329</td>\n",
       "      <td>0.835668</td>\n",
       "      <td>0.994407</td>\n",
       "      <td>0.787766</td>\n",
       "      <td>0.808304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>0.027203</td>\n",
       "      <td>0.816045</td>\n",
       "      <td>0.993065</td>\n",
       "      <td>0.789362</td>\n",
       "      <td>0.809925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>0.028048</td>\n",
       "      <td>0.855256</td>\n",
       "      <td>0.992394</td>\n",
       "      <td>0.781915</td>\n",
       "      <td>0.812385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>0.023886</td>\n",
       "      <td>0.814558</td>\n",
       "      <td>0.995078</td>\n",
       "      <td>0.795213</td>\n",
       "      <td>0.820940</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    train_loss  valid_loss  train_acc  valid_acc       roc\n",
       "0     0.545076    0.497729   0.750336   0.757447  0.737762\n",
       "1     0.423507    0.444942   0.780089   0.761702  0.804862\n",
       "2     0.374174    0.438525   0.814765   0.772340  0.813727\n",
       "3     0.335164    0.436611   0.855705   0.781383  0.822269\n",
       "4     0.302722    0.430409   0.877852   0.783511  0.831954\n",
       "..         ...         ...        ...        ...       ...\n",
       "95    0.028597    0.859924   0.992617   0.779255  0.813292\n",
       "96    0.026329    0.835668   0.994407   0.787766  0.808304\n",
       "97    0.027203    0.816045   0.993065   0.789362  0.809925\n",
       "98    0.028048    0.855256   0.992394   0.781915  0.812385\n",
       "99    0.023886    0.814558   0.995078   0.795213  0.820940\n",
       "\n",
       "[100 rows x 5 columns]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_loss    0.302722\n",
      "valid_loss    0.430409\n",
      "train_acc     0.877852\n",
      "valid_acc     0.783511\n",
      "roc           0.831954\n",
      "Name: 4, dtype: float64\n",
      "train_loss    0.036840\n",
      "valid_loss    0.752872\n",
      "train_acc     0.990828\n",
      "valid_acc     0.814362\n",
      "roc           0.819455\n",
      "Name: 51, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "print(history.iloc[history['roc'].argmax()])\n",
    "print(history.iloc[history['valid_acc'].argmax()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot(history)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}