Switch to side-by-side view

--- a
+++ b/examples/cinc17/notebooks/cinc17_eval.ipynb
@@ -0,0 +1,162 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Using TensorFlow backend.\n"
+     ]
+    }
+   ],
+   "source": [
+    "import collections\n",
+    "import json\n",
+    "import keras\n",
+    "import numpy as np\n",
+    "import os\n",
+    "import sys\n",
+    "sys.path.append(\"../../../ecg\")\n",
+    "import scipy.stats as sst\n",
+    "\n",
+    "import util\n",
+    "import load"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 852/852 [00:00<00:00, 931.25it/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "model_path = \"/deep/group/awni/ecg_models/default/1528249597-44/0.412-0.870-015-0.309-0.892.hdf5\"\n",
+    "data_path = \"../dev.json\"\n",
+    "\n",
+    "data = load.load_dataset(data_path)\n",
+    "preproc = util.load(os.path.dirname(model_path))\n",
+    "model = keras.models.load_model(model_path)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_path = \"../train.json\"\n",
+    "with open(\"../train.json\", 'r') as fid:\n",
+    "    train_labels = [json.loads(l)['labels'] for l in fid]\n",
+    "counts = collections.Counter(preproc.class_to_int[l[0]] for l in train_labels)\n",
+    "counts = sorted(counts.most_common(), key=lambda x: x[0])\n",
+    "counts = zip(*counts)[1]\n",
+    "smooth = 500\n",
+    "counts = np.array(counts)[None, None, :]\n",
+    "total = np.sum(counts) + counts.shape[1]\n",
+    "prior = (counts + smooth) / float(total)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "probs = []\n",
+    "labels = []\n",
+    "for x, y  in zip(*data):\n",
+    "    x, y = preproc.process([x], [y])\n",
+    "    probs.append(model.predict(x))\n",
+    "    labels.append(y)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "preds = []\n",
+    "ground_truth = []\n",
+    "for p, g in zip(probs, labels):\n",
+    "    preds.append(sst.mode(np.argmax(p / prior, axis=2).squeeze())[0][0])\n",
+    "    ground_truth.append(sst.mode(np.argmax(g, axis=2).squeeze())[0][0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "             precision    recall  f1-score   support\n",
+      "\n",
+      "          A      0.859     0.912     0.885        80\n",
+      "          N      0.914     0.923     0.919       508\n",
+      "          O      0.803     0.785     0.794       233\n",
+      "          ~      0.731     0.613     0.667        31\n",
+      "\n",
+      "avg / total      0.872     0.873     0.872       852\n",
+      "\n",
+      "CINC Average 0.865827\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sklearn.metrics as skm\n",
+    "report = skm.classification_report(\n",
+    "            ground_truth, preds,\n",
+    "            target_names=preproc.classes,\n",
+    "            digits=3)\n",
+    "scores = skm.precision_recall_fscore_support(\n",
+    "                    ground_truth,\n",
+    "                    preds,\n",
+    "                    average=None)\n",
+    "print(report)\n",
+    "print \"CINC Average {:3f}\".format(np.mean(scores[2][:3]))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 2",
+   "language": "python",
+   "name": "python2"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}