414 lines (413 with data), 26.1 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [Constants](#Constants)\n",
"* [Load data](#Load-data)\n",
"* [Train Word2Vec](#Train-Word2Vec)\n",
"* [Prepare text](#Prepare-text)\n",
"* [Defining the neural network](#Defining-the-neural-network) \n",
"* [Training the neural net](#Training-the-neural-net)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import string\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from os.path import isfile\n",
"\n",
"from keras.models import Model\n",
"from keras.preprocessing.sequence import pad_sequences\n",
"from keras.layers import Embedding, Input, Conv1D, Dense, GlobalMaxPooling1D\n",
"from keras.optimizers import RMSprop\n",
"from keras.regularizers import l1\n",
"\n",
"from gensim.models import word2vec\n",
"from gensim.models import KeyedVectors\n",
"\n",
"\n",
"import logging\n",
"logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Constants"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Location of train/test data files generated by TextSections/TextPrep\n",
"TRAIN_DATA_LOC = \"~/train_data.csv\"\n",
"TEST_DATA_LOC = \"~/test_data.csv\"\n",
"\n",
"# Columns we will use:\n",
"VISITID = \"visit_id\"\n",
"OUTCOME = \"readmitted\" # e.g. ReadmissionInLessThan30Days\n",
"\n",
"# Test/Train split\n",
"SPLIT_SIZE = 0.9 # relative size of train:test\n",
"SPLIT_SEED = 1234\n",
"\n",
"# Word2Vec hyperparameters\n",
"WINDOW = 2\n",
"DIMENSIONS = 1000\n",
"MIN_COUNT = 5\n",
"USE_SKIPGRAM = True \n",
"USE_HIER_SMAX = False \n",
"NUM_THREADS = 50\n",
"# Where to save the w2v model:\n",
"W2V_FILENAME = './w2v_dims_{dims}_window_{window}.bin'.format(\n",
" dims = DIMENSIONS,\n",
" window = WINDOW\n",
")\n",
"\n",
"\n",
"# Text Prep\n",
"PADDING = \"PADDING\"\n",
"MAX_NOTE_LEN = 700\n",
"MIN_NOTE_LEN = 20\n",
"\n",
"# Model Architecture\n",
"UNITS = 450\n",
"FILTERSIZE = 3\n",
"LEARNING_RATE = 0.0001\n",
"LOSS_FUNC = 'binary_crossentropy'\n",
"REG_FACTOR = 0.05\n",
"\n",
"# Model Training\n",
"CNN_FILENAME = \"./cnn.h5\"\n",
"BATCH_SIZE = 100\n",
"EPOCHS = 4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"# Read train and test hospital data.\n",
"train = pd.read_csv(TRAIN_DATA_LOC, dtype = str)\n",
"test = pd.read_csv(TEST_DATA_LOC, dtype = str)\n",
"\n",
"# Split the train data into a train and validation set.\n",
"train, valid = train_test_split(train, \n",
" stratify = train[OUTCOME], \n",
" train_size = SPLIT_SIZE, \n",
" random_state = SPLIT_SEED)\n",
"\n",
"# Prepare the sections.\n",
"# If `sectiontext` is present, then include \"SECTIONNAME sectiontext\".\n",
"# If not present, include only \"SECTIONNAME\".\n",
"SECTIONNAMES = [x for x in trainTXT.columns if VISITID not in x and OUTCOME not in x]\n",
"for x in SECTIONNAMES:\n",
" rep = x.replace(\" \", \"_\").upper()\n",
" train[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in train[x]]\n",
" valid[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in valid[x]]\n",
" test[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in test[x]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train Word2Vec"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2017-10-27 12:32:33,194 : INFO : loading projection weights from ./word2vec/w2v_dims_1000_window_2.bin\n",
"2017-10-27 12:32:33,507 : INFO : loaded (22330, 1000) matrix from ./word2vec/w2v_dims_1000_window_2.bin\n"
]
}
],
"source": [
"# We will remove digits and punctuation:\n",
"remove_digits_punc = str.maketrans('', '', string.digits + ''.join([x for x in string.punctuation if '_' not in x]))\n",
"remove_digits_punc = {a:\" \" for a in remove_digits_punc.keys()}\n",
"\n",
"# (If the model already exists, don't recompute.)\n",
"if not isfile(W2V_FILENAME):\n",
" # Use only training data to train word2vec:\n",
" notes = train[SECTIONNAMES].apply(lambda x: \" \".join(x), axis=1).values \n",
" stop = set([x for x in string.ascii_lowercase]) \n",
" for i in range(len(notes)):\n",
" notes[i] = [w for w in notes[i].translate(remove_digits_punc).split() if (w not in stop)]\n",
" \n",
" w2v = word2vec.Word2Vec(notes, \n",
" size = DIMENSIONS, \n",
" window = WINDOW, \n",
" sg = USE_SKIPGRAM, \n",
" hs = USE_HIER_SMAX, \n",
" min_count = MIN_COUNT, \n",
" workers = NUM_THREADS)\n",
" w2v.wv.save_word2vec_format(W2V_FILENAME, binary=True)\n",
"else:\n",
" w2v = KeyedVectors.load_word2vec_format(W2V_FILENAME, binary=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Make the embedding matrix.\n",
"# We include one extra word, `PADDING`. This is the word that will right-pad short notes.\n",
"# For `PADDING`'s vector representation, we choose the zero vector.\n",
"vocab = [PADDING] + sorted(list(w2v.wv.vocab.keys()))\n",
"vset = set(vocab)\n",
"\n",
"embeddings_index = {}\n",
"for i in range(len(vocab)):\n",
" embeddings_index[vocab[i]] = i\n",
"\n",
"# reverse_embeddings_index = {b:a for a,b in embeddings_index.items()}\n",
"\n",
"# Adding PADDING as vocab word with embedding value of a zero vector\n",
"embeddings_matrix = np.matrix(np.concatenate(([[0.] * DIMENSIONS], [w2v[x] for x in vocab[1:]])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Prepare text"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"train_x = train[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n",
"test_x = test[ SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n",
"valid_x = valid[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n",
"\n",
"train_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in train_x]\n",
"valid_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in valid_x]\n",
"test_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in test_x]\n",
"\n",
"train_y = train[OUTCOME]\n",
"valid_y = valid[OUTCOME]\n",
"test_y = test[OUTCOME]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEWCAYAAAC9qEq5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X2cHFWd7/HP1xCeHyIQFZJAggTY4FXE8KAoy72KJIhm\nFR9g0QiiuXFB94KsBvEBWbmirnJFgYiKEFARRTBiXAzeBXyKJGgIBAgMAUxChPBgABMCCb/945yG\nStvT0z2Zqpnp+b5fr35N96k6p87p6qlfn1OnqxQRmJmZVeFF/V0BMzMbOhx0zMysMg46ZmZWGQcd\nMzOrjIOOmZlVxkHHzMwq46DTYSTNlPTpPiprN0lPSRqWX98g6YN9UXYu7xeS3t9X5bWx3c9LekTS\nX6redqskHSZpeR+VJUnflfS4pJt7WcZGn4VelnG/pDf1Nr91BgedQST/066V9KSkv0r6naTpkp7f\njxExPSL+vcWymh4AIuLPEbFtRGzog7qfKenyuvInR8Slm1p2m/XYDfgYMCEiXtZgeZ8d7NusV0ja\ns6TiXw8cDoyOiAMbbPt4SRtyUHlK0n05SO1VW6cvPwsDUf5/eFjSNoW0D0q6ocX8ffqFrJM56Aw+\nb42I7YDdgXOATwDf6euNSNqsr8scIHYDHo2Ih/u7IhXaHbg/Iv7WZJ3fR8S2wA7Am4C1wC2SXlFF\nBduRe25lHLuGAf9aQrlWFBF+DJIHcD/wprq0A4HngFfk15cAn8/PdwauBf4KPAb8mvRF47KcZy3w\nFPBxYCwQwInAn4GbCmmb5fJuAL4A3Aw8AfwU2DEvOwxY3qi+wCTgGeDZvL1bC+V9MD9/EfAp4AHg\nYWAWsENeVqvH+3PdHgHOaPI+7ZDzr8rlfSqXXzuYPpfrcUmDvH/XjsKyLYD/yHV4CJgJbFXMR+pF\nPQysBE4o5N0J+Fl+3+YDnwd+k5fdlNv3t1yv9/RUXoO67QrMzvu5C/hQTj8ReBrYkMv+XIO8x9fq\nUpd+LfDjun2wWSHPUuBJ4D7guEK+DwF35mV3APsXPg+nAYuA1cAPgS3zshfn7a0CHs/PRxfKvAE4\nG/ht3od7AuPye/ckcD1wPnB5Ic/BwO9In/9bgcN6+N+akd+/ETntg8ANhXVel/fd6vz3dTn97Pz+\nPp3f42/k9H2AubnMJcC7C2Udmd+bJ4EVwGn9fXyp7DjW3xXwo42d1SDo5PQ/Ax/Ozy/hhaDzBdKB\ncXh+vAFQo7IKB5VZwDbAVg0ONDfkf5BX5HWuqv2T0yTo5OdnFg8IhfJqQecDpIPlHsC2wE+Ay+rq\n9q1cr1cB64B/6OZ9mkUKiNvlvHcDJ3ZXz7q83S4HziUd2HfMZf8M+EIh33rgrPxeHwmsAV6cl1+R\nH1sDE4BlFA70uX171tWj2/Ia1O0m4AJgS2A/0sH7f+Vlx9MgqBTyNlye98lDdftgs7zvnwD2zst2\nAfbNz9+VPyMHACIFh90Ln4ebSQFyR1Jgmp6X7QQcnd+f7YAfAdfUfVb+DOyb6zAc+D3pS8DmpCHE\nJ3jh8zgKeDS/by8iDS8+Coxs9r9F+tzV/n+eDzq5vo8D78vbPza/3qn+s5xfb5P38Ql5/VeTvixN\nyMtXAm/Iz19MDsxD4eHhtc7wIOmfot6zpAPC7hHxbET8OvKnvIkzI+JvEbG2m+WXRcTtkYZqPg28\ne1NOLhccB3w1IpZGxFPA6cAxdcN8n4uItRFxK+mb66vqC8l1OQY4PSKejIj7ga+QDha9JknANOCU\niHgsIp4E/m/eVs2zwFn5vZ5D+ta7d67T0cBnI2JNRNwBtHIuq2F5Deo2BjgE+EREPB0RC4FvA1N7\n3eCku88V5N61pK0iYmVELM7pHwS+FBHzI+mKiAcK+c6LiAcj4jFS0N4PICIejYir8vvzJKn38I91\n27wkIhZHxHrS5/oA4DMR8UxE/Ib0haDmvcCciJgTEc9FxFxgASkINfMZ4COSRtalvwW4JyIui4j1\nEfED4C7grd2UcxRpSPO7ef0/kb6kvSsvfxaYIGn7iHg8Iv7YQ706hoNOZxhF6sLX+zKp9/BLSUsl\nzWihrGVtLH+A9I1z55Zq2dyuubxi2ZsBLy2kFWebrSH1iOrtnOtUX9aoTazfSNK38FvyJI6/Av+Z\n02sezQfE+jqOJLWl+N719D43K6/erkAtENb0RZsbfq7yF473ANOBlZJ+LmmfvHgMcG+TMhvuQ0lb\nS/qmpAckPUHquY2o+0JTfM9qbV7TzfLdgXfV9lXeX68nBatuRcTtpKG9+v+V+s8nNH+PdwcOqtv+\ncUBt8srRpAD4gKQbJb22Wb06iYPOICfpANIH/zf1y/I3/Y9FxB7A24BTJb2xtribInvqCY0pPN+N\n9I3tEdL5iK0L9RrGxgfknsp9kPSPWix7PencSTseyXWqL2tFm+U0KnctaRhpRH7sEOnke09Wkdoy\nupA2ppt1e+NBYEdJ2xXS+qLNbyedB/w7EXFdRBxOOojfRRr6hHTgf3kvtvUxUi/uoIjYHjg0p6u4\n2cLzlaQ2b11IK76ny0i98hGFxzYRcU4Ldfks6bxUMaDUfz5h4/e4/vO9DLixbvvbRsSHAXJPcArw\nEuAa4MoW6tURHHQGKUnbSzqKdJ7g8oi4rcE6R0naMw8NrSad7HwuL36IdP6kXe+VNCH/s59FOtG8\ngXTeZEtJb5E0nHTyfotCvoeAsU1mHf0AOEXSOEnbkoauflj3Tb9HuS5XAmdL2k7S7sCpwOXNc25M\n0pbFBy+cUzpX0kvyOqMkHdFinX4CnJm/0e/D3w999XZ/EBHLSCfMv5Dr+0rSBIK22gzpy0LeB18n\nnVf6XIN1XippSp5evI407Ff7XH0bOE3Sa/Issz3zPujJdqSg/ldJO5IO/N3KQ3YLSO/p5rmnUBzq\nuhx4q6Qjcpu2zNPhRzcscOOyu0iTHD5aSJ4D7CXpnyVtJuk9pHNz1+bl9fvv2rz++yQNz48DJP1D\nru9xknaIiGdJ56KeY4hw0Bl8fibpSdI3qTOAr5JOVjYynjSr5ynSSdcLIuK/8rIvAJ/KXf/T2tj+\nZaTJCn8hnbT+KEBErAb+hXTQWUHq+RR/7/Kj/PdRSY3Gry/OZd9Emg31NPCRNupV9JG8/aWkHuD3\nc/mtGkU6ABYfLydNT+8C5uUhoOtpcI6lGyeTZtX9hdTOH5AO2DVnApfm/fHuNupacyzpZP+DwNWk\n80fXt5H/tZKeIh0AbwC2Bw5o9GWGdNw4NW/rMdK5l9o3+B+Rzsd8nzQz6xq6Py9U9P9Ik0QeAeaR\nhi57chzwWtIEgc+TAsW6XI9lwBTgk6Se5jLg32j9mHcWaTIAubxHSedpPpa393HgqIh4JK/yNeCd\nSj/APS8Pdb6ZdM7vQdJ+/yIvfBF7H3B//hxNz20ZEmozmcysQpK+CLwsIiq/IkOnkvRD4K6IaNpL\nsv7lno5ZBSTtI+mVecjpQNLw19X9Xa/BLA9XvVzSiyRNIvVsrunvellznfqrc7OBZjvSkNqupPH/\nr5B+S2S99zLSubKdSEO5H85Tk20A8/CamZlVxsNrZmZWmSE9vLbzzjvH2LFj+7saZmaDyi233PJI\nRNRftaElQzrojB07lgULFvR3NczMBhVJ9VdnaJmH18zMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zM\nKuOgY2ZmlSk16EiaJGmJpK5GNxDL16E6Ly9fJGn/nvJK+ve87kJJv5S0a2HZ6Xn9Ja1cct7MzKpV\nWtDJN/E6H5hMuu/EsZIm1K02mXT5/fGkWwFf2ELeL0fEKyNiP9I9Kz6T80wgXUZ8X2AScIH65jbK\nZmbWR8rs6RwIdOV73j9DutnYlLp1pgCz8r3U55FuT7tLs7wR8UQh/za8cMe+KcAVEbEuIu4j3ffk\nwLIaZ2Zm7Ssz6Ixi43uWL+fv7yfe3TpN80o6W9Iy0o2PPtPG9pA0TdICSQtWrVrVVoOKDjvsMA47\n7LBe57fG/L6adbZBOZEgIs6IiDHA90h3ZGwn70URMTEiJo4c2atLB1Xq3Ll3P/8wMxvsygw6K4Ax\nhdejc1or67SSF1LQObqN7ZmZWT8qM+jMB8ZLGidpc9JJ/tl168wGpuZZbAcDqyNiZbO8ksYX8k8B\n7iqUdYykLSSNI01OuLmsxpmZWftKu8p0RKyXdDJwHTAMuDgiFkuanpfPBOYAR5JO+q8BTmiWNxd9\njqS9geeAB4BaeYslXQncAawHToqIDWW1z8zM2lfqrQ0iYg4psBTTZhaeB3BSq3lz+tENVq8tOxs4\nu7f1NTOzcg3KiQRmZjY4OeiYmVllHHTMzKwyDjpmZlYZBx0zM6tMqbPXrG8Vr0pwyuF79WNNzMx6\nxz0dMzOrjIOOmZlVxkHHzMwq43M6g5TP75jZYOSejpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZ\nZRx0zMysMg46ZmZWGf9OZwAq/gbHzKyTuKdjZmaVcdAxM7PKeHhtgPCQmpkNBe7pmJlZZRx0zMys\nMg46ZmZWGQcdMzOrTKlBR9IkSUskdUma0WC5JJ2Xly+StH9PeSV9WdJdef2rJY3I6WMlrZW0MD9m\nltk2MzNrX2mz1yQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQucHhHrJX0ROB34RC7v3ojY\nr6w2DVS+oZuZDRZl9nQOBLoiYmlEPANcAUypW2cKMCuSecAISbs0yxsRv4yI9Tn/PGB0iW0wM7M+\nVGbQGQUsK7xentNaWaeVvAAfAH5ReD0uD63dKOkNjSolaZqkBZIWrFq1qrWWmJlZnxi0EwkknQGs\nB76Xk1YCu+XhtVOB70vavj5fRFwUERMjYuLIkSOrq7CZmZV6RYIVwJjC69E5rZV1hjfLK+l44Cjg\njRERABGxDliXn98i6V5gL2BBH7SlFL4KgZkNNWX2dOYD4yWNk7Q5cAwwu26d2cDUPIvtYGB1RKxs\nllfSJODjwNsiYk2tIEkj8wQEJO1BmpywtMT2mZlZm0rr6eTZZScD1wHDgIsjYrGk6Xn5TGAOcCTQ\nBawBTmiWNxf9DWALYK4kgHkRMR04FDhL0rPAc8D0iHisrPaZmVn7Sr3gZ0TMIQWWYtrMwvMATmo1\nb07fs5v1rwKu2pT6mplZuQbtRAIzMxt8HHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQcfMzCpT6pRp\nq56vOG1mA5l7OmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46\nZmZWGQcdMzOrjIOOmZlVxkHHzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyvolbB/MN3cxsoHFP\nx8zMKuOgY2ZmlSk16EiaJGmJpC5JMxosl6Tz8vJFkvbvKa+kL0u6K69/taQRhWWn5/WXSDqizLb1\n1rlz737+YWY21JQWdCQNA84HJgMTgGMlTahbbTIwPj+mARe2kHcu8IqIeCVwN3B6zjMBOAbYF5gE\nXJDLMTOzAaLMns6BQFdELI2IZ4ArgCl160wBZkUyDxghaZdmeSPilxGxPuefB4wulHVFRKyLiPuA\nrlyOmZkNEGUGnVHAssLr5TmtlXVayQvwAeAXbWwPSdMkLZC0YNWqVS00w8zM+sqgnUgg6QxgPfC9\ndvJFxEURMTEiJo4cObKcypmZWUNl/k5nBTCm8Hp0TmtlneHN8ko6HjgKeGNERBvbMzOzflRmT2c+\nMF7SOEmbk07yz65bZzYwNc9iOxhYHRErm+WVNAn4OPC2iFhTV9YxkraQNI40OeHmEttnZmZtKq2n\nExHrJZ0MXAcMAy6OiMWSpuflM4E5wJGkk/5rgBOa5c1FfwPYApgrCWBeREzPZV8J3EEadjspIjaU\n1T4zM2tfqZfBiYg5pMBSTJtZeB7ASa3mzel7Ntne2cDZva2vmZmVa9BOJDAzs8HHQcfMzCrjoGNm\nZpVx0DEzs8r4fjpDhO+tY2YDgXs6ZmZWGQcdMzOrTEtBR9JPJL1FkoOUmZn1WqtB5ALgn4F7JJ0j\nae8S62RmZh2qpaATEddHxHHA/sD9wPWSfifpBEnDy6ygmZl1jpaHyyTtBBwPfBD4E/A1UhCaW0rN\nzMys47Q0ZVrS1cDewGXAW/OVoAF+KGlBWZUzM7PO0urvdL6VL8D5PElb5FtDTyyhXmZm1oFaHV77\nfIO03/dlRczMrPM17elIehkwCthK0qsB5UXbA1uXXDczM+swPQ2vHUGaPDAa+Goh/UngkyXVyczM\nOlTToBMRlwKXSjo6Iq6qqE5mZtahehpee29EXA6MlXRq/fKI+GqDbGZmZg31NLy2Tf67bdkVMTOz\nztfT8No389/PVVMdMzPrZK1e8PNLkraXNFzSryStkvTesitnZmadpdXf6bw5Ip4AjiJde21P4N/K\nqpSZmXWmVoNObRjuLcCPImJ1SfUxM7MO1uplcK6VdBewFviwpJHA0+VVy8zMOlFLQSciZkj6ErA6\nIjZI+hswpdyqWVnOnXv3889POXyvfqyJmQ017dwJdB/gPZKmAu8E3txTBkmTJC2R1CVpRoPlknRe\nXr5I0v495ZX0LkmLJT0naWIhfayktZIW5sfMNtpmZmYVaPXWBpcBLwcWAhtycgCzmuQZBpwPHA4s\nB+ZLmh0RdxRWmwyMz4+DgAuBg3rIezvwDuCbDTZ7b0Ts10qbzMyseq2e05kITIiIaKPsA4GuiFgK\nIOkK0pBcMehMAWblcudJGiFpF2Bsd3kj4s6c1kZVzMxsIGh1eO124GVtlj0KWFZ4vTyntbJOK3kb\nGZeH1m6U9IZGK0iaJmmBpAWrVq1qoUgzM+srrfZ0dgbukHQzsK6WGBFvK6VWvbMS2C0iHpX0GuAa\nSfvm3xc9LyIuAi4CmDhxYjs9NzMz20StBp0ze1H2CmBM4fXonNbKOsNbyLuRiFhHDogRcYuke4G9\nAN9O28xsgGhpeC0ibiRdiWB4fj4f+GMP2eYD4yWNk7Q5cAwwu26d2cDUPIvtYNKU7JUt5t2IpJF5\nAgKS9iBNTljaSvvMzKwarV577UPAj3lhxtgo4JpmeSJiPXAycB1wJ3BlRCyWNF3S9LzaHFJg6AK+\nBfxLs7y5Lm+XtBx4LfBzSdflsg4FFklamOs6PSIea6V9ZmZWjVaH104izUb7A0BE3CPpJT1liog5\npMBSTJtZeB657Jby5vSrgasbpF8FDMgbzRV/jDnQ+IeiZlalVmevrYuIZ2ovJG1G+p2OmZlZy1oN\nOjdK+iSwlaTDgR8BPyuvWmZm1olaDTozgFXAbcD/Jg17faqsSpmZWWdq9YKfz0m6BrgmIvyLSjMz\n65WmPZ08lflMSY8AS4Al+a6hn6mmemZm1kl6Gl47BTgEOCAidoyIHUkX5jxE0iml187MzDpKT0Hn\nfcCxEXFfLSFfhPO9wNQyK2ZmZp2np6AzPCIeqU/M53WGl1MlMzPrVD0FnWd6uczMzOzv9DR77VWS\nnmiQLmDLEupjZmYdrGnQiYhhVVXEzMw6X6s/DjUzM9tkDjpmZlYZBx0zM6tMq7c2sCHAtzkws7K5\np2NmZpVx0DEzs8o46JiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOz\nyjjomJlZZUoNOpImSVoiqUvSjAbLJem8vHyRpP17yivpXZIWS3pO0sS68k7P6y+RdESZbet05869\ne6OHmVlfKC3oSBoGnA9MBiYAx0qaULfaZGB8fkwDLmwh7+3AO4Cb6rY3ATgG2BeYBFyQyzEzswGi\nzJ7OgUBXRCyNiGeAK4ApdetMAWZFMg8YIWmXZnkj4s6IWNJge1OAKyJiXUTcB3TlcszMbIAoM+iM\nApYVXi/Paa2s00re3mwPSdMkLZC0YNWqVT0UaWZmfWnITSSIiIsiYmJETBw5cmR/V8fMbEgp8yZu\nK4Axhdejc1or6wxvIW9vtmdmZv2ozJ7OfGC8pHGSNied5J9dt85sYGqexXYwsDoiVraYt95s4BhJ\nW0gaR5qccHNfNsjMzDZNaT2diFgv6WTgOmAYcHFELJY0PS+fCcwBjiSd9F8DnNAsL4CktwNfB0YC\nP5e0MCKOyGVfCdwBrAdOiogNZbVvqPGtrM2sL5Q5vEZEzCEFlmLazMLzAE5qNW9Ovxq4ups8ZwNn\nb0KVzcysRENuIoGZmfUfBx0zM6uMg46ZmVXGQcfMzCpT6kQC60yeyWZmveWejpmZVcZBx8zMKuOg\nY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrjIOOmZlVxlcksE3iqxOYWTscdEpSPBib\nmVni4TUzM6uMezrWZzzUZmY9cU/HzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0z\nM6tMqUFH0iRJSyR1SZrRYLkknZeXL5K0f095Je0oaa6ke/LfF+f0sZLWSlqYHzPLbJuZmbWvtKAj\naRhwPjAZmAAcK2lC3WqTgfH5MQ24sIW8M4BfRcR44Ff5dc29EbFffkwvp2VmZtZbZfZ0DgS6ImJp\nRDwDXAFMqVtnCjArknnACEm79JB3CnBpfn4p8E8ltsHMzPpQmZfBGQUsK7xeDhzUwjqjesj70ohY\nmZ//BXhpYb1xkhYCq4FPRcSv6yslaRqpV8Vuu+3WTnusDb4kjpk1MqgnEkREAJFfrgR2i4j9gFOB\n70vavkGeiyJiYkRMHDlyZIW1NTOzMns6K4Axhdejc1or6wxvkvchSbtExMo8FPcwQESsA9bl57dI\nuhfYC1jQN82x3nKvx8xqyuzpzAfGSxonaXPgGGB23Tqzgal5FtvBwOo8dNYs72zg/fn5+4GfAkga\nmScgIGkP0uSEpeU1z8zM2lVaTyci1ks6GbgOGAZcHBGLJU3Py2cCc4AjgS5gDXBCs7y56HOAKyWd\nCDwAvDunHwqcJelZ4DlgekQ8Vlb7zMysfaXeTyci5pACSzFtZuF5ACe1mjenPwq8sUH6VcBVm1hl\nMzMr0aCeSGBmZoOL7xxqlfKkArOhzT0dMzOrjIOOmZlVxsNr1m881GY29LinY2ZmlXHQMTOzyjjo\nmJlZZXxOxwaE2vmd5Y+vZfSLt+rn2phZWdzTMTOzyjjo2ICz/PG1nDv37o1mt5lZZ3DQMTOzyjjo\nmJlZZRx0zMysMp69ZgOar1pg1lnc0zEzs8q4p2ODhns9ZoOfezpmZlYZ93T6kH9XYmbWnIOODUoe\najMbnBx0bNBzADIbPBx0rKN0N8TpYGQ2MDjo2JDg3pDZwOCgs4k8eWDwcQAy6z8OOjakOQCZVctB\nxyzz+SCz8pUadCRNAr4GDAO+HRHn1C1XXn4ksAY4PiL+2CyvpB2BHwJjgfuBd0fE43nZ6cCJwAbg\noxFxXZnts6GhlSFUByaz1pQWdCQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQH8KiLOkTQj\nv/6EpAnAMcC+wK7A9ZL2iogNZbXRrKbdc3sOUjZUldnTORDoioilAJKuAKYAxaAzBZgVEQHMkzRC\n0i6kXkx3eacAh+X8lwI3AJ/I6VdExDrgPklduQ6/L7GNZr3SnxNQigHP57SsamUGnVHAssLr5aTe\nTE/rjOoh70sjYmV+/hfgpYWy5jUoayOSpgHT8sunJC1ppTEFOwOP1F7ceOPebWYf1DZqe5lOffOA\nfF8ra3+ZTm0zPeuItm+Codz+Rm3fvbeFDeqJBBERkqLNPBcBF/V2m5IWRMTE3uYfzIZy22Fot38o\ntx2Gdvv7uu1lXmV6BTCm8Hp0TmtlnWZ5H8pDcOS/D7exPTMz60dlBp35wHhJ4yRtTjrJP7tundnA\nVCUHA6vz0FmzvLOB9+fn7wd+Wkg/RtIWksaRJifcXFbjzMysfaUNr0XEekknA9eRpj1fHBGLJU3P\ny2cCc0jTpbtIU6ZPaJY3F30OcKWkE4EHgHfnPIslXUmabLAeOKmkmWu9HprrAEO57TC02z+U2w5D\nu/192naliWNmZmbl851DzcysMg46ZmZWGQedFkmaJGmJpK58JYSOJOl+SbdJWihpQU7bUdJcSffk\nvy8urH96fk+WSDqi/2rePkkXS3pY0u2FtLbbKuk1+T3rknRevrzTgNdN+8+UtCLv/4WSjiws65j2\nSxoj6b8k3SFpsaR/zekdv/+btL2afR8RfvTwIE1muBfYA9gcuBWY0N/1Kqmt9wM716V9CZiRn88A\nvpifT8jvxRbAuPweDevvNrTR1kOB/YHbN6WtpFmSBwMCfgFM7u+2bUL7zwROa7BuR7Uf2AXYPz/f\nDrg7t7Hj93+Ttley793Tac3zl/SJiGeA2mV5hooppEsOkf/+UyH9iohYFxH3kWYhHtgP9euViLgJ\neKwuua225t+KbR8R8yL9F84q5BnQuml/dzqq/RGxMvLFhSPiSeBO0hVMOn7/N2l7d/q07Q46renu\ncj2dKEgXS70lXzIIml96qNPel3bbOio/r08fzD4iaVEefqsNL3Vs+yWNBV4N/IEhtv/r2g4V7HsH\nHav3+ojYj3QF8JMkHVpcmL/RDIl59kOprQUXkoaR9wNWAl/p3+qUS9K2wFXA/4mIJ4rLOn3/N2h7\nJfveQac1Q+YSOxGxIv99GLiaNFw2lC491G5bV+Tn9emDUkQ8FBEbIuI54Fu8MFzace2XNJx00P1e\nRPwkJw+J/d+o7VXtewed1rRySZ9BT9I2krarPQfeDNzO0Lr0UFttzUMxT0g6OM/cmVrIM+jUDrjZ\n20n7Hzqs/bmu3wHujIivFhZ1/P7vru2V7fv+nkkxWB6ky/XcTZq5cUZ/16ekNu5BmqVyK7C41k5g\nJ+BXwD3A9cCOhTxn5PdkCQN81k6D9v6ANIzwLGk8+sTetBWYmP9B7wW+Qb7Sx0B/dNP+y4DbgEX5\nYLNLJ7YfeD1p6GwRsDA/jhwK+79J2yvZ974MjpmZVcbDa2ZmVhkHHTMzq4yDjpmZVcZBx8zMKuOg\nY2ZmlXHQMQMkPVVy+cdL2rXw+n5JO7eQ79WSvlNy3S6R9M4my0+W9IEy62BDh4OOWTWOB3btaaUG\nPgmc11chA3JvAAADJklEQVSVkNSbW9RfDHykr+pgQ5uDjlk3JI2UdJWk+flxSE4/M18Q8QZJSyV9\ntJDn0/meI7+R9ANJp+VexETge/k+JVvl1T8i6Y/5fiT7NNj+dsArI+LW/Po2SSOUPCppak6fJelw\nSVtK+m5e70+S/mdefryk2ZL+P/CrnP8buZ7XAy8pbPMcpfusLJL0HwARsQa4X9KguYK4DVwOOmbd\n+xpwbkQcABwNfLuwbB/gCNL1qT4rabik2nqvIl0wdSJARPwYWAAcFxH7RcTaXMYjEbE/6UKLpzXY\nfu3X3jW/BQ4B9gWWAm/I6a8FfgeclDYX/wM4FrhU0pZ5nf2Bd0bEP5IucbI36T4pU4HXAUjaKS/b\nNyJeCXy+sO0Fhe2Z9VpvutpmQ8WbgAl64WaI2+cr8wL8PCLWAeskPUy6BP4hwE8j4mngaUk/66H8\n2kUmbwHe0WD5LsCqwutfk2689gApUE2TNAp4PCL+Jun1wNcBIuIuSQ8Ae+W8cyOidu+cQ4EfRMQG\n4MHcAwJYDTwNfEfStcC1hW0/TAq0ZpvEPR2z7r0IODj3TvaLiFERUZtwsK6w3gZ69wWuVkZ3+dcC\nWxZe30TqbbwBuIEUkN5JCkY9+VtPK0TEelLP7cfAUcB/FhZvmetjtkkcdMy690sKJ9Al7dfD+r8F\n3prPrWxLOnDXPEm6NXA77gT2rL2IiGXAzsD4iFgK/IY0LHdTXuXXwHG5rnsBu5Eu0FjvJuA9kobl\nKwvXzv1sC+wQEXOAU0jDhDV7sfFQn1mvOOiYJVtLWl54nAp8FJiYT6rfAUxvVkBEzCddnXcR6X7x\nt5GGrAAuAWbWTSRoKiLuAnao3W4i+wPpaueQgswoUvABuAB4kaTbgB8Cx+chwHpXk66ifAfpFsO/\nz+nbAddKWpTLPLWQ5xBgbiv1NmvGV5k260OSto2IpyRtTepRTIt8P/pelncK8GREfLvHlUsi6dXA\nqRHxvv6qg3UO93TM+tZFkhYCfwSu2pSAk13IxueP+sPOwKf7uQ7WIdzTMTOzyrinY2ZmlXHQMTOz\nyjjomJlZZRx0zMysMg46ZmZWmf8GQp4xDBcjJYgAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2d78f12e518>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# We decide the max and min length (in words) of discharge notes.\n",
"\n",
"plt.hist([len(x) for x in train_x], normed=1, bins=100, alpha=.5)\n",
"plt.vlines(MAX_NOTE_LEN, 0, .003); plt.vlines(MIN_NOTE_LEN, 0, .003);\n",
"plt.xlabel(\"Length (words)\"); plt.ylabel(\"Density\");\n",
"plt.title(\"Distribution of Length of Discharge Notes\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Keep only the notes that are long enough:\n",
"subset_train = set(np.where([len(x) >= MIN_NOTE_LEN for x in train_x])[0])\n",
"subset_test = set(np.where([len(x) >= MIN_NOTE_LEN for x in test_x])[0])\n",
"subset_valid = set(np.where([len(x) >= MIN_NOTE_LEN for x in valid_x])[0])\n",
"\n",
"def getsubset(orig, index):\n",
" return([j for i,j in enumerate(orig) if i in index])\n",
"\n",
"# Pad the notes that are too short:\n",
"train_x = pad_sequences(getsubset(train_x, subset_train), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
"valid_x = pad_sequences(getsubset(valid_x, subset_valid), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
"test_x = pad_sequences(getsubset(test_x, subset_test), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n",
"\n",
"train_y = np.array(getsubset(train_y, subset_train))\n",
"valid_y = np.array(getsubset(valid_y, subset_valid))\n",
"test_y = np.array(getsubset(test_y, subset_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Defining the neural network"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"seq_input_layer = Input(shape=(MAX_NOTE_LEN,), dtype='int32')\n",
"\n",
"embedded_layer = Embedding(embeddings_matrix.shape[0], embeddings_matrix.shape[1],\n",
" weights = [embeddings_matrix],\n",
" input_length = MAX_NOTE_LEN,\n",
" trainable = True)(seq_input_layer)\n",
"\n",
"conv_layer = Conv1D(UNITS, FILTERSIZE, activation='tanh')(embedded_layer)\n",
"\n",
"pool_layer = GlobalMaxPooling1D()(conv_layer)\n",
"\n",
"out_layer = Dense(1, \n",
" activation = 'sigmoid', \n",
" activity_regularizer = l1(REG_FACTOR)\n",
" )(pool_layer)\n",
"\n",
"optimizer = RMSprop(lr = LEARNING_RATE)\n",
"model = Model(inputs=seq_input_layer, outputs=out_layer)\n",
"model.compile(loss=LOSS_FUNC, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) (None, 700) 0 \n",
"_________________________________________________________________\n",
"embedding_1 (Embedding) (None, 700, 1000) 22331000 \n",
"_________________________________________________________________\n",
"conv1d_1 (Conv1D) (None, 698, 450) 1350450 \n",
"_________________________________________________________________\n",
"global_max_pooling1d_1 (Glob (None, 450) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1) 451 \n",
"=================================================================\n",
"Total params: 23,681,901\n",
"Trainable params: 23,681,901\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training the neural net"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"# Load the weights from a previous run, or train the model anew:\n",
"if isfile(CNN_FILENAME):\n",
" model.load_weights(CNN_FILENAME)\n",
"else:\n",
" model.fit(train_x, train_y, \n",
" batch_size = BATCH_SIZE, \n",
" epochs = EPOCHS, \n",
" validation_data = (valid_x, valid_y), \n",
" verbose = True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}