[d4cc86]: / 03_TrainModel.ipynb

Download this file

414 lines (413 with data), 26.1 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* [Constants](#Constants)\n",
    "* [Load data](#Load-data)\n",
    "* [Train Word2Vec](#Train-Word2Vec)\n",
    "* [Prepare text](#Prepare-text)\n",
    "* [Defining the neural network](#Defining-the-neural-network) \n",
    "* [Training the neural net](#Training-the-neural-net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import string\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from os.path import isfile\n",
    "\n",
    "from keras.models import Model\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.layers import Embedding, Input, Conv1D, Dense, GlobalMaxPooling1D\n",
    "from keras.optimizers import RMSprop\n",
    "from keras.regularizers import l1\n",
    "\n",
    "from gensim.models import word2vec\n",
    "from gensim.models import KeyedVectors\n",
    "\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Location of train/test data files generated by TextSections/TextPrep\n",
    "TRAIN_DATA_LOC = \"~/train_data.csv\"\n",
    "TEST_DATA_LOC  = \"~/test_data.csv\"\n",
    "\n",
    "# Columns we will use:\n",
    "VISITID        = \"visit_id\"\n",
    "OUTCOME        = \"readmitted\" # e.g. ReadmissionInLessThan30Days\n",
    "\n",
    "# Test/Train split\n",
    "SPLIT_SIZE = 0.9 # relative size of train:test\n",
    "SPLIT_SEED = 1234\n",
    "\n",
    "# Word2Vec hyperparameters\n",
    "WINDOW        = 2\n",
    "DIMENSIONS    = 1000\n",
    "MIN_COUNT     = 5\n",
    "USE_SKIPGRAM  = True  \n",
    "USE_HIER_SMAX = False  \n",
    "NUM_THREADS   = 50\n",
    "# Where to save the w2v model:\n",
    "W2V_FILENAME = './w2v_dims_{dims}_window_{window}.bin'.format(\n",
    "    dims    = DIMENSIONS,\n",
    "    window  = WINDOW\n",
    ")\n",
    "\n",
    "\n",
    "# Text Prep\n",
    "PADDING       = \"PADDING\"\n",
    "MAX_NOTE_LEN  = 700\n",
    "MIN_NOTE_LEN  = 20\n",
    "\n",
    "# Model Architecture\n",
    "UNITS         = 450\n",
    "FILTERSIZE    = 3\n",
    "LEARNING_RATE = 0.0001\n",
    "LOSS_FUNC     = 'binary_crossentropy'\n",
    "REG_FACTOR    = 0.05\n",
    "\n",
    "# Model Training\n",
    "CNN_FILENAME = \"./cnn.h5\"\n",
    "BATCH_SIZE   = 100\n",
    "EPOCHS       = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Read train and test hospital data.\n",
    "train = pd.read_csv(TRAIN_DATA_LOC, dtype = str)\n",
    "test  = pd.read_csv(TEST_DATA_LOC,  dtype = str)\n",
    "\n",
    "# Split the train data into a train and validation set.\n",
    "train, valid = train_test_split(train, \n",
    "                                stratify     = train[OUTCOME], \n",
    "                                train_size   = SPLIT_SIZE, \n",
    "                                random_state = SPLIT_SEED)\n",
    "\n",
    "# Prepare the sections.\n",
    "# If `sectiontext` is present, then include \"SECTIONNAME sectiontext\".\n",
    "# If not present, include only \"SECTIONNAME\".\n",
    "SECTIONNAMES = [x for x in trainTXT.columns if VISITID not in x and OUTCOME not in x]\n",
    "for x in SECTIONNAMES:\n",
    "    rep      = x.replace(\" \", \"_\").upper()\n",
    "    train[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in train[x]]\n",
    "    valid[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in valid[x]]\n",
    "    test[x]  = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in test[x]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-10-27 12:32:33,194 : INFO : loading projection weights from ./word2vec/w2v_dims_1000_window_2.bin\n",
      "2017-10-27 12:32:33,507 : INFO : loaded (22330, 1000) matrix from ./word2vec/w2v_dims_1000_window_2.bin\n"
     ]
    }
   ],
   "source": [
    "# We will remove digits and punctuation:\n",
    "remove_digits_punc = str.maketrans('', '', string.digits + ''.join([x for x in string.punctuation if '_' not in x]))\n",
    "remove_digits_punc = {a:\" \" for a in remove_digits_punc.keys()}\n",
    "\n",
    "# (If the model already exists, don't recompute.)\n",
    "if not isfile(W2V_FILENAME):\n",
    "    # Use only training data to train word2vec:\n",
    "    notes = train[SECTIONNAMES].apply(lambda x: \" \".join(x), axis=1).values  \n",
    "    stop  = set([x for x in string.ascii_lowercase]) \n",
    "    for i in range(len(notes)):\n",
    "        notes[i] = [w for w in notes[i].translate(remove_digits_punc).split() if (w not in stop)]\n",
    "    \n",
    "    w2v = word2vec.Word2Vec(notes, \n",
    "                            size      = DIMENSIONS, \n",
    "                            window    = WINDOW, \n",
    "                            sg        = USE_SKIPGRAM, \n",
    "                            hs        = USE_HIER_SMAX, \n",
    "                            min_count = MIN_COUNT, \n",
    "                            workers   = NUM_THREADS)\n",
    "    w2v.wv.save_word2vec_format(W2V_FILENAME, binary=True)\n",
    "else:\n",
    "    w2v = KeyedVectors.load_word2vec_format(W2V_FILENAME, binary=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Make the embedding matrix.\n",
    "# We include one extra word, `PADDING`. This is the word that will right-pad short notes.\n",
    "# For `PADDING`'s vector representation, we choose the zero vector.\n",
    "vocab = [PADDING] + sorted(list(w2v.wv.vocab.keys()))\n",
    "vset  = set(vocab)\n",
    "\n",
    "embeddings_index = {}\n",
    "for i in range(len(vocab)):\n",
    "    embeddings_index[vocab[i]] = i\n",
    "\n",
    "# reverse_embeddings_index = {b:a for a,b in embeddings_index.items()}\n",
    "\n",
    "# Adding PADDING as vocab word with embedding value of a zero vector\n",
    "embeddings_matrix = np.matrix(np.concatenate(([[0.] * DIMENSIONS], [w2v[x] for x in vocab[1:]])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Prepare text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_x = train[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values  \n",
    "test_x  = test[ SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values  \n",
    "valid_x = valid[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values  \n",
    "\n",
    "train_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in train_x]\n",
    "valid_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in valid_x]\n",
    "test_x  = [[embeddings_index[x] for x in note.split() if x in vset] for note in test_x]\n",
    "\n",
    "train_y = train[OUTCOME]\n",
    "valid_y = valid[OUTCOME]\n",
    "test_y  = test[OUTCOME]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEWCAYAAAC9qEq5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X2cHFWd7/HP1xCeHyIQFZJAggTY4FXE8KAoy72KJIhm\nFR9g0QiiuXFB94KsBvEBWbmirnJFgYiKEFARRTBiXAzeBXyKJGgIBAgMAUxChPBgABMCCb/945yG\nStvT0z2Zqpnp+b5fr35N96k6p87p6qlfn1OnqxQRmJmZVeFF/V0BMzMbOhx0zMysMg46ZmZWGQcd\nMzOrjIOOmZlVxkHHzMwq46DTYSTNlPTpPiprN0lPSRqWX98g6YN9UXYu7xeS3t9X5bWx3c9LekTS\nX6redqskHSZpeR+VJUnflfS4pJt7WcZGn4VelnG/pDf1Nr91BgedQST/066V9KSkv0r6naTpkp7f\njxExPSL+vcWymh4AIuLPEbFtRGzog7qfKenyuvInR8Slm1p2m/XYDfgYMCEiXtZgeZ8d7NusV0ja\ns6TiXw8cDoyOiAMbbPt4SRtyUHlK0n05SO1VW6cvPwsDUf5/eFjSNoW0D0q6ocX8ffqFrJM56Aw+\nb42I7YDdgXOATwDf6euNSNqsr8scIHYDHo2Ih/u7IhXaHbg/Iv7WZJ3fR8S2wA7Am4C1wC2SXlFF\nBduRe25lHLuGAf9aQrlWFBF+DJIHcD/wprq0A4HngFfk15cAn8/PdwauBf4KPAb8mvRF47KcZy3w\nFPBxYCwQwInAn4GbCmmb5fJuAL4A3Aw8AfwU2DEvOwxY3qi+wCTgGeDZvL1bC+V9MD9/EfAp4AHg\nYWAWsENeVqvH+3PdHgHOaPI+7ZDzr8rlfSqXXzuYPpfrcUmDvH/XjsKyLYD/yHV4CJgJbFXMR+pF\nPQysBE4o5N0J+Fl+3+YDnwd+k5fdlNv3t1yv9/RUXoO67QrMzvu5C/hQTj8ReBrYkMv+XIO8x9fq\nUpd+LfDjun2wWSHPUuBJ4D7guEK+DwF35mV3APsXPg+nAYuA1cAPgS3zshfn7a0CHs/PRxfKvAE4\nG/ht3od7AuPye/ckcD1wPnB5Ic/BwO9In/9bgcN6+N+akd+/ETntg8ANhXVel/fd6vz3dTn97Pz+\nPp3f42/k9H2AubnMJcC7C2Udmd+bJ4EVwGn9fXyp7DjW3xXwo42d1SDo5PQ/Ax/Ozy/hhaDzBdKB\ncXh+vAFQo7IKB5VZwDbAVg0ONDfkf5BX5HWuqv2T0yTo5OdnFg8IhfJqQecDpIPlHsC2wE+Ay+rq\n9q1cr1cB64B/6OZ9mkUKiNvlvHcDJ3ZXz7q83S4HziUd2HfMZf8M+EIh33rgrPxeHwmsAV6cl1+R\nH1sDE4BlFA70uX171tWj2/Ia1O0m4AJgS2A/0sH7f+Vlx9MgqBTyNlye98lDdftgs7zvnwD2zst2\nAfbNz9+VPyMHACIFh90Ln4ebSQFyR1Jgmp6X7QQcnd+f7YAfAdfUfVb+DOyb6zAc+D3pS8DmpCHE\nJ3jh8zgKeDS/by8iDS8+Coxs9r9F+tzV/n+eDzq5vo8D78vbPza/3qn+s5xfb5P38Ql5/VeTvixN\nyMtXAm/Iz19MDsxD4eHhtc7wIOmfot6zpAPC7hHxbET8OvKnvIkzI+JvEbG2m+WXRcTtkYZqPg28\ne1NOLhccB3w1IpZGxFPA6cAxdcN8n4uItRFxK+mb66vqC8l1OQY4PSKejIj7ga+QDha9JknANOCU\niHgsIp4E/m/eVs2zwFn5vZ5D+ta7d67T0cBnI2JNRNwBtHIuq2F5Deo2BjgE+EREPB0RC4FvA1N7\n3eCku88V5N61pK0iYmVELM7pHwS+FBHzI+mKiAcK+c6LiAcj4jFS0N4PICIejYir8vvzJKn38I91\n27wkIhZHxHrS5/oA4DMR8UxE/Ib0haDmvcCciJgTEc9FxFxgASkINfMZ4COSRtalvwW4JyIui4j1\nEfED4C7grd2UcxRpSPO7ef0/kb6kvSsvfxaYIGn7iHg8Iv7YQ706hoNOZxhF6sLX+zKp9/BLSUsl\nzWihrGVtLH+A9I1z55Zq2dyuubxi2ZsBLy2kFWebrSH1iOrtnOtUX9aoTazfSNK38FvyJI6/Av+Z\n02sezQfE+jqOJLWl+N719D43K6/erkAtENb0RZsbfq7yF473ANOBlZJ+LmmfvHgMcG+TMhvuQ0lb\nS/qmpAckPUHquY2o+0JTfM9qbV7TzfLdgXfV9lXeX68nBatuRcTtpKG9+v+V+s8nNH+PdwcOqtv+\ncUBt8srRpAD4gKQbJb22Wb06iYPOICfpANIH/zf1y/I3/Y9FxB7A24BTJb2xtribInvqCY0pPN+N\n9I3tEdL5iK0L9RrGxgfknsp9kPSPWix7PencSTseyXWqL2tFm+U0KnctaRhpRH7sEOnke09Wkdoy\nupA2ppt1e+NBYEdJ2xXS+qLNbyedB/w7EXFdRBxOOojfRRr6hHTgf3kvtvUxUi/uoIjYHjg0p6u4\n2cLzlaQ2b11IK76ny0i98hGFxzYRcU4Ldfks6bxUMaDUfz5h4/e4/vO9DLixbvvbRsSHAXJPcArw\nEuAa4MoW6tURHHQGKUnbSzqKdJ7g8oi4rcE6R0naMw8NrSad7HwuL36IdP6kXe+VNCH/s59FOtG8\ngXTeZEtJb5E0nHTyfotCvoeAsU1mHf0AOEXSOEnbkoauflj3Tb9HuS5XAmdL2k7S7sCpwOXNc25M\n0pbFBy+cUzpX0kvyOqMkHdFinX4CnJm/0e/D3w999XZ/EBHLSCfMv5Dr+0rSBIK22gzpy0LeB18n\nnVf6XIN1XippSp5evI407Ff7XH0bOE3Sa/Issz3zPujJdqSg/ldJO5IO/N3KQ3YLSO/p5rmnUBzq\nuhx4q6Qjcpu2zNPhRzcscOOyu0iTHD5aSJ4D7CXpnyVtJuk9pHNz1+bl9fvv2rz++yQNz48DJP1D\nru9xknaIiGdJ56KeY4hw0Bl8fibpSdI3qTOAr5JOVjYynjSr5ynSSdcLIuK/8rIvAJ/KXf/T2tj+\nZaTJCn8hnbT+KEBErAb+hXTQWUHq+RR/7/Kj/PdRSY3Gry/OZd9Emg31NPCRNupV9JG8/aWkHuD3\nc/mtGkU6ABYfLydNT+8C5uUhoOtpcI6lGyeTZtX9hdTOH5AO2DVnApfm/fHuNupacyzpZP+DwNWk\n80fXt5H/tZKeIh0AbwC2Bw5o9GWGdNw4NW/rMdK5l9o3+B+Rzsd8nzQz6xq6Py9U9P9Ik0QeAeaR\nhi57chzwWtIEgc+TAsW6XI9lwBTgk6Se5jLg32j9mHcWaTIAubxHSedpPpa393HgqIh4JK/yNeCd\nSj/APS8Pdb6ZdM7vQdJ+/yIvfBF7H3B//hxNz20ZEmozmcysQpK+CLwsIiq/IkOnkvRD4K6IaNpL\nsv7lno5ZBSTtI+mVecjpQNLw19X9Xa/BLA9XvVzSiyRNIvVsrunvellznfqrc7OBZjvSkNqupPH/\nr5B+S2S99zLSubKdSEO5H85Tk20A8/CamZlVxsNrZmZWmSE9vLbzzjvH2LFj+7saZmaDyi233PJI\nRNRftaElQzrojB07lgULFvR3NczMBhVJ9VdnaJmH18zMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zM\nKuOgY2ZmlSk16EiaJGmJpK5GNxDL16E6Ly9fJGn/nvJK+ve87kJJv5S0a2HZ6Xn9Ja1cct7MzKpV\nWtDJN/E6H5hMuu/EsZIm1K02mXT5/fGkWwFf2ELeL0fEKyNiP9I9Kz6T80wgXUZ8X2AScIH65jbK\nZmbWR8rs6RwIdOV73j9DutnYlLp1pgCz8r3U55FuT7tLs7wR8UQh/za8cMe+KcAVEbEuIu4j3ffk\nwLIaZ2Zm7Ssz6Ixi43uWL+fv7yfe3TpN80o6W9Iy0o2PPtPG9pA0TdICSQtWrVrVVoOKDjvsMA47\n7LBe57fG/L6adbZBOZEgIs6IiDHA90h3ZGwn70URMTEiJo4c2atLB1Xq3Ll3P/8wMxvsygw6K4Ax\nhdejc1or67SSF1LQObqN7ZmZWT8qM+jMB8ZLGidpc9JJ/tl168wGpuZZbAcDqyNiZbO8ksYX8k8B\n7iqUdYykLSSNI01OuLmsxpmZWftKu8p0RKyXdDJwHTAMuDgiFkuanpfPBOYAR5JO+q8BTmiWNxd9\njqS9geeAB4BaeYslXQncAawHToqIDWW1z8zM2lfqrQ0iYg4psBTTZhaeB3BSq3lz+tENVq8tOxs4\nu7f1NTOzcg3KiQRmZjY4OeiYmVllHHTMzKwyDjpmZlYZBx0zM6tMqbPXrG8Vr0pwyuF79WNNzMx6\nxz0dMzOrjIOOmZlVxkHHzMwq43M6g5TP75jZYOSejpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZ\nZRx0zMysMg46ZmZWGf9OZwAq/gbHzKyTuKdjZmaVcdAxM7PKeHhtgPCQmpkNBe7pmJlZZRx0zMys\nMg46ZmZWGQcdMzOrTKlBR9IkSUskdUma0WC5JJ2Xly+StH9PeSV9WdJdef2rJY3I6WMlrZW0MD9m\nltk2MzNrX2mz1yQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQucHhHrJX0ROB34RC7v3ojY\nr6w2DVS+oZuZDRZl9nQOBLoiYmlEPANcAUypW2cKMCuSecAISbs0yxsRv4yI9Tn/PGB0iW0wM7M+\nVGbQGQUsK7xentNaWaeVvAAfAH5ReD0uD63dKOkNjSolaZqkBZIWrFq1qrWWmJlZnxi0EwkknQGs\nB76Xk1YCu+XhtVOB70vavj5fRFwUERMjYuLIkSOrq7CZmZV6RYIVwJjC69E5rZV1hjfLK+l44Cjg\njRERABGxDliXn98i6V5gL2BBH7SlFL4KgZkNNWX2dOYD4yWNk7Q5cAwwu26d2cDUPIvtYGB1RKxs\nllfSJODjwNsiYk2tIEkj8wQEJO1BmpywtMT2mZlZm0rr6eTZZScD1wHDgIsjYrGk6Xn5TGAOcCTQ\nBawBTmiWNxf9DWALYK4kgHkRMR04FDhL0rPAc8D0iHisrPaZmVn7Sr3gZ0TMIQWWYtrMwvMATmo1\nb07fs5v1rwKu2pT6mplZuQbtRAIzMxt8HHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQcfMzCpT6pRp\nq56vOG1mA5l7OmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46\nZmZWGQcdMzOrjIOOmZlVxkHHzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyvolbB/MN3cxsoHFP\nx8zMKuOgY2ZmlSk16EiaJGmJpC5JMxosl6Tz8vJFkvbvKa+kL0u6K69/taQRhWWn5/WXSDqizLb1\n1rlz737+YWY21JQWdCQNA84HJgMTgGMlTahbbTIwPj+mARe2kHcu8IqIeCVwN3B6zjMBOAbYF5gE\nXJDLMTOzAaLMns6BQFdELI2IZ4ArgCl160wBZkUyDxghaZdmeSPilxGxPuefB4wulHVFRKyLiPuA\nrlyOmZkNEGUGnVHAssLr5TmtlXVayQvwAeAXbWwPSdMkLZC0YNWqVS00w8zM+sqgnUgg6QxgPfC9\ndvJFxEURMTEiJo4cObKcypmZWUNl/k5nBTCm8Hp0TmtlneHN8ko6HjgKeGNERBvbMzOzflRmT2c+\nMF7SOEmbk07yz65bZzYwNc9iOxhYHRErm+WVNAn4OPC2iFhTV9YxkraQNI40OeHmEttnZmZtKq2n\nExHrJZ0MXAcMAy6OiMWSpuflM4E5wJGkk/5rgBOa5c1FfwPYApgrCWBeREzPZV8J3EEadjspIjaU\n1T4zM2tfqZfBiYg5pMBSTJtZeB7ASa3mzel7Ntne2cDZva2vmZmVa9BOJDAzs8HHQcfMzCrjoGNm\nZpVx0DEzs8r4fjpDhO+tY2YDgXs6ZmZWGQcdMzOrTEtBR9JPJL1FkoOUmZn1WqtB5ALgn4F7JJ0j\nae8S62RmZh2qpaATEddHxHHA/sD9wPWSfifpBEnDy6ygmZl1jpaHyyTtBBwPfBD4E/A1UhCaW0rN\nzMys47Q0ZVrS1cDewGXAW/OVoAF+KGlBWZUzM7PO0urvdL6VL8D5PElb5FtDTyyhXmZm1oFaHV77\nfIO03/dlRczMrPM17elIehkwCthK0qsB5UXbA1uXXDczM+swPQ2vHUGaPDAa+Goh/UngkyXVyczM\nOlTToBMRlwKXSjo6Iq6qqE5mZtahehpee29EXA6MlXRq/fKI+GqDbGZmZg31NLy2Tf67bdkVMTOz\nztfT8No389/PVVMdMzPrZK1e8PNLkraXNFzSryStkvTesitnZmadpdXf6bw5Ip4AjiJde21P4N/K\nqpSZmXWmVoNObRjuLcCPImJ1SfUxM7MO1uplcK6VdBewFviwpJHA0+VVy8zMOlFLQSciZkj6ErA6\nIjZI+hswpdyqWVnOnXv3889POXyvfqyJmQ017dwJdB/gPZKmAu8E3txTBkmTJC2R1CVpRoPlknRe\nXr5I0v495ZX0LkmLJT0naWIhfayktZIW5sfMNtpmZmYVaPXWBpcBLwcWAhtycgCzmuQZBpwPHA4s\nB+ZLmh0RdxRWmwyMz4+DgAuBg3rIezvwDuCbDTZ7b0Ts10qbzMyseq2e05kITIiIaKPsA4GuiFgK\nIOkK0pBcMehMAWblcudJGiFpF2Bsd3kj4s6c1kZVzMxsIGh1eO124GVtlj0KWFZ4vTyntbJOK3kb\nGZeH1m6U9IZGK0iaJmmBpAWrVq1qoUgzM+srrfZ0dgbukHQzsK6WGBFvK6VWvbMS2C0iHpX0GuAa\nSfvm3xc9LyIuAi4CmDhxYjs9NzMz20StBp0ze1H2CmBM4fXonNbKOsNbyLuRiFhHDogRcYuke4G9\nAN9O28xsgGhpeC0ibiRdiWB4fj4f+GMP2eYD4yWNk7Q5cAwwu26d2cDUPIvtYNKU7JUt5t2IpJF5\nAgKS9iBNTljaSvvMzKwarV577UPAj3lhxtgo4JpmeSJiPXAycB1wJ3BlRCyWNF3S9LzaHFJg6AK+\nBfxLs7y5Lm+XtBx4LfBzSdflsg4FFklamOs6PSIea6V9ZmZWjVaH104izUb7A0BE3CPpJT1liog5\npMBSTJtZeB657Jby5vSrgasbpF8FDMgbzRV/jDnQ+IeiZlalVmevrYuIZ2ovJG1G+p2OmZlZy1oN\nOjdK+iSwlaTDgR8BPyuvWmZm1olaDTozgFXAbcD/Jg17faqsSpmZWWdq9YKfz0m6BrgmIvyLSjMz\n65WmPZ08lflMSY8AS4Al+a6hn6mmemZm1kl6Gl47BTgEOCAidoyIHUkX5jxE0iml187MzDpKT0Hn\nfcCxEXFfLSFfhPO9wNQyK2ZmZp2np6AzPCIeqU/M53WGl1MlMzPrVD0FnWd6uczMzOzv9DR77VWS\nnmiQLmDLEupjZmYdrGnQiYhhVVXEzMw6X6s/DjUzM9tkDjpmZlYZBx0zM6tMq7c2sCHAtzkws7K5\np2NmZpVx0DEzs8o46JiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOz\nyjjomJlZZUoNOpImSVoiqUvSjAbLJem8vHyRpP17yivpXZIWS3pO0sS68k7P6y+RdESZbet05869\ne6OHmVlfKC3oSBoGnA9MBiYAx0qaULfaZGB8fkwDLmwh7+3AO4Cb6rY3ATgG2BeYBFyQyzEzswGi\nzJ7OgUBXRCyNiGeAK4ApdetMAWZFMg8YIWmXZnkj4s6IWNJge1OAKyJiXUTcB3TlcszMbIAoM+iM\nApYVXi/Paa2s00re3mwPSdMkLZC0YNWqVT0UaWZmfWnITSSIiIsiYmJETBw5cmR/V8fMbEgp8yZu\nK4Axhdejc1or6wxvIW9vtmdmZv2ozJ7OfGC8pHGSNied5J9dt85sYGqexXYwsDoiVraYt95s4BhJ\nW0gaR5qccHNfNsjMzDZNaT2diFgv6WTgOmAYcHFELJY0PS+fCcwBjiSd9F8DnNAsL4CktwNfB0YC\nP5e0MCKOyGVfCdwBrAdOiogNZbVvqPGtrM2sL5Q5vEZEzCEFlmLazMLzAE5qNW9Ovxq4ups8ZwNn\nb0KVzcysRENuIoGZmfUfBx0zM6uMg46ZmVXGQcfMzCpT6kQC60yeyWZmveWejpmZVcZBx8zMKuOg\nY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrjIOOmZlVxlcksE3iqxOYWTscdEpSPBib\nmVni4TUzM6uMezrWZzzUZmY9cU/HzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0z\nM6tMqUFH0iRJSyR1SZrRYLkknZeXL5K0f095Je0oaa6ke/LfF+f0sZLWSlqYHzPLbJuZmbWvtKAj\naRhwPjAZmAAcK2lC3WqTgfH5MQ24sIW8M4BfRcR44Ff5dc29EbFffkwvp2VmZtZbZfZ0DgS6ImJp\nRDwDXAFMqVtnCjArknnACEm79JB3CnBpfn4p8E8ltsHMzPpQmZfBGQUsK7xeDhzUwjqjesj70ohY\nmZ//BXhpYb1xkhYCq4FPRcSv6yslaRqpV8Vuu+3WTnusDb4kjpk1MqgnEkREAJFfrgR2i4j9gFOB\n70vavkGeiyJiYkRMHDlyZIW1NTOzMns6K4Axhdejc1or6wxvkvchSbtExMo8FPcwQESsA9bl57dI\nuhfYC1jQN82x3nKvx8xqyuzpzAfGSxonaXPgGGB23Tqzgal5FtvBwOo8dNYs72zg/fn5+4GfAkga\nmScgIGkP0uSEpeU1z8zM2lVaTyci1ks6GbgOGAZcHBGLJU3Py2cCc4AjgS5gDXBCs7y56HOAKyWd\nCDwAvDunHwqcJelZ4DlgekQ8Vlb7zMysfaXeTyci5pACSzFtZuF5ACe1mjenPwq8sUH6VcBVm1hl\nMzMr0aCeSGBmZoOL7xxqlfKkArOhzT0dMzOrjIOOmZlVxsNr1m881GY29LinY2ZmlXHQMTOzyjjo\nmJlZZXxOxwaE2vmd5Y+vZfSLt+rn2phZWdzTMTOzyjjo2ICz/PG1nDv37o1mt5lZZ3DQMTOzyjjo\nmJlZZRx0zMysMp69ZgOar1pg1lnc0zEzs8q4p2ODhns9ZoOfezpmZlYZ93T6kH9XYmbWnIOODUoe\najMbnBx0bNBzADIbPBx0rKN0N8TpYGQ2MDjo2JDg3pDZwOCgs4k8eWDwcQAy6z8OOjakOQCZVctB\nxyzz+SCz8pUadCRNAr4GDAO+HRHn1C1XXn4ksAY4PiL+2CyvpB2BHwJjgfuBd0fE43nZ6cCJwAbg\noxFxXZnts6GhlSFUByaz1pQWdCQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQH8KiLOkTQj\nv/6EpAnAMcC+wK7A9ZL2iogNZbXRrKbdc3sOUjZUldnTORDoioilAJKuAKYAxaAzBZgVEQHMkzRC\n0i6kXkx3eacAh+X8lwI3AJ/I6VdExDrgPklduQ6/L7GNZr3SnxNQigHP57SsamUGnVHAssLr5aTe\nTE/rjOoh70sjYmV+/hfgpYWy5jUoayOSpgHT8sunJC1ppTEFOwOP1F7ceOPebWYf1DZqe5lOffOA\nfF8ra3+ZTm0zPeuItm+Codz+Rm3fvbeFDeqJBBERkqLNPBcBF/V2m5IWRMTE3uYfzIZy22Fot38o\ntx2Gdvv7uu1lXmV6BTCm8Hp0TmtlnWZ5H8pDcOS/D7exPTMz60dlBp35wHhJ4yRtTjrJP7tundnA\nVCUHA6vz0FmzvLOB9+fn7wd+Wkg/RtIWksaRJifcXFbjzMysfaUNr0XEekknA9eRpj1fHBGLJU3P\ny2cCc0jTpbtIU6ZPaJY3F30OcKWkE4EHgHfnPIslXUmabLAeOKmkmWu9HprrAEO57TC02z+U2w5D\nu/192naliWNmZmbl851DzcysMg46ZmZWGQedFkmaJGmJpK58JYSOJOl+SbdJWihpQU7bUdJcSffk\nvy8urH96fk+WSDqi/2rePkkXS3pY0u2FtLbbKuk1+T3rknRevrzTgNdN+8+UtCLv/4WSjiws65j2\nSxoj6b8k3SFpsaR/zekdv/+btL2afR8RfvTwIE1muBfYA9gcuBWY0N/1Kqmt9wM716V9CZiRn88A\nvpifT8jvxRbAuPweDevvNrTR1kOB/YHbN6WtpFmSBwMCfgFM7u+2bUL7zwROa7BuR7Uf2AXYPz/f\nDrg7t7Hj93+Ttley793Tac3zl/SJiGeA2mV5hooppEsOkf/+UyH9iohYFxH3kWYhHtgP9euViLgJ\neKwuua225t+KbR8R8yL9F84q5BnQuml/dzqq/RGxMvLFhSPiSeBO0hVMOn7/N2l7d/q07Q46renu\ncj2dKEgXS70lXzIIml96qNPel3bbOio/r08fzD4iaVEefqsNL3Vs+yWNBV4N/IEhtv/r2g4V7HsH\nHav3+ojYj3QF8JMkHVpcmL/RDIl59kOprQUXkoaR9wNWAl/p3+qUS9K2wFXA/4mIJ4rLOn3/N2h7\nJfveQac1Q+YSOxGxIv99GLiaNFw2lC491G5bV+Tn9emDUkQ8FBEbIuI54Fu8MFzace2XNJx00P1e\nRPwkJw+J/d+o7VXtewed1rRySZ9BT9I2krarPQfeDNzO0Lr0UFttzUMxT0g6OM/cmVrIM+jUDrjZ\n20n7Hzqs/bmu3wHujIivFhZ1/P7vru2V7fv+nkkxWB6ky/XcTZq5cUZ/16ekNu5BmqVyK7C41k5g\nJ+BXwD3A9cCOhTxn5PdkCQN81k6D9v6ANIzwLGk8+sTetBWYmP9B7wW+Qb7Sx0B/dNP+y4DbgEX5\nYLNLJ7YfeD1p6GwRsDA/jhwK+79J2yvZ974MjpmZVcbDa2ZmVhkHHTMzq4yDjpmZVcZBx8zMKuOg\nY2ZmlXHQMQMkPVVy+cdL2rXw+n5JO7eQ79WSvlNy3S6R9M4my0+W9IEy62BDh4OOWTWOB3btaaUG\nPgmc11chA3JvAAADJklEQVSVkNSbW9RfDHykr+pgQ5uDjlk3JI2UdJWk+flxSE4/M18Q8QZJSyV9\ntJDn0/meI7+R9ANJp+VexETge/k+JVvl1T8i6Y/5fiT7NNj+dsArI+LW/Po2SSOUPCppak6fJelw\nSVtK+m5e70+S/mdefryk2ZL+P/CrnP8buZ7XAy8pbPMcpfusLJL0HwARsQa4X9KguYK4DVwOOmbd\n+xpwbkQcABwNfLuwbB/gCNL1qT4rabik2nqvIl0wdSJARPwYWAAcFxH7RcTaXMYjEbE/6UKLpzXY\nfu3X3jW/BQ4B9gWWAm/I6a8FfgeclDYX/wM4FrhU0pZ5nf2Bd0bEP5IucbI36T4pU4HXAUjaKS/b\nNyJeCXy+sO0Fhe2Z9VpvutpmQ8WbgAl64WaI2+cr8wL8PCLWAeskPUy6BP4hwE8j4mngaUk/66H8\n2kUmbwHe0WD5LsCqwutfk2689gApUE2TNAp4PCL+Jun1wNcBIuIuSQ8Ae+W8cyOidu+cQ4EfRMQG\n4MHcAwJYDTwNfEfStcC1hW0/TAq0ZpvEPR2z7r0IODj3TvaLiFERUZtwsK6w3gZ69wWuVkZ3+dcC\nWxZe30TqbbwBuIEUkN5JCkY9+VtPK0TEelLP7cfAUcB/FhZvmetjtkkcdMy690sKJ9Al7dfD+r8F\n3prPrWxLOnDXPEm6NXA77gT2rL2IiGXAzsD4iFgK/IY0LHdTXuXXwHG5rnsBu5Eu0FjvJuA9kobl\nKwvXzv1sC+wQEXOAU0jDhDV7sfFQn1mvOOiYJVtLWl54nAp8FJiYT6rfAUxvVkBEzCddnXcR6X7x\nt5GGrAAuAWbWTSRoKiLuAnao3W4i+wPpaueQgswoUvABuAB4kaTbgB8Cx+chwHpXk66ifAfpFsO/\nz+nbAddKWpTLPLWQ5xBgbiv1NmvGV5k260OSto2IpyRtTepRTIt8P/pelncK8GREfLvHlUsi6dXA\nqRHxvv6qg3UO93TM+tZFkhYCfwSu2pSAk13IxueP+sPOwKf7uQ7WIdzTMTOzyrinY2ZmlXHQMTOz\nyjjomJlZZRx0zMysMg46ZmZWmf8GQp4xDBcjJYgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x2d78f12e518>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We decide the max and min length (in words) of discharge notes.\n",
    "\n",
    "plt.hist([len(x) for x in train_x], normed=1, bins=100, alpha=.5)\n",
    "plt.vlines(MAX_NOTE_LEN, 0, .003); plt.vlines(MIN_NOTE_LEN, 0, .003);\n",
    "plt.xlabel(\"Length (words)\"); plt.ylabel(\"Density\");\n",
    "plt.title(\"Distribution of Length of Discharge Notes\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Keep only the notes that are long enough:\n",
    "subset_train = set(np.where([len(x) >= MIN_NOTE_LEN for x in train_x])[0])\n",
    "subset_test  = set(np.where([len(x) >= MIN_NOTE_LEN for x in test_x])[0])\n",
    "subset_valid = set(np.where([len(x) >= MIN_NOTE_LEN for x in valid_x])[0])\n",
    "\n",
    "def getsubset(orig, index):\n",
    "    return([j for i,j in enumerate(orig) if i in index])\n",
    "\n",
    "# Pad the notes that are too short:\n",
    "train_x = pad_sequences(getsubset(train_x, subset_train), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
    "valid_x = pad_sequences(getsubset(valid_x, subset_valid), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
    "test_x  = pad_sequences(getsubset(test_x,  subset_test),  maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
    "\n",
    "train_y = np.array(getsubset(train_y, subset_train))\n",
    "valid_y = np.array(getsubset(valid_y, subset_valid))\n",
    "test_y  = np.array(getsubset(test_y,  subset_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining the neural network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "seq_input_layer = Input(shape=(MAX_NOTE_LEN,), dtype='int32')\n",
    "\n",
    "embedded_layer  = Embedding(embeddings_matrix.shape[0], embeddings_matrix.shape[1],\n",
    "                      weights      = [embeddings_matrix],\n",
    "                      input_length = MAX_NOTE_LEN,\n",
    "                      trainable    = True)(seq_input_layer)\n",
    "\n",
    "conv_layer      = Conv1D(UNITS, FILTERSIZE, activation='tanh')(embedded_layer)\n",
    "\n",
    "pool_layer      = GlobalMaxPooling1D()(conv_layer)\n",
    "\n",
    "out_layer       = Dense(1, \n",
    "                        activation           = 'sigmoid', \n",
    "                        activity_regularizer = l1(REG_FACTOR)\n",
    "                       )(pool_layer)\n",
    "\n",
    "optimizer = RMSprop(lr = LEARNING_RATE)\n",
    "model     = Model(inputs=seq_input_layer, outputs=out_layer)\n",
    "model.compile(loss=LOSS_FUNC, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 700)               0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, 700, 1000)         22331000  \n",
      "_________________________________________________________________\n",
      "conv1d_1 (Conv1D)            (None, 698, 450)          1350450   \n",
      "_________________________________________________________________\n",
      "global_max_pooling1d_1 (Glob (None, 450)               0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 451       \n",
      "=================================================================\n",
      "Total params: 23,681,901\n",
      "Trainable params: 23,681,901\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the neural net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load the weights from a previous run, or train the model anew:\n",
    "if isfile(CNN_FILENAME):\n",
    "    model.load_weights(CNN_FILENAME)\n",
    "else:\n",
    "    model.fit(train_x, train_y, \n",
    "              batch_size      = BATCH_SIZE, \n",
    "              epochs          = EPOCHS, \n",
    "              validation_data = (valid_x, valid_y), \n",
    "              verbose         = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}