--- a +++ b/03_TrainModel.ipynb @@ -0,0 +1,413 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* [Constants](#Constants)\n", + "* [Load data](#Load-data)\n", + "* [Train Word2Vec](#Train-Word2Vec)\n", + "* [Prepare text](#Prepare-text)\n", + "* [Defining the neural network](#Defining-the-neural-network) \n", + "* [Training the neural net](#Training-the-neural-net)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import string\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from os.path import isfile\n", + "\n", + "from keras.models import Model\n", + "from keras.preprocessing.sequence import pad_sequences\n", + "from keras.layers import Embedding, Input, Conv1D, Dense, GlobalMaxPooling1D\n", + "from keras.optimizers import RMSprop\n", + "from keras.regularizers import l1\n", + "\n", + "from gensim.models import word2vec\n", + "from gensim.models import KeyedVectors\n", + "\n", + "\n", + "import logging\n", + "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Location of train/test data files generated by TextSections/TextPrep\n", + "TRAIN_DATA_LOC = \"~/train_data.csv\"\n", + "TEST_DATA_LOC = \"~/test_data.csv\"\n", + "\n", + "# Columns we will use:\n", + "VISITID = \"visit_id\"\n", + "OUTCOME = \"readmitted\" # e.g. ReadmissionInLessThan30Days\n", + "\n", + "# Test/Train split\n", + "SPLIT_SIZE = 0.9 # relative size of train:test\n", + "SPLIT_SEED = 1234\n", + "\n", + "# Word2Vec hyperparameters\n", + "WINDOW = 2\n", + "DIMENSIONS = 1000\n", + "MIN_COUNT = 5\n", + "USE_SKIPGRAM = True \n", + "USE_HIER_SMAX = False \n", + "NUM_THREADS = 50\n", + "# Where to save the w2v model:\n", + "W2V_FILENAME = './w2v_dims_{dims}_window_{window}.bin'.format(\n", + " dims = DIMENSIONS,\n", + " window = WINDOW\n", + ")\n", + "\n", + "\n", + "# Text Prep\n", + "PADDING = \"PADDING\"\n", + "MAX_NOTE_LEN = 700\n", + "MIN_NOTE_LEN = 20\n", + "\n", + "# Model Architecture\n", + "UNITS = 450\n", + "FILTERSIZE = 3\n", + "LEARNING_RATE = 0.0001\n", + "LOSS_FUNC = 'binary_crossentropy'\n", + "REG_FACTOR = 0.05\n", + "\n", + "# Model Training\n", + "CNN_FILENAME = \"./cnn.h5\"\n", + "BATCH_SIZE = 100\n", + "EPOCHS = 4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "# Read train and test hospital data.\n", + "train = pd.read_csv(TRAIN_DATA_LOC, dtype = str)\n", + "test = pd.read_csv(TEST_DATA_LOC, dtype = str)\n", + "\n", + "# Split the train data into a train and validation set.\n", + "train, valid = train_test_split(train, \n", + " stratify = train[OUTCOME], \n", + " train_size = SPLIT_SIZE, \n", + " random_state = SPLIT_SEED)\n", + "\n", + "# Prepare the sections.\n", + "# If `sectiontext` is present, then include \"SECTIONNAME sectiontext\".\n", + "# If not present, include only \"SECTIONNAME\".\n", + "SECTIONNAMES = [x for x in trainTXT.columns if VISITID not in x and OUTCOME not in x]\n", + "for x in SECTIONNAMES:\n", + " rep = x.replace(\" \", \"_\").upper()\n", + " train[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in train[x]]\n", + " valid[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in valid[x]]\n", + " test[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in test[x]]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train Word2Vec" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2017-10-27 12:32:33,194 : INFO : loading projection weights from ./word2vec/w2v_dims_1000_window_2.bin\n", + "2017-10-27 12:32:33,507 : INFO : loaded (22330, 1000) matrix from ./word2vec/w2v_dims_1000_window_2.bin\n" + ] + } + ], + "source": [ + "# We will remove digits and punctuation:\n", + "remove_digits_punc = str.maketrans('', '', string.digits + ''.join([x for x in string.punctuation if '_' not in x]))\n", + "remove_digits_punc = {a:\" \" for a in remove_digits_punc.keys()}\n", + "\n", + "# (If the model already exists, don't recompute.)\n", + "if not isfile(W2V_FILENAME):\n", + " # Use only training data to train word2vec:\n", + " notes = train[SECTIONNAMES].apply(lambda x: \" \".join(x), axis=1).values \n", + " stop = set([x for x in string.ascii_lowercase]) \n", + " for i in range(len(notes)):\n", + " notes[i] = [w for w in notes[i].translate(remove_digits_punc).split() if (w not in stop)]\n", + " \n", + " w2v = word2vec.Word2Vec(notes, \n", + " size = DIMENSIONS, \n", + " window = WINDOW, \n", + " sg = USE_SKIPGRAM, \n", + " hs = USE_HIER_SMAX, \n", + " min_count = MIN_COUNT, \n", + " workers = NUM_THREADS)\n", + " w2v.wv.save_word2vec_format(W2V_FILENAME, binary=True)\n", + "else:\n", + " w2v = KeyedVectors.load_word2vec_format(W2V_FILENAME, binary=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Make the embedding matrix.\n", + "# We include one extra word, `PADDING`. This is the word that will right-pad short notes.\n", + "# For `PADDING`'s vector representation, we choose the zero vector.\n", + "vocab = [PADDING] + sorted(list(w2v.wv.vocab.keys()))\n", + "vset = set(vocab)\n", + "\n", + "embeddings_index = {}\n", + "for i in range(len(vocab)):\n", + " embeddings_index[vocab[i]] = i\n", + "\n", + "# reverse_embeddings_index = {b:a for a,b in embeddings_index.items()}\n", + "\n", + "# Adding PADDING as vocab word with embedding value of a zero vector\n", + "embeddings_matrix = np.matrix(np.concatenate(([[0.] * DIMENSIONS], [w2v[x] for x in vocab[1:]])))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Prepare text" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "train_x = train[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", + "test_x = test[ SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", + "valid_x = valid[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", + "\n", + "train_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in train_x]\n", + "valid_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in valid_x]\n", + "test_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in test_x]\n", + "\n", + "train_y = train[OUTCOME]\n", + "valid_y = valid[OUTCOME]\n", + "test_y = test[OUTCOME]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEWCAYAAAC9qEq5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X2cHFWd7/HP1xCeHyIQFZJAggTY4FXE8KAoy72KJIhm\nFR9g0QiiuXFB94KsBvEBWbmirnJFgYiKEFARRTBiXAzeBXyKJGgIBAgMAUxChPBgABMCCb/945yG\nStvT0z2Zqpnp+b5fr35N96k6p87p6qlfn1OnqxQRmJmZVeFF/V0BMzMbOhx0zMysMg46ZmZWGQcd\nMzOrjIOOmZlVxkHHzMwq46DTYSTNlPTpPiprN0lPSRqWX98g6YN9UXYu7xeS3t9X5bWx3c9LekTS\nX6redqskHSZpeR+VJUnflfS4pJt7WcZGn4VelnG/pDf1Nr91BgedQST/066V9KSkv0r6naTpkp7f\njxExPSL+vcWymh4AIuLPEbFtRGzog7qfKenyuvInR8Slm1p2m/XYDfgYMCEiXtZgeZ8d7NusV0ja\ns6TiXw8cDoyOiAMbbPt4SRtyUHlK0n05SO1VW6cvPwsDUf5/eFjSNoW0D0q6ocX8ffqFrJM56Aw+\nb42I7YDdgXOATwDf6euNSNqsr8scIHYDHo2Ih/u7IhXaHbg/Iv7WZJ3fR8S2wA7Am4C1wC2SXlFF\nBduRe25lHLuGAf9aQrlWFBF+DJIHcD/wprq0A4HngFfk15cAn8/PdwauBf4KPAb8mvRF47KcZy3w\nFPBxYCwQwInAn4GbCmmb5fJuAL4A3Aw8AfwU2DEvOwxY3qi+wCTgGeDZvL1bC+V9MD9/EfAp4AHg\nYWAWsENeVqvH+3PdHgHOaPI+7ZDzr8rlfSqXXzuYPpfrcUmDvH/XjsKyLYD/yHV4CJgJbFXMR+pF\nPQysBE4o5N0J+Fl+3+YDnwd+k5fdlNv3t1yv9/RUXoO67QrMzvu5C/hQTj8ReBrYkMv+XIO8x9fq\nUpd+LfDjun2wWSHPUuBJ4D7guEK+DwF35mV3APsXPg+nAYuA1cAPgS3zshfn7a0CHs/PRxfKvAE4\nG/ht3od7AuPye/ckcD1wPnB5Ic/BwO9In/9bgcN6+N+akd+/ETntg8ANhXVel/fd6vz3dTn97Pz+\nPp3f42/k9H2AubnMJcC7C2Udmd+bJ4EVwGn9fXyp7DjW3xXwo42d1SDo5PQ/Ax/Ozy/hhaDzBdKB\ncXh+vAFQo7IKB5VZwDbAVg0ONDfkf5BX5HWuqv2T0yTo5OdnFg8IhfJqQecDpIPlHsC2wE+Ay+rq\n9q1cr1cB64B/6OZ9mkUKiNvlvHcDJ3ZXz7q83S4HziUd2HfMZf8M+EIh33rgrPxeHwmsAV6cl1+R\nH1sDE4BlFA70uX171tWj2/Ia1O0m4AJgS2A/0sH7f+Vlx9MgqBTyNlye98lDdftgs7zvnwD2zst2\nAfbNz9+VPyMHACIFh90Ln4ebSQFyR1Jgmp6X7QQcnd+f7YAfAdfUfVb+DOyb6zAc+D3pS8DmpCHE\nJ3jh8zgKeDS/by8iDS8+Coxs9r9F+tzV/n+eDzq5vo8D78vbPza/3qn+s5xfb5P38Ql5/VeTvixN\nyMtXAm/Iz19MDsxD4eHhtc7wIOmfot6zpAPC7hHxbET8OvKnvIkzI+JvEbG2m+WXRcTtkYZqPg28\ne1NOLhccB3w1IpZGxFPA6cAxdcN8n4uItRFxK+mb66vqC8l1OQY4PSKejIj7ga+QDha9JknANOCU\niHgsIp4E/m/eVs2zwFn5vZ5D+ta7d67T0cBnI2JNRNwBtHIuq2F5Deo2BjgE+EREPB0RC4FvA1N7\n3eCku88V5N61pK0iYmVELM7pHwS+FBHzI+mKiAcK+c6LiAcj4jFS0N4PICIejYir8vvzJKn38I91\n27wkIhZHxHrS5/oA4DMR8UxE/Ib0haDmvcCciJgTEc9FxFxgASkINfMZ4COSRtalvwW4JyIui4j1\nEfED4C7grd2UcxRpSPO7ef0/kb6kvSsvfxaYIGn7iHg8Iv7YQ706hoNOZxhF6sLX+zKp9/BLSUsl\nzWihrGVtLH+A9I1z55Zq2dyuubxi2ZsBLy2kFWebrSH1iOrtnOtUX9aoTazfSNK38FvyJI6/Av+Z\n02sezQfE+jqOJLWl+N719D43K6/erkAtENb0RZsbfq7yF473ANOBlZJ+LmmfvHgMcG+TMhvuQ0lb\nS/qmpAckPUHquY2o+0JTfM9qbV7TzfLdgXfV9lXeX68nBatuRcTtpKG9+v+V+s8nNH+PdwcOqtv+\ncUBt8srRpAD4gKQbJb22Wb06iYPOICfpANIH/zf1y/I3/Y9FxB7A24BTJb2xtribInvqCY0pPN+N\n9I3tEdL5iK0L9RrGxgfknsp9kPSPWix7PencSTseyXWqL2tFm+U0KnctaRhpRH7sEOnke09Wkdoy\nupA2ppt1e+NBYEdJ2xXS+qLNbyedB/w7EXFdRBxOOojfRRr6hHTgf3kvtvUxUi/uoIjYHjg0p6u4\n2cLzlaQ2b11IK76ny0i98hGFxzYRcU4Ldfks6bxUMaDUfz5h4/e4/vO9DLixbvvbRsSHAXJPcArw\nEuAa4MoW6tURHHQGKUnbSzqKdJ7g8oi4rcE6R0naMw8NrSad7HwuL36IdP6kXe+VNCH/s59FOtG8\ngXTeZEtJb5E0nHTyfotCvoeAsU1mHf0AOEXSOEnbkoauflj3Tb9HuS5XAmdL2k7S7sCpwOXNc25M\n0pbFBy+cUzpX0kvyOqMkHdFinX4CnJm/0e/D3w999XZ/EBHLSCfMv5Dr+0rSBIK22gzpy0LeB18n\nnVf6XIN1XippSp5evI407Ff7XH0bOE3Sa/Issz3zPujJdqSg/ldJO5IO/N3KQ3YLSO/p5rmnUBzq\nuhx4q6Qjcpu2zNPhRzcscOOyu0iTHD5aSJ4D7CXpnyVtJuk9pHNz1+bl9fvv2rz++yQNz48DJP1D\nru9xknaIiGdJ56KeY4hw0Bl8fibpSdI3qTOAr5JOVjYynjSr5ynSSdcLIuK/8rIvAJ/KXf/T2tj+\nZaTJCn8hnbT+KEBErAb+hXTQWUHq+RR/7/Kj/PdRSY3Gry/OZd9Emg31NPCRNupV9JG8/aWkHuD3\nc/mtGkU6ABYfLydNT+8C5uUhoOtpcI6lGyeTZtX9hdTOH5AO2DVnApfm/fHuNupacyzpZP+DwNWk\n80fXt5H/tZKeIh0AbwC2Bw5o9GWGdNw4NW/rMdK5l9o3+B+Rzsd8nzQz6xq6Py9U9P9Ik0QeAeaR\nhi57chzwWtIEgc+TAsW6XI9lwBTgk6Se5jLg32j9mHcWaTIAubxHSedpPpa393HgqIh4JK/yNeCd\nSj/APS8Pdb6ZdM7vQdJ+/yIvfBF7H3B//hxNz20ZEmozmcysQpK+CLwsIiq/IkOnkvRD4K6IaNpL\nsv7lno5ZBSTtI+mVecjpQNLw19X9Xa/BLA9XvVzSiyRNIvVsrunvellznfqrc7OBZjvSkNqupPH/\nr5B+S2S99zLSubKdSEO5H85Tk20A8/CamZlVxsNrZmZWmSE9vLbzzjvH2LFj+7saZmaDyi233PJI\nRNRftaElQzrojB07lgULFvR3NczMBhVJ9VdnaJmH18zMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zM\nKuOgY2ZmlSk16EiaJGmJpK5GNxDL16E6Ly9fJGn/nvJK+ve87kJJv5S0a2HZ6Xn9Ja1cct7MzKpV\nWtDJN/E6H5hMuu/EsZIm1K02mXT5/fGkWwFf2ELeL0fEKyNiP9I9Kz6T80wgXUZ8X2AScIH65jbK\nZmbWR8rs6RwIdOV73j9DutnYlLp1pgCz8r3U55FuT7tLs7wR8UQh/za8cMe+KcAVEbEuIu4j3ffk\nwLIaZ2Zm7Ssz6Ixi43uWL+fv7yfe3TpN80o6W9Iy0o2PPtPG9pA0TdICSQtWrVrVVoOKDjvsMA47\n7LBe57fG/L6adbZBOZEgIs6IiDHA90h3ZGwn70URMTEiJo4c2atLB1Xq3Ll3P/8wMxvsygw6K4Ax\nhdejc1or67SSF1LQObqN7ZmZWT8qM+jMB8ZLGidpc9JJ/tl168wGpuZZbAcDqyNiZbO8ksYX8k8B\n7iqUdYykLSSNI01OuLmsxpmZWftKu8p0RKyXdDJwHTAMuDgiFkuanpfPBOYAR5JO+q8BTmiWNxd9\njqS9geeAB4BaeYslXQncAawHToqIDWW1z8zM2lfqrQ0iYg4psBTTZhaeB3BSq3lz+tENVq8tOxs4\nu7f1NTOzcg3KiQRmZjY4OeiYmVllHHTMzKwyDjpmZlYZBx0zM6tMqbPXrG8Vr0pwyuF79WNNzMx6\nxz0dMzOrjIOOmZlVxkHHzMwq43M6g5TP75jZYOSejpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZ\nZRx0zMysMg46ZmZWGf9OZwAq/gbHzKyTuKdjZmaVcdAxM7PKeHhtgPCQmpkNBe7pmJlZZRx0zMys\nMg46ZmZWGQcdMzOrTKlBR9IkSUskdUma0WC5JJ2Xly+StH9PeSV9WdJdef2rJY3I6WMlrZW0MD9m\nltk2MzNrX2mz1yQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQucHhHrJX0ROB34RC7v3ojY\nr6w2DVS+oZuZDRZl9nQOBLoiYmlEPANcAUypW2cKMCuSecAISbs0yxsRv4yI9Tn/PGB0iW0wM7M+\nVGbQGQUsK7xentNaWaeVvAAfAH5ReD0uD63dKOkNjSolaZqkBZIWrFq1qrWWmJlZnxi0EwkknQGs\nB76Xk1YCu+XhtVOB70vavj5fRFwUERMjYuLIkSOrq7CZmZV6RYIVwJjC69E5rZV1hjfLK+l44Cjg\njRERABGxDliXn98i6V5gL2BBH7SlFL4KgZkNNWX2dOYD4yWNk7Q5cAwwu26d2cDUPIvtYGB1RKxs\nllfSJODjwNsiYk2tIEkj8wQEJO1BmpywtMT2mZlZm0rr6eTZZScD1wHDgIsjYrGk6Xn5TGAOcCTQ\nBawBTmiWNxf9DWALYK4kgHkRMR04FDhL0rPAc8D0iHisrPaZmVn7Sr3gZ0TMIQWWYtrMwvMATmo1\nb07fs5v1rwKu2pT6mplZuQbtRAIzMxt8HHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQcfMzCpT6pRp\nq56vOG1mA5l7OmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46\nZmZWGQcdMzOrjIOOmZlVxkHHzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyvolbB/MN3cxsoHFP\nx8zMKuOgY2ZmlSk16EiaJGmJpC5JMxosl6Tz8vJFkvbvKa+kL0u6K69/taQRhWWn5/WXSDqizLb1\n1rlz737+YWY21JQWdCQNA84HJgMTgGMlTahbbTIwPj+mARe2kHcu8IqIeCVwN3B6zjMBOAbYF5gE\nXJDLMTOzAaLMns6BQFdELI2IZ4ArgCl160wBZkUyDxghaZdmeSPilxGxPuefB4wulHVFRKyLiPuA\nrlyOmZkNEGUGnVHAssLr5TmtlXVayQvwAeAXbWwPSdMkLZC0YNWqVS00w8zM+sqgnUgg6QxgPfC9\ndvJFxEURMTEiJo4cObKcypmZWUNl/k5nBTCm8Hp0TmtlneHN8ko6HjgKeGNERBvbMzOzflRmT2c+\nMF7SOEmbk07yz65bZzYwNc9iOxhYHRErm+WVNAn4OPC2iFhTV9YxkraQNI40OeHmEttnZmZtKq2n\nExHrJZ0MXAcMAy6OiMWSpuflM4E5wJGkk/5rgBOa5c1FfwPYApgrCWBeREzPZV8J3EEadjspIjaU\n1T4zM2tfqZfBiYg5pMBSTJtZeB7ASa3mzel7Ntne2cDZva2vmZmVa9BOJDAzs8HHQcfMzCrjoGNm\nZpVx0DEzs8r4fjpDhO+tY2YDgXs6ZmZWGQcdMzOrTEtBR9JPJL1FkoOUmZn1WqtB5ALgn4F7JJ0j\nae8S62RmZh2qpaATEddHxHHA/sD9wPWSfifpBEnDy6ygmZl1jpaHyyTtBBwPfBD4E/A1UhCaW0rN\nzMys47Q0ZVrS1cDewGXAW/OVoAF+KGlBWZUzM7PO0urvdL6VL8D5PElb5FtDTyyhXmZm1oFaHV77\nfIO03/dlRczMrPM17elIehkwCthK0qsB5UXbA1uXXDczM+swPQ2vHUGaPDAa+Goh/UngkyXVyczM\nOlTToBMRlwKXSjo6Iq6qqE5mZtahehpee29EXA6MlXRq/fKI+GqDbGZmZg31NLy2Tf67bdkVMTOz\nztfT8No389/PVVMdMzPrZK1e8PNLkraXNFzSryStkvTesitnZmadpdXf6bw5Ip4AjiJde21P4N/K\nqpSZmXWmVoNObRjuLcCPImJ1SfUxM7MO1uplcK6VdBewFviwpJHA0+VVy8zMOlFLQSciZkj6ErA6\nIjZI+hswpdyqWVnOnXv3889POXyvfqyJmQ017dwJdB/gPZKmAu8E3txTBkmTJC2R1CVpRoPlknRe\nXr5I0v495ZX0LkmLJT0naWIhfayktZIW5sfMNtpmZmYVaPXWBpcBLwcWAhtycgCzmuQZBpwPHA4s\nB+ZLmh0RdxRWmwyMz4+DgAuBg3rIezvwDuCbDTZ7b0Ts10qbzMyseq2e05kITIiIaKPsA4GuiFgK\nIOkK0pBcMehMAWblcudJGiFpF2Bsd3kj4s6c1kZVzMxsIGh1eO124GVtlj0KWFZ4vTyntbJOK3kb\nGZeH1m6U9IZGK0iaJmmBpAWrVq1qoUgzM+srrfZ0dgbukHQzsK6WGBFvK6VWvbMS2C0iHpX0GuAa\nSfvm3xc9LyIuAi4CmDhxYjs9NzMz20StBp0ze1H2CmBM4fXonNbKOsNbyLuRiFhHDogRcYuke4G9\nAN9O28xsgGhpeC0ibiRdiWB4fj4f+GMP2eYD4yWNk7Q5cAwwu26d2cDUPIvtYNKU7JUt5t2IpJF5\nAgKS9iBNTljaSvvMzKwarV577UPAj3lhxtgo4JpmeSJiPXAycB1wJ3BlRCyWNF3S9LzaHFJg6AK+\nBfxLs7y5Lm+XtBx4LfBzSdflsg4FFklamOs6PSIea6V9ZmZWjVaH104izUb7A0BE3CPpJT1liog5\npMBSTJtZeB657Jby5vSrgasbpF8FDMgbzRV/jDnQ+IeiZlalVmevrYuIZ2ovJG1G+p2OmZlZy1oN\nOjdK+iSwlaTDgR8BPyuvWmZm1olaDTozgFXAbcD/Jg17faqsSpmZWWdq9YKfz0m6BrgmIvyLSjMz\n65WmPZ08lflMSY8AS4Al+a6hn6mmemZm1kl6Gl47BTgEOCAidoyIHUkX5jxE0iml187MzDpKT0Hn\nfcCxEXFfLSFfhPO9wNQyK2ZmZp2np6AzPCIeqU/M53WGl1MlMzPrVD0FnWd6uczMzOzv9DR77VWS\nnmiQLmDLEupjZmYdrGnQiYhhVVXEzMw6X6s/DjUzM9tkDjpmZlYZBx0zM6tMq7c2sCHAtzkws7K5\np2NmZpVx0DEzs8o46JiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOz\nyjjomJlZZUoNOpImSVoiqUvSjAbLJem8vHyRpP17yivpXZIWS3pO0sS68k7P6y+RdESZbet05869\ne6OHmVlfKC3oSBoGnA9MBiYAx0qaULfaZGB8fkwDLmwh7+3AO4Cb6rY3ATgG2BeYBFyQyzEzswGi\nzJ7OgUBXRCyNiGeAK4ApdetMAWZFMg8YIWmXZnkj4s6IWNJge1OAKyJiXUTcB3TlcszMbIAoM+iM\nApYVXi/Paa2s00re3mwPSdMkLZC0YNWqVT0UaWZmfWnITSSIiIsiYmJETBw5cmR/V8fMbEgp8yZu\nK4Axhdejc1or6wxvIW9vtmdmZv2ozJ7OfGC8pHGSNied5J9dt85sYGqexXYwsDoiVraYt95s4BhJ\nW0gaR5qccHNfNsjMzDZNaT2diFgv6WTgOmAYcHFELJY0PS+fCcwBjiSd9F8DnNAsL4CktwNfB0YC\nP5e0MCKOyGVfCdwBrAdOiogNZbVvqPGtrM2sL5Q5vEZEzCEFlmLazMLzAE5qNW9Ovxq4ups8ZwNn\nb0KVzcysRENuIoGZmfUfBx0zM6uMg46ZmVXGQcfMzCpT6kQC60yeyWZmveWejpmZVcZBx8zMKuOg\nY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrjIOOmZlVxlcksE3iqxOYWTscdEpSPBib\nmVni4TUzM6uMezrWZzzUZmY9cU/HzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0z\nM6tMqUFH0iRJSyR1SZrRYLkknZeXL5K0f095Je0oaa6ke/LfF+f0sZLWSlqYHzPLbJuZmbWvtKAj\naRhwPjAZmAAcK2lC3WqTgfH5MQ24sIW8M4BfRcR44Ff5dc29EbFffkwvp2VmZtZbZfZ0DgS6ImJp\nRDwDXAFMqVtnCjArknnACEm79JB3CnBpfn4p8E8ltsHMzPpQmZfBGQUsK7xeDhzUwjqjesj70ohY\nmZ//BXhpYb1xkhYCq4FPRcSv6yslaRqpV8Vuu+3WTnusDb4kjpk1MqgnEkREAJFfrgR2i4j9gFOB\n70vavkGeiyJiYkRMHDlyZIW1NTOzMns6K4Axhdejc1or6wxvkvchSbtExMo8FPcwQESsA9bl57dI\nuhfYC1jQN82x3nKvx8xqyuzpzAfGSxonaXPgGGB23Tqzgal5FtvBwOo8dNYs72zg/fn5+4GfAkga\nmScgIGkP0uSEpeU1z8zM2lVaTyci1ks6GbgOGAZcHBGLJU3Py2cCc4AjgS5gDXBCs7y56HOAKyWd\nCDwAvDunHwqcJelZ4DlgekQ8Vlb7zMysfaXeTyci5pACSzFtZuF5ACe1mjenPwq8sUH6VcBVm1hl\nMzMr0aCeSGBmZoOL7xxqlfKkArOhzT0dMzOrjIOOmZlVxsNr1m881GY29LinY2ZmlXHQMTOzyjjo\nmJlZZXxOxwaE2vmd5Y+vZfSLt+rn2phZWdzTMTOzyjjo2ICz/PG1nDv37o1mt5lZZ3DQMTOzyjjo\nmJlZZRx0zMysMp69ZgOar1pg1lnc0zEzs8q4p2ODhns9ZoOfezpmZlYZ93T6kH9XYmbWnIOODUoe\najMbnBx0bNBzADIbPBx0rKN0N8TpYGQ2MDjo2JDg3pDZwOCgs4k8eWDwcQAy6z8OOjakOQCZVctB\nxyzz+SCz8pUadCRNAr4GDAO+HRHn1C1XXn4ksAY4PiL+2CyvpB2BHwJjgfuBd0fE43nZ6cCJwAbg\noxFxXZnts6GhlSFUByaz1pQWdCQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQH8KiLOkTQj\nv/6EpAnAMcC+wK7A9ZL2iogNZbXRrKbdc3sOUjZUldnTORDoioilAJKuAKYAxaAzBZgVEQHMkzRC\n0i6kXkx3eacAh+X8lwI3AJ/I6VdExDrgPklduQ6/L7GNZr3SnxNQigHP57SsamUGnVHAssLr5aTe\nTE/rjOoh70sjYmV+/hfgpYWy5jUoayOSpgHT8sunJC1ppTEFOwOP1F7ceOPebWYf1DZqe5lOffOA\nfF8ra3+ZTm0zPeuItm+Codz+Rm3fvbeFDeqJBBERkqLNPBcBF/V2m5IWRMTE3uYfzIZy22Fot38o\ntx2Gdvv7uu1lXmV6BTCm8Hp0TmtlnWZ5H8pDcOS/D7exPTMz60dlBp35wHhJ4yRtTjrJP7tundnA\nVCUHA6vz0FmzvLOB9+fn7wd+Wkg/RtIWksaRJifcXFbjzMysfaUNr0XEekknA9eRpj1fHBGLJU3P\ny2cCc0jTpbtIU6ZPaJY3F30OcKWkE4EHgHfnPIslXUmabLAeOKmkmWu9HprrAEO57TC02z+U2w5D\nu/192naliWNmZmbl851DzcysMg46ZmZWGQedFkmaJGmJpK58JYSOJOl+SbdJWihpQU7bUdJcSffk\nvy8urH96fk+WSDqi/2rePkkXS3pY0u2FtLbbKuk1+T3rknRevrzTgNdN+8+UtCLv/4WSjiws65j2\nSxoj6b8k3SFpsaR/zekdv/+btL2afR8RfvTwIE1muBfYA9gcuBWY0N/1Kqmt9wM716V9CZiRn88A\nvpifT8jvxRbAuPweDevvNrTR1kOB/YHbN6WtpFmSBwMCfgFM7u+2bUL7zwROa7BuR7Uf2AXYPz/f\nDrg7t7Hj93+Ttley793Tac3zl/SJiGeA2mV5hooppEsOkf/+UyH9iohYFxH3kWYhHtgP9euViLgJ\neKwuua225t+KbR8R8yL9F84q5BnQuml/dzqq/RGxMvLFhSPiSeBO0hVMOn7/N2l7d/q07Q46renu\ncj2dKEgXS70lXzIIml96qNPel3bbOio/r08fzD4iaVEefqsNL3Vs+yWNBV4N/IEhtv/r2g4V7HsH\nHav3+ojYj3QF8JMkHVpcmL/RDIl59kOprQUXkoaR9wNWAl/p3+qUS9K2wFXA/4mIJ4rLOn3/N2h7\nJfveQac1Q+YSOxGxIv99GLiaNFw2lC491G5bV+Tn9emDUkQ8FBEbIuI54Fu8MFzace2XNJx00P1e\nRPwkJw+J/d+o7VXtewed1rRySZ9BT9I2krarPQfeDNzO0Lr0UFttzUMxT0g6OM/cmVrIM+jUDrjZ\n20n7Hzqs/bmu3wHujIivFhZ1/P7vru2V7fv+nkkxWB6ky/XcTZq5cUZ/16ekNu5BmqVyK7C41k5g\nJ+BXwD3A9cCOhTxn5PdkCQN81k6D9v6ANIzwLGk8+sTetBWYmP9B7wW+Qb7Sx0B/dNP+y4DbgEX5\nYLNLJ7YfeD1p6GwRsDA/jhwK+79J2yvZ974MjpmZVcbDa2ZmVhkHHTMzq4yDjpmZVcZBx8zMKuOg\nY2ZmlXHQMQMkPVVy+cdL2rXw+n5JO7eQ79WSvlNy3S6R9M4my0+W9IEy62BDh4OOWTWOB3btaaUG\nPgmc11chA3JvAAADJklEQVSVkNSbW9RfDHykr+pgQ5uDjlk3JI2UdJWk+flxSE4/M18Q8QZJSyV9\ntJDn0/meI7+R9ANJp+VexETge/k+JVvl1T8i6Y/5fiT7NNj+dsArI+LW/Po2SSOUPCppak6fJelw\nSVtK+m5e70+S/mdefryk2ZL+P/CrnP8buZ7XAy8pbPMcpfusLJL0HwARsQa4X9KguYK4DVwOOmbd\n+xpwbkQcABwNfLuwbB/gCNL1qT4rabik2nqvIl0wdSJARPwYWAAcFxH7RcTaXMYjEbE/6UKLpzXY\nfu3X3jW/BQ4B9gWWAm/I6a8FfgeclDYX/wM4FrhU0pZ5nf2Bd0bEP5IucbI36T4pU4HXAUjaKS/b\nNyJeCXy+sO0Fhe2Z9VpvutpmQ8WbgAl64WaI2+cr8wL8PCLWAeskPUy6BP4hwE8j4mngaUk/66H8\n2kUmbwHe0WD5LsCqwutfk2689gApUE2TNAp4PCL+Jun1wNcBIuIuSQ8Ae+W8cyOidu+cQ4EfRMQG\n4MHcAwJYDTwNfEfStcC1hW0/TAq0ZpvEPR2z7r0IODj3TvaLiFERUZtwsK6w3gZ69wWuVkZ3+dcC\nWxZe30TqbbwBuIEUkN5JCkY9+VtPK0TEelLP7cfAUcB/FhZvmetjtkkcdMy690sKJ9Al7dfD+r8F\n3prPrWxLOnDXPEm6NXA77gT2rL2IiGXAzsD4iFgK/IY0LHdTXuXXwHG5rnsBu5Eu0FjvJuA9kobl\nKwvXzv1sC+wQEXOAU0jDhDV7sfFQn1mvOOiYJVtLWl54nAp8FJiYT6rfAUxvVkBEzCddnXcR6X7x\nt5GGrAAuAWbWTSRoKiLuAnao3W4i+wPpaueQgswoUvABuAB4kaTbgB8Cx+chwHpXk66ifAfpFsO/\nz+nbAddKWpTLPLWQ5xBgbiv1NmvGV5k260OSto2IpyRtTepRTIt8P/pelncK8GREfLvHlUsi6dXA\nqRHxvv6qg3UO93TM+tZFkhYCfwSu2pSAk13IxueP+sPOwKf7uQ7WIdzTMTOzyrinY2ZmlXHQMTOz\nyjjomJlZZRx0zMysMg46ZmZWmf8GQp4xDBcjJYgAAAAASUVORK5CYII=\n", + "text/plain": [ + "<matplotlib.figure.Figure at 0x2d78f12e518>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# We decide the max and min length (in words) of discharge notes.\n", + "\n", + "plt.hist([len(x) for x in train_x], normed=1, bins=100, alpha=.5)\n", + "plt.vlines(MAX_NOTE_LEN, 0, .003); plt.vlines(MIN_NOTE_LEN, 0, .003);\n", + "plt.xlabel(\"Length (words)\"); plt.ylabel(\"Density\");\n", + "plt.title(\"Distribution of Length of Discharge Notes\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Keep only the notes that are long enough:\n", + "subset_train = set(np.where([len(x) >= MIN_NOTE_LEN for x in train_x])[0])\n", + "subset_test = set(np.where([len(x) >= MIN_NOTE_LEN for x in test_x])[0])\n", + "subset_valid = set(np.where([len(x) >= MIN_NOTE_LEN for x in valid_x])[0])\n", + "\n", + "def getsubset(orig, index):\n", + " return([j for i,j in enumerate(orig) if i in index])\n", + "\n", + "# Pad the notes that are too short:\n", + "train_x = pad_sequences(getsubset(train_x, subset_train), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", + "valid_x = pad_sequences(getsubset(valid_x, subset_valid), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", + "test_x = pad_sequences(getsubset(test_x, subset_test), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", + "\n", + "train_y = np.array(getsubset(train_y, subset_train))\n", + "valid_y = np.array(getsubset(valid_y, subset_valid))\n", + "test_y = np.array(getsubset(test_y, subset_test))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Defining the neural network" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "seq_input_layer = Input(shape=(MAX_NOTE_LEN,), dtype='int32')\n", + "\n", + "embedded_layer = Embedding(embeddings_matrix.shape[0], embeddings_matrix.shape[1],\n", + " weights = [embeddings_matrix],\n", + " input_length = MAX_NOTE_LEN,\n", + " trainable = True)(seq_input_layer)\n", + "\n", + "conv_layer = Conv1D(UNITS, FILTERSIZE, activation='tanh')(embedded_layer)\n", + "\n", + "pool_layer = GlobalMaxPooling1D()(conv_layer)\n", + "\n", + "out_layer = Dense(1, \n", + " activation = 'sigmoid', \n", + " activity_regularizer = l1(REG_FACTOR)\n", + " )(pool_layer)\n", + "\n", + "optimizer = RMSprop(lr = LEARNING_RATE)\n", + "model = Model(inputs=seq_input_layer, outputs=out_layer)\n", + "model.compile(loss=LOSS_FUNC, optimizer=optimizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) (None, 700) 0 \n", + "_________________________________________________________________\n", + "embedding_1 (Embedding) (None, 700, 1000) 22331000 \n", + "_________________________________________________________________\n", + "conv1d_1 (Conv1D) (None, 698, 450) 1350450 \n", + "_________________________________________________________________\n", + "global_max_pooling1d_1 (Glob (None, 450) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 1) 451 \n", + "=================================================================\n", + "Total params: 23,681,901\n", + "Trainable params: 23,681,901\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training the neural net" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "# Load the weights from a previous run, or train the model anew:\n", + "if isfile(CNN_FILENAME):\n", + " model.load_weights(CNN_FILENAME)\n", + "else:\n", + " model.fit(train_x, train_y, \n", + " batch_size = BATCH_SIZE, \n", + " epochs = EPOCHS, \n", + " validation_data = (valid_x, valid_y), \n", + " verbose = True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}