248 lines (247 with data), 9.2 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import keras\n",
"import math\n",
"import numpy as np\n",
"import os\n",
"import sklearn.metrics as skm\n",
"import sys\n",
"sys.path.append(\"../../../ecg\")\n",
"\n",
"import load\n",
"import util\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 8761/8761 [00:03<00:00, 2561.29it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"8761/8761 [==============================] - 20s 2ms/step\n"
]
}
],
"source": [
"model_path = \"/deep/group/awni/ecg_models/default/1527627404-9/0.337-0.880-012-0.255-0.906.hdf5\"\n",
"data_json = \"../dev.json\"\n",
"\n",
"preproc = util.load(os.path.dirname(model_path))\n",
"dataset = load.load_dataset(data_json)\n",
"ecgs, labels = preproc.process(*dataset)\n",
"\n",
"model = keras.models.load_model(model_path)\n",
"probs = model.predict(ecgs, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def stats(ground_truth, preds):\n",
" labels = range(ground_truth.shape[2])\n",
" g = np.argmax(ground_truth, axis=2).ravel()\n",
" p = np.argmax(preds, axis=2).ravel()\n",
" stat_dict = {}\n",
" for i in labels:\n",
" # compute all the stats for each label\n",
" tp = np.sum(g[g==i] == p[g==i])\n",
" fp = np.sum(g[p==i] != p[p==i])\n",
" fn = np.sum(g==i) - tp\n",
" tn = np.sum(g!=i) - fp\n",
" stat_dict[i] = (tp, fp, fn, tn)\n",
" return stat_dict\n",
"\n",
"def to_set(preds):\n",
" idxs = np.argmax(preds, axis=2)\n",
" return [list(set(r)) for r in idxs]\n",
"\n",
"def set_stats(ground_truth, preds):\n",
" labels = range(ground_truth.shape[2])\n",
" ground_truth = to_set(ground_truth)\n",
" preds = to_set(preds)\n",
" stat_dict = {}\n",
" for x in labels:\n",
" tp = 0; fp = 0; fn = 0; tn = 0;\n",
" for g, p in zip(ground_truth, preds):\n",
" if x in g and x in p: # tp\n",
" tp += 1\n",
" if x not in g and x in p: # fp\n",
" fp += 1\n",
" if x in g and x not in p:\n",
" fn += 1\n",
" if x not in g and x not in p:\n",
" tn += 1\n",
" stat_dict[x] = (tp, fp, fn, tn)\n",
" return stat_dict\n",
"\n",
"def compute_f1(tp, fp, fn, tn):\n",
" precision = tp / float(tp + fp)\n",
" recall = tp / float(tp + fn)\n",
" specificity = tn / float(tn + fp)\n",
" npv = tn / float(tn + fn)\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
" return f1, tp + fn\n",
"\n",
"def print_results(seq_sd, set_sd):\n",
" print \"\\t\\t Seq F1 Set F1\"\n",
" seq_tf1 = 0; seq_tot = 0\n",
" set_tf1 = 0; set_tot = 0\n",
" for k, v in seq_sd.items():\n",
" set_f1, n = compute_f1(*set_sd[k])\n",
" set_tf1 += n * set_f1\n",
" set_tot += n\n",
" seq_f1, n = compute_f1(*v)\n",
" seq_tf1 += n * seq_f1\n",
" seq_tot += n\n",
" print \"{:>10} {:10.3f} {:10.3f}\".format(\n",
" preproc.classes[k], seq_f1, set_f1)\n",
" print \"{:>10} {:10.3f} {:10.3f}\".format(\n",
" \"Average\", seq_tf1 / float(seq_tot), set_tf1 / float(set_tot))\n",
" \n",
"def c_statistic_with_95p_confidence_interval(cstat, num_positives, num_negatives, z_alpha_2=1.96):\n",
" \"\"\"\n",
" Calculates the confidence interval of an ROC curve (c-statistic), using the method described\n",
" under \"Confidence Interval for AUC\" here:\n",
" https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/PASS/Confidence_Intervals_for_the_Area_Under_an_ROC_Curve.pdf\n",
" Args:\n",
" cstat: the c-statistic (equivalent to area under the ROC curve)\n",
" num_positives: number of positive examples in the set.\n",
" num_negatives: number of negative examples in the set.\n",
" z_alpha_2 (optional): the critical value for an N% confidence interval, e.g., 1.96 for 95%,\n",
" 2.326 for 98%, 2.576 for 99%, etc.\n",
" Returns:\n",
" The 95% confidence interval half-width, e.g., the Y in X ± Y.\n",
" \"\"\"\n",
" q1 = cstat / (2 - cstat)\n",
" q2 = 2 * cstat**2 / (1 + cstat)\n",
" numerator = cstat * (1 - cstat) \\\n",
" + (num_positives - 1) * (q1 - cstat**2) \\\n",
" + (num_negatives - 1) * (q2 - cstat**2)\n",
" standard_error_auc = math.sqrt(numerator / (num_positives * num_negatives))\n",
" return z_alpha_2 * standard_error_auc\n",
"\n",
"def roc_auc(ground_truth, probs, index):\n",
" gts = np.argmax(ground_truth, axis=2)\n",
" n_gts = np.zeros_like(gts)\n",
" n_gts[gts==index] = 1\n",
" n_pos = np.sum(n_gts == 1)\n",
" n_neg = n_gts.size - n_pos\n",
" n_ps = probs[..., index].squeeze()\n",
" n_gts, n_ps = n_gts.ravel(), n_ps.ravel()\n",
" return n_pos, n_neg, skm.roc_auc_score(n_gts, n_ps)\n",
"\n",
"def roc_auc_set(ground_truth, probs, index):\n",
" gts = np.argmax(ground_truth, axis=2)\n",
" max_ps = np.max(probs[...,index], axis=1)\n",
" max_gts = np.any(gts==index, axis=1)\n",
" pos = np.sum(max_gts)\n",
" neg = max_gts.size - pos\n",
" return pos, neg, skm.roc_auc_score(max_gts, max_ps)\n",
"\n",
"def print_aucs(ground_truth, probs):\n",
" seq_tauc = 0.0; seq_tot = 0.0\n",
" set_tauc = 0.0; set_tot = 0.0\n",
" print \"\\t AUC\"\n",
" for idx, cname in preproc.int_to_class.items():\n",
" pos, neg, seq_auc = roc_auc(ground_truth, probs, idx)\n",
" seq_tot += pos\n",
" seq_tauc += pos * seq_auc\n",
" seq_conf = c_statistic_with_95p_confidence_interval(seq_auc, pos, neg)\n",
" pos, neg, set_auc = roc_auc_set(ground_truth, probs, idx)\n",
" set_tot += pos\n",
" set_tauc += pos * set_auc\n",
" set_conf = c_statistic_with_95p_confidence_interval(set_auc, pos, neg)\n",
" print \"{: <8}\\t{:.3f} ({:.3f}-{:.3f})\\t{:.3f} ({:.3f}-{:.3f})\".format(\n",
" cname, seq_auc, seq_auc-seq_conf,seq_auc+seq_conf,\n",
" set_auc, set_auc-set_conf, set_auc+set_conf)\n",
" print \"Average\\t\\t{:.3f}\\t{:.3f}\".format(seq_tauc/seq_tot, set_tauc/set_tot)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t\t Seq F1 Set F1\n",
" AF 0.914 0.914\n",
" AVB 0.805 0.839\n",
" BIGEMINY 0.917 0.896\n",
" EAR 0.652 0.699\n",
" IVR 0.721 0.758\n",
"JUNCTIONAL 0.706 0.740\n",
" NOISE 0.911 0.847\n",
" SINUS 0.920 0.960\n",
" SVT 0.700 0.812\n",
" TRIGEMINY 0.924 0.918\n",
" VT 0.769 0.848\n",
"WENCKEBACH 0.779 0.822\n",
" Average 0.879 0.889\n",
"\t AUC\n",
"AF \t0.994 (0.994-0.995)\t0.994 (0.991-0.996)\n",
"AVB \t0.992 (0.990-0.993)\t0.990 (0.985-0.995)\n",
"BIGEMINY\t0.999 (0.998-1.000)\t0.998 (0.994-1.001)\n",
"EAR \t0.977 (0.975-0.980)\t0.967 (0.957-0.977)\n",
"IVR \t0.996 (0.994-0.998)\t0.991 (0.984-0.998)\n",
"JUNCTIONAL\t0.987 (0.985-0.989)\t0.984 (0.976-0.992)\n",
"NOISE \t0.994 (0.993-0.994)\t0.978 (0.973-0.984)\n",
"SINUS \t0.979 (0.979-0.980)\t0.987 (0.985-0.989)\n",
"SVT \t0.986 (0.984-0.988)\t0.983 (0.977-0.989)\n",
"TRIGEMINY\t0.999 (0.999-1.000)\t0.998 (0.994-1.001)\n",
"VT \t0.997 (0.995-0.998)\t0.992 (0.988-0.997)\n",
"WENCKEBACH\t0.991 (0.990-0.993)\t0.990 (0.985-0.996)\n",
"Average\t\t0.986\t0.987\n"
]
}
],
"source": [
"print_results(stats(labels, probs), set_stats(labels, probs))\n",
"print_aucs(labels, probs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}