--- a +++ b/examples/irhythm/notebooks/dev_results.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import keras\n", + "import math\n", + "import numpy as np\n", + "import os\n", + "import sklearn.metrics as skm\n", + "import sys\n", + "sys.path.append(\"../../../ecg\")\n", + "\n", + "import load\n", + "import util\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 8761/8761 [00:03<00:00, 2561.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8761/8761 [==============================] - 20s 2ms/step\n" + ] + } + ], + "source": [ + "model_path = \"/deep/group/awni/ecg_models/default/1527627404-9/0.337-0.880-012-0.255-0.906.hdf5\"\n", + "data_json = \"../dev.json\"\n", + "\n", + "preproc = util.load(os.path.dirname(model_path))\n", + "dataset = load.load_dataset(data_json)\n", + "ecgs, labels = preproc.process(*dataset)\n", + "\n", + "model = keras.models.load_model(model_path)\n", + "probs = model.predict(ecgs, verbose=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "def stats(ground_truth, preds):\n", + " labels = range(ground_truth.shape[2])\n", + " g = np.argmax(ground_truth, axis=2).ravel()\n", + " p = np.argmax(preds, axis=2).ravel()\n", + " stat_dict = {}\n", + " for i in labels:\n", + " # compute all the stats for each label\n", + " tp = np.sum(g[g==i] == p[g==i])\n", + " fp = np.sum(g[p==i] != p[p==i])\n", + " fn = np.sum(g==i) - tp\n", + " tn = np.sum(g!=i) - fp\n", + " stat_dict[i] = (tp, fp, fn, tn)\n", + " return stat_dict\n", + "\n", + "def to_set(preds):\n", + " idxs = np.argmax(preds, axis=2)\n", + " return [list(set(r)) for r in idxs]\n", + "\n", + "def set_stats(ground_truth, preds):\n", + " labels = range(ground_truth.shape[2])\n", + " ground_truth = to_set(ground_truth)\n", + " preds = to_set(preds)\n", + " stat_dict = {}\n", + " for x in labels:\n", + " tp = 0; fp = 0; fn = 0; tn = 0;\n", + " for g, p in zip(ground_truth, preds):\n", + " if x in g and x in p: # tp\n", + " tp += 1\n", + " if x not in g and x in p: # fp\n", + " fp += 1\n", + " if x in g and x not in p:\n", + " fn += 1\n", + " if x not in g and x not in p:\n", + " tn += 1\n", + " stat_dict[x] = (tp, fp, fn, tn)\n", + " return stat_dict\n", + "\n", + "def compute_f1(tp, fp, fn, tn):\n", + " precision = tp / float(tp + fp)\n", + " recall = tp / float(tp + fn)\n", + " specificity = tn / float(tn + fp)\n", + " npv = tn / float(tn + fn)\n", + " f1 = 2 * precision * recall / (precision + recall)\n", + " return f1, tp + fn\n", + "\n", + "def print_results(seq_sd, set_sd):\n", + " print \"\\t\\t Seq F1 Set F1\"\n", + " seq_tf1 = 0; seq_tot = 0\n", + " set_tf1 = 0; set_tot = 0\n", + " for k, v in seq_sd.items():\n", + " set_f1, n = compute_f1(*set_sd[k])\n", + " set_tf1 += n * set_f1\n", + " set_tot += n\n", + " seq_f1, n = compute_f1(*v)\n", + " seq_tf1 += n * seq_f1\n", + " seq_tot += n\n", + " print \"{:>10} {:10.3f} {:10.3f}\".format(\n", + " preproc.classes[k], seq_f1, set_f1)\n", + " print \"{:>10} {:10.3f} {:10.3f}\".format(\n", + " \"Average\", seq_tf1 / float(seq_tot), set_tf1 / float(set_tot))\n", + " \n", + "def c_statistic_with_95p_confidence_interval(cstat, num_positives, num_negatives, z_alpha_2=1.96):\n", + " \"\"\"\n", + " Calculates the confidence interval of an ROC curve (c-statistic), using the method described\n", + " under \"Confidence Interval for AUC\" here:\n", + " https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/PASS/Confidence_Intervals_for_the_Area_Under_an_ROC_Curve.pdf\n", + " Args:\n", + " cstat: the c-statistic (equivalent to area under the ROC curve)\n", + " num_positives: number of positive examples in the set.\n", + " num_negatives: number of negative examples in the set.\n", + " z_alpha_2 (optional): the critical value for an N% confidence interval, e.g., 1.96 for 95%,\n", + " 2.326 for 98%, 2.576 for 99%, etc.\n", + " Returns:\n", + " The 95% confidence interval half-width, e.g., the Y in X ± Y.\n", + " \"\"\"\n", + " q1 = cstat / (2 - cstat)\n", + " q2 = 2 * cstat**2 / (1 + cstat)\n", + " numerator = cstat * (1 - cstat) \\\n", + " + (num_positives - 1) * (q1 - cstat**2) \\\n", + " + (num_negatives - 1) * (q2 - cstat**2)\n", + " standard_error_auc = math.sqrt(numerator / (num_positives * num_negatives))\n", + " return z_alpha_2 * standard_error_auc\n", + "\n", + "def roc_auc(ground_truth, probs, index):\n", + " gts = np.argmax(ground_truth, axis=2)\n", + " n_gts = np.zeros_like(gts)\n", + " n_gts[gts==index] = 1\n", + " n_pos = np.sum(n_gts == 1)\n", + " n_neg = n_gts.size - n_pos\n", + " n_ps = probs[..., index].squeeze()\n", + " n_gts, n_ps = n_gts.ravel(), n_ps.ravel()\n", + " return n_pos, n_neg, skm.roc_auc_score(n_gts, n_ps)\n", + "\n", + "def roc_auc_set(ground_truth, probs, index):\n", + " gts = np.argmax(ground_truth, axis=2)\n", + " max_ps = np.max(probs[...,index], axis=1)\n", + " max_gts = np.any(gts==index, axis=1)\n", + " pos = np.sum(max_gts)\n", + " neg = max_gts.size - pos\n", + " return pos, neg, skm.roc_auc_score(max_gts, max_ps)\n", + "\n", + "def print_aucs(ground_truth, probs):\n", + " seq_tauc = 0.0; seq_tot = 0.0\n", + " set_tauc = 0.0; set_tot = 0.0\n", + " print \"\\t AUC\"\n", + " for idx, cname in preproc.int_to_class.items():\n", + " pos, neg, seq_auc = roc_auc(ground_truth, probs, idx)\n", + " seq_tot += pos\n", + " seq_tauc += pos * seq_auc\n", + " seq_conf = c_statistic_with_95p_confidence_interval(seq_auc, pos, neg)\n", + " pos, neg, set_auc = roc_auc_set(ground_truth, probs, idx)\n", + " set_tot += pos\n", + " set_tauc += pos * set_auc\n", + " set_conf = c_statistic_with_95p_confidence_interval(set_auc, pos, neg)\n", + " print \"{: <8}\\t{:.3f} ({:.3f}-{:.3f})\\t{:.3f} ({:.3f}-{:.3f})\".format(\n", + " cname, seq_auc, seq_auc-seq_conf,seq_auc+seq_conf,\n", + " set_auc, set_auc-set_conf, set_auc+set_conf)\n", + " print \"Average\\t\\t{:.3f}\\t{:.3f}\".format(seq_tauc/seq_tot, set_tauc/set_tot)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\t\t Seq F1 Set F1\n", + " AF 0.914 0.914\n", + " AVB 0.805 0.839\n", + " BIGEMINY 0.917 0.896\n", + " EAR 0.652 0.699\n", + " IVR 0.721 0.758\n", + "JUNCTIONAL 0.706 0.740\n", + " NOISE 0.911 0.847\n", + " SINUS 0.920 0.960\n", + " SVT 0.700 0.812\n", + " TRIGEMINY 0.924 0.918\n", + " VT 0.769 0.848\n", + "WENCKEBACH 0.779 0.822\n", + " Average 0.879 0.889\n", + "\t AUC\n", + "AF \t0.994 (0.994-0.995)\t0.994 (0.991-0.996)\n", + "AVB \t0.992 (0.990-0.993)\t0.990 (0.985-0.995)\n", + "BIGEMINY\t0.999 (0.998-1.000)\t0.998 (0.994-1.001)\n", + "EAR \t0.977 (0.975-0.980)\t0.967 (0.957-0.977)\n", + "IVR \t0.996 (0.994-0.998)\t0.991 (0.984-0.998)\n", + "JUNCTIONAL\t0.987 (0.985-0.989)\t0.984 (0.976-0.992)\n", + "NOISE \t0.994 (0.993-0.994)\t0.978 (0.973-0.984)\n", + "SINUS \t0.979 (0.979-0.980)\t0.987 (0.985-0.989)\n", + "SVT \t0.986 (0.984-0.988)\t0.983 (0.977-0.989)\n", + "TRIGEMINY\t0.999 (0.999-1.000)\t0.998 (0.994-1.001)\n", + "VT \t0.997 (0.995-0.998)\t0.992 (0.988-0.997)\n", + "WENCKEBACH\t0.991 (0.990-0.993)\t0.990 (0.985-0.996)\n", + "Average\t\t0.986\t0.987\n" + ] + } + ], + "source": [ + "print_results(stats(labels, probs), set_stats(labels, probs))\n", + "print_aucs(labels, probs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}