[bbafc7]: / ecg / examples / cinc17 / notebooks / cinc17_eval.ipynb

Download this file

163 lines (162 with data), 4.1 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import json\n",
    "import keras\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(\"../../../ecg\")\n",
    "import scipy.stats as sst\n",
    "\n",
    "import util\n",
    "import load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 852/852 [00:00<00:00, 931.25it/s]\n"
     ]
    }
   ],
   "source": [
    "model_path = \"/deep/group/awni/ecg_models/default/1528249597-44/0.412-0.870-015-0.309-0.892.hdf5\"\n",
    "data_path = \"../dev.json\"\n",
    "\n",
    "data = load.load_dataset(data_path)\n",
    "preproc = util.load(os.path.dirname(model_path))\n",
    "model = keras.models.load_model(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"../train.json\"\n",
    "with open(\"../train.json\", 'r') as fid:\n",
    "    train_labels = [json.loads(l)['labels'] for l in fid]\n",
    "counts = collections.Counter(preproc.class_to_int[l[0]] for l in train_labels)\n",
    "counts = sorted(counts.most_common(), key=lambda x: x[0])\n",
    "counts = zip(*counts)[1]\n",
    "smooth = 500\n",
    "counts = np.array(counts)[None, None, :]\n",
    "total = np.sum(counts) + counts.shape[1]\n",
    "prior = (counts + smooth) / float(total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = []\n",
    "labels = []\n",
    "for x, y  in zip(*data):\n",
    "    x, y = preproc.process([x], [y])\n",
    "    probs.append(model.predict(x))\n",
    "    labels.append(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = []\n",
    "ground_truth = []\n",
    "for p, g in zip(probs, labels):\n",
    "    preds.append(sst.mode(np.argmax(p / prior, axis=2).squeeze())[0][0])\n",
    "    ground_truth.append(sst.mode(np.argmax(g, axis=2).squeeze())[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "          A      0.859     0.912     0.885        80\n",
      "          N      0.914     0.923     0.919       508\n",
      "          O      0.803     0.785     0.794       233\n",
      "          ~      0.731     0.613     0.667        31\n",
      "\n",
      "avg / total      0.872     0.873     0.872       852\n",
      "\n",
      "CINC Average 0.865827\n"
     ]
    }
   ],
   "source": [
    "import sklearn.metrics as skm\n",
    "report = skm.classification_report(\n",
    "            ground_truth, preds,\n",
    "            target_names=preproc.classes,\n",
    "            digits=3)\n",
    "scores = skm.precision_recall_fscore_support(\n",
    "                    ground_truth,\n",
    "                    preds,\n",
    "                    average=None)\n",
    "print(report)\n",
    "print \"CINC Average {:3f}\".format(np.mean(scores[2][:3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}