|
a |
|
b/03_TrainModel.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "markdown", |
|
|
5 |
"metadata": {}, |
|
|
6 |
"source": [ |
|
|
7 |
"* [Constants](#Constants)\n", |
|
|
8 |
"* [Load data](#Load-data)\n", |
|
|
9 |
"* [Train Word2Vec](#Train-Word2Vec)\n", |
|
|
10 |
"* [Prepare text](#Prepare-text)\n", |
|
|
11 |
"* [Defining the neural network](#Defining-the-neural-network) \n", |
|
|
12 |
"* [Training the neural net](#Training-the-neural-net)" |
|
|
13 |
] |
|
|
14 |
}, |
|
|
15 |
{ |
|
|
16 |
"cell_type": "code", |
|
|
17 |
"execution_count": 4, |
|
|
18 |
"metadata": {}, |
|
|
19 |
"outputs": [], |
|
|
20 |
"source": [ |
|
|
21 |
"import numpy as np\n", |
|
|
22 |
"import pandas as pd\n", |
|
|
23 |
"import matplotlib.pyplot as plt\n", |
|
|
24 |
"import string\n", |
|
|
25 |
"\n", |
|
|
26 |
"from sklearn.model_selection import train_test_split\n", |
|
|
27 |
"from os.path import isfile\n", |
|
|
28 |
"\n", |
|
|
29 |
"from keras.models import Model\n", |
|
|
30 |
"from keras.preprocessing.sequence import pad_sequences\n", |
|
|
31 |
"from keras.layers import Embedding, Input, Conv1D, Dense, GlobalMaxPooling1D\n", |
|
|
32 |
"from keras.optimizers import RMSprop\n", |
|
|
33 |
"from keras.regularizers import l1\n", |
|
|
34 |
"\n", |
|
|
35 |
"from gensim.models import word2vec\n", |
|
|
36 |
"from gensim.models import KeyedVectors\n", |
|
|
37 |
"\n", |
|
|
38 |
"\n", |
|
|
39 |
"import logging\n", |
|
|
40 |
"logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)" |
|
|
41 |
] |
|
|
42 |
}, |
|
|
43 |
{ |
|
|
44 |
"cell_type": "markdown", |
|
|
45 |
"metadata": {}, |
|
|
46 |
"source": [ |
|
|
47 |
"# Constants" |
|
|
48 |
] |
|
|
49 |
}, |
|
|
50 |
{ |
|
|
51 |
"cell_type": "code", |
|
|
52 |
"execution_count": null, |
|
|
53 |
"metadata": {}, |
|
|
54 |
"outputs": [], |
|
|
55 |
"source": [ |
|
|
56 |
"# Location of train/test data files generated by TextSections/TextPrep\n", |
|
|
57 |
"TRAIN_DATA_LOC = \"~/train_data.csv\"\n", |
|
|
58 |
"TEST_DATA_LOC = \"~/test_data.csv\"\n", |
|
|
59 |
"\n", |
|
|
60 |
"# Columns we will use:\n", |
|
|
61 |
"VISITID = \"visit_id\"\n", |
|
|
62 |
"OUTCOME = \"readmitted\" # e.g. ReadmissionInLessThan30Days\n", |
|
|
63 |
"\n", |
|
|
64 |
"# Test/Train split\n", |
|
|
65 |
"SPLIT_SIZE = 0.9 # relative size of train:test\n", |
|
|
66 |
"SPLIT_SEED = 1234\n", |
|
|
67 |
"\n", |
|
|
68 |
"# Word2Vec hyperparameters\n", |
|
|
69 |
"WINDOW = 2\n", |
|
|
70 |
"DIMENSIONS = 1000\n", |
|
|
71 |
"MIN_COUNT = 5\n", |
|
|
72 |
"USE_SKIPGRAM = True \n", |
|
|
73 |
"USE_HIER_SMAX = False \n", |
|
|
74 |
"NUM_THREADS = 50\n", |
|
|
75 |
"# Where to save the w2v model:\n", |
|
|
76 |
"W2V_FILENAME = './w2v_dims_{dims}_window_{window}.bin'.format(\n", |
|
|
77 |
" dims = DIMENSIONS,\n", |
|
|
78 |
" window = WINDOW\n", |
|
|
79 |
")\n", |
|
|
80 |
"\n", |
|
|
81 |
"\n", |
|
|
82 |
"# Text Prep\n", |
|
|
83 |
"PADDING = \"PADDING\"\n", |
|
|
84 |
"MAX_NOTE_LEN = 700\n", |
|
|
85 |
"MIN_NOTE_LEN = 20\n", |
|
|
86 |
"\n", |
|
|
87 |
"# Model Architecture\n", |
|
|
88 |
"UNITS = 450\n", |
|
|
89 |
"FILTERSIZE = 3\n", |
|
|
90 |
"LEARNING_RATE = 0.0001\n", |
|
|
91 |
"LOSS_FUNC = 'binary_crossentropy'\n", |
|
|
92 |
"REG_FACTOR = 0.05\n", |
|
|
93 |
"\n", |
|
|
94 |
"# Model Training\n", |
|
|
95 |
"CNN_FILENAME = \"./cnn.h5\"\n", |
|
|
96 |
"BATCH_SIZE = 100\n", |
|
|
97 |
"EPOCHS = 4" |
|
|
98 |
] |
|
|
99 |
}, |
|
|
100 |
{ |
|
|
101 |
"cell_type": "markdown", |
|
|
102 |
"metadata": {}, |
|
|
103 |
"source": [ |
|
|
104 |
"# Load data" |
|
|
105 |
] |
|
|
106 |
}, |
|
|
107 |
{ |
|
|
108 |
"cell_type": "code", |
|
|
109 |
"execution_count": 2, |
|
|
110 |
"metadata": { |
|
|
111 |
"collapsed": true, |
|
|
112 |
"scrolled": true |
|
|
113 |
}, |
|
|
114 |
"outputs": [], |
|
|
115 |
"source": [ |
|
|
116 |
"# Read train and test hospital data.\n", |
|
|
117 |
"train = pd.read_csv(TRAIN_DATA_LOC, dtype = str)\n", |
|
|
118 |
"test = pd.read_csv(TEST_DATA_LOC, dtype = str)\n", |
|
|
119 |
"\n", |
|
|
120 |
"# Split the train data into a train and validation set.\n", |
|
|
121 |
"train, valid = train_test_split(train, \n", |
|
|
122 |
" stratify = train[OUTCOME], \n", |
|
|
123 |
" train_size = SPLIT_SIZE, \n", |
|
|
124 |
" random_state = SPLIT_SEED)\n", |
|
|
125 |
"\n", |
|
|
126 |
"# Prepare the sections.\n", |
|
|
127 |
"# If `sectiontext` is present, then include \"SECTIONNAME sectiontext\".\n", |
|
|
128 |
"# If not present, include only \"SECTIONNAME\".\n", |
|
|
129 |
"SECTIONNAMES = [x for x in trainTXT.columns if VISITID not in x and OUTCOME not in x]\n", |
|
|
130 |
"for x in SECTIONNAMES:\n", |
|
|
131 |
" rep = x.replace(\" \", \"_\").upper()\n", |
|
|
132 |
" train[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in train[x]]\n", |
|
|
133 |
" valid[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in valid[x]]\n", |
|
|
134 |
" test[x] = [\" \".join([rep, t]) if not pd.isnull(t) else rep for t in test[x]]" |
|
|
135 |
] |
|
|
136 |
}, |
|
|
137 |
{ |
|
|
138 |
"cell_type": "markdown", |
|
|
139 |
"metadata": {}, |
|
|
140 |
"source": [ |
|
|
141 |
"# Train Word2Vec" |
|
|
142 |
] |
|
|
143 |
}, |
|
|
144 |
{ |
|
|
145 |
"cell_type": "code", |
|
|
146 |
"execution_count": 3, |
|
|
147 |
"metadata": { |
|
|
148 |
"scrolled": true |
|
|
149 |
}, |
|
|
150 |
"outputs": [ |
|
|
151 |
{ |
|
|
152 |
"name": "stderr", |
|
|
153 |
"output_type": "stream", |
|
|
154 |
"text": [ |
|
|
155 |
"2017-10-27 12:32:33,194 : INFO : loading projection weights from ./word2vec/w2v_dims_1000_window_2.bin\n", |
|
|
156 |
"2017-10-27 12:32:33,507 : INFO : loaded (22330, 1000) matrix from ./word2vec/w2v_dims_1000_window_2.bin\n" |
|
|
157 |
] |
|
|
158 |
} |
|
|
159 |
], |
|
|
160 |
"source": [ |
|
|
161 |
"# We will remove digits and punctuation:\n", |
|
|
162 |
"remove_digits_punc = str.maketrans('', '', string.digits + ''.join([x for x in string.punctuation if '_' not in x]))\n", |
|
|
163 |
"remove_digits_punc = {a:\" \" for a in remove_digits_punc.keys()}\n", |
|
|
164 |
"\n", |
|
|
165 |
"# (If the model already exists, don't recompute.)\n", |
|
|
166 |
"if not isfile(W2V_FILENAME):\n", |
|
|
167 |
" # Use only training data to train word2vec:\n", |
|
|
168 |
" notes = train[SECTIONNAMES].apply(lambda x: \" \".join(x), axis=1).values \n", |
|
|
169 |
" stop = set([x for x in string.ascii_lowercase]) \n", |
|
|
170 |
" for i in range(len(notes)):\n", |
|
|
171 |
" notes[i] = [w for w in notes[i].translate(remove_digits_punc).split() if (w not in stop)]\n", |
|
|
172 |
" \n", |
|
|
173 |
" w2v = word2vec.Word2Vec(notes, \n", |
|
|
174 |
" size = DIMENSIONS, \n", |
|
|
175 |
" window = WINDOW, \n", |
|
|
176 |
" sg = USE_SKIPGRAM, \n", |
|
|
177 |
" hs = USE_HIER_SMAX, \n", |
|
|
178 |
" min_count = MIN_COUNT, \n", |
|
|
179 |
" workers = NUM_THREADS)\n", |
|
|
180 |
" w2v.wv.save_word2vec_format(W2V_FILENAME, binary=True)\n", |
|
|
181 |
"else:\n", |
|
|
182 |
" w2v = KeyedVectors.load_word2vec_format(W2V_FILENAME, binary=True)" |
|
|
183 |
] |
|
|
184 |
}, |
|
|
185 |
{ |
|
|
186 |
"cell_type": "code", |
|
|
187 |
"execution_count": 4, |
|
|
188 |
"metadata": { |
|
|
189 |
"collapsed": true |
|
|
190 |
}, |
|
|
191 |
"outputs": [], |
|
|
192 |
"source": [ |
|
|
193 |
"# Make the embedding matrix.\n", |
|
|
194 |
"# We include one extra word, `PADDING`. This is the word that will right-pad short notes.\n", |
|
|
195 |
"# For `PADDING`'s vector representation, we choose the zero vector.\n", |
|
|
196 |
"vocab = [PADDING] + sorted(list(w2v.wv.vocab.keys()))\n", |
|
|
197 |
"vset = set(vocab)\n", |
|
|
198 |
"\n", |
|
|
199 |
"embeddings_index = {}\n", |
|
|
200 |
"for i in range(len(vocab)):\n", |
|
|
201 |
" embeddings_index[vocab[i]] = i\n", |
|
|
202 |
"\n", |
|
|
203 |
"# reverse_embeddings_index = {b:a for a,b in embeddings_index.items()}\n", |
|
|
204 |
"\n", |
|
|
205 |
"# Adding PADDING as vocab word with embedding value of a zero vector\n", |
|
|
206 |
"embeddings_matrix = np.matrix(np.concatenate(([[0.] * DIMENSIONS], [w2v[x] for x in vocab[1:]])))" |
|
|
207 |
] |
|
|
208 |
}, |
|
|
209 |
{ |
|
|
210 |
"cell_type": "markdown", |
|
|
211 |
"metadata": { |
|
|
212 |
"collapsed": true |
|
|
213 |
}, |
|
|
214 |
"source": [ |
|
|
215 |
"# Prepare text" |
|
|
216 |
] |
|
|
217 |
}, |
|
|
218 |
{ |
|
|
219 |
"cell_type": "code", |
|
|
220 |
"execution_count": 5, |
|
|
221 |
"metadata": { |
|
|
222 |
"collapsed": true, |
|
|
223 |
"scrolled": true |
|
|
224 |
}, |
|
|
225 |
"outputs": [], |
|
|
226 |
"source": [ |
|
|
227 |
"train_x = train[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", |
|
|
228 |
"test_x = test[ SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", |
|
|
229 |
"valid_x = valid[SECTIONNAMES].apply(lambda x: (\" \".join(x)).translate(remove_digits_punc), axis=1).values \n", |
|
|
230 |
"\n", |
|
|
231 |
"train_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in train_x]\n", |
|
|
232 |
"valid_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in valid_x]\n", |
|
|
233 |
"test_x = [[embeddings_index[x] for x in note.split() if x in vset] for note in test_x]\n", |
|
|
234 |
"\n", |
|
|
235 |
"train_y = train[OUTCOME]\n", |
|
|
236 |
"valid_y = valid[OUTCOME]\n", |
|
|
237 |
"test_y = test[OUTCOME]" |
|
|
238 |
] |
|
|
239 |
}, |
|
|
240 |
{ |
|
|
241 |
"cell_type": "code", |
|
|
242 |
"execution_count": 6, |
|
|
243 |
"metadata": { |
|
|
244 |
"scrolled": true |
|
|
245 |
}, |
|
|
246 |
"outputs": [ |
|
|
247 |
{ |
|
|
248 |
"data": { |
|
|
249 |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEWCAYAAAC9qEq5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X2cHFWd7/HP1xCeHyIQFZJAggTY4FXE8KAoy72KJIhm\nFR9g0QiiuXFB94KsBvEBWbmirnJFgYiKEFARRTBiXAzeBXyKJGgIBAgMAUxChPBgABMCCb/945yG\nStvT0z2Zqpnp+b5fr35N96k6p87p6qlfn1OnqxQRmJmZVeFF/V0BMzMbOhx0zMysMg46ZmZWGQcd\nMzOrjIOOmZlVxkHHzMwq46DTYSTNlPTpPiprN0lPSRqWX98g6YN9UXYu7xeS3t9X5bWx3c9LekTS\nX6redqskHSZpeR+VJUnflfS4pJt7WcZGn4VelnG/pDf1Nr91BgedQST/066V9KSkv0r6naTpkp7f\njxExPSL+vcWymh4AIuLPEbFtRGzog7qfKenyuvInR8Slm1p2m/XYDfgYMCEiXtZgeZ8d7NusV0ja\ns6TiXw8cDoyOiAMbbPt4SRtyUHlK0n05SO1VW6cvPwsDUf5/eFjSNoW0D0q6ocX8ffqFrJM56Aw+\nb42I7YDdgXOATwDf6euNSNqsr8scIHYDHo2Ih/u7IhXaHbg/Iv7WZJ3fR8S2wA7Am4C1wC2SXlFF\nBduRe25lHLuGAf9aQrlWFBF+DJIHcD/wprq0A4HngFfk15cAn8/PdwauBf4KPAb8mvRF47KcZy3w\nFPBxYCwQwInAn4GbCmmb5fJuAL4A3Aw8AfwU2DEvOwxY3qi+wCTgGeDZvL1bC+V9MD9/EfAp4AHg\nYWAWsENeVqvH+3PdHgHOaPI+7ZDzr8rlfSqXXzuYPpfrcUmDvH/XjsKyLYD/yHV4CJgJbFXMR+pF\nPQysBE4o5N0J+Fl+3+YDnwd+k5fdlNv3t1yv9/RUXoO67QrMzvu5C/hQTj8ReBrYkMv+XIO8x9fq\nUpd+LfDjun2wWSHPUuBJ4D7guEK+DwF35mV3APsXPg+nAYuA1cAPgS3zshfn7a0CHs/PRxfKvAE4\nG/ht3od7AuPye/ckcD1wPnB5Ic/BwO9In/9bgcN6+N+akd+/ETntg8ANhXVel/fd6vz3dTn97Pz+\nPp3f42/k9H2AubnMJcC7C2Udmd+bJ4EVwGn9fXyp7DjW3xXwo42d1SDo5PQ/Ax/Ozy/hhaDzBdKB\ncXh+vAFQo7IKB5VZwDbAVg0ONDfkf5BX5HWuqv2T0yTo5OdnFg8IhfJqQecDpIPlHsC2wE+Ay+rq\n9q1cr1cB64B/6OZ9mkUKiNvlvHcDJ3ZXz7q83S4HziUd2HfMZf8M+EIh33rgrPxeHwmsAV6cl1+R\nH1sDE4BlFA70uX171tWj2/Ia1O0m4AJgS2A/0sH7f+Vlx9MgqBTyNlye98lDdftgs7zvnwD2zst2\nAfbNz9+VPyMHACIFh90Ln4ebSQFyR1Jgmp6X7QQcnd+f7YAfAdfUfVb+DOyb6zAc+D3pS8DmpCHE\nJ3jh8zgKeDS/by8iDS8+Coxs9r9F+tzV/n+eDzq5vo8D78vbPza/3qn+s5xfb5P38Ql5/VeTvixN\nyMtXAm/Iz19MDsxD4eHhtc7wIOmfot6zpAPC7hHxbET8OvKnvIkzI+JvEbG2m+WXRcTtkYZqPg28\ne1NOLhccB3w1IpZGxFPA6cAxdcN8n4uItRFxK+mb66vqC8l1OQY4PSKejIj7ga+QDha9JknANOCU\niHgsIp4E/m/eVs2zwFn5vZ5D+ta7d67T0cBnI2JNRNwBtHIuq2F5Deo2BjgE+EREPB0RC4FvA1N7\n3eCku88V5N61pK0iYmVELM7pHwS+FBHzI+mKiAcK+c6LiAcj4jFS0N4PICIejYir8vvzJKn38I91\n27wkIhZHxHrS5/oA4DMR8UxE/Ib0haDmvcCciJgTEc9FxFxgASkINfMZ4COSRtalvwW4JyIui4j1\nEfED4C7grd2UcxRpSPO7ef0/kb6kvSsvfxaYIGn7iHg8Iv7YQ706hoNOZxhF6sLX+zKp9/BLSUsl\nzWihrGVtLH+A9I1z55Zq2dyuubxi2ZsBLy2kFWebrSH1iOrtnOtUX9aoTazfSNK38FvyJI6/Av+Z\n02sezQfE+jqOJLWl+N719D43K6/erkAtENb0RZsbfq7yF473ANOBlZJ+LmmfvHgMcG+TMhvuQ0lb\nS/qmpAckPUHquY2o+0JTfM9qbV7TzfLdgXfV9lXeX68nBatuRcTtpKG9+v+V+s8nNH+PdwcOqtv+\ncUBt8srRpAD4gKQbJb22Wb06iYPOICfpANIH/zf1y/I3/Y9FxB7A24BTJb2xtribInvqCY0pPN+N\n9I3tEdL5iK0L9RrGxgfknsp9kPSPWix7PencSTseyXWqL2tFm+U0KnctaRhpRH7sEOnke09Wkdoy\nupA2ppt1e+NBYEdJ2xXS+qLNbyedB/w7EXFdRBxOOojfRRr6hHTgf3kvtvUxUi/uoIjYHjg0p6u4\n2cLzlaQ2b11IK76ny0i98hGFxzYRcU4Ldfks6bxUMaDUfz5h4/e4/vO9DLixbvvbRsSHAXJPcArw\nEuAa4MoW6tURHHQGKUnbSzqKdJ7g8oi4rcE6R0naMw8NrSad7HwuL36IdP6kXe+VNCH/s59FOtG8\ngXTeZEtJb5E0nHTyfotCvoeAsU1mHf0AOEXSOEnbkoauflj3Tb9HuS5XAmdL2k7S7sCpwOXNc25M\n0pbFBy+cUzpX0kvyOqMkHdFinX4CnJm/0e/D3w999XZ/EBHLSCfMv5Dr+0rSBIK22gzpy0LeB18n\nnVf6XIN1XippSp5evI407Ff7XH0bOE3Sa/Issz3zPujJdqSg/ldJO5IO/N3KQ3YLSO/p5rmnUBzq\nuhx4q6Qjcpu2zNPhRzcscOOyu0iTHD5aSJ4D7CXpnyVtJuk9pHNz1+bl9fvv2rz++yQNz48DJP1D\nru9xknaIiGdJ56KeY4hw0Bl8fibpSdI3qTOAr5JOVjYynjSr5ynSSdcLIuK/8rIvAJ/KXf/T2tj+\nZaTJCn8hnbT+KEBErAb+hXTQWUHq+RR/7/Kj/PdRSY3Gry/OZd9Emg31NPCRNupV9JG8/aWkHuD3\nc/mtGkU6ABYfLydNT+8C5uUhoOtpcI6lGyeTZtX9hdTOH5AO2DVnApfm/fHuNupacyzpZP+DwNWk\n80fXt5H/tZKeIh0AbwC2Bw5o9GWGdNw4NW/rMdK5l9o3+B+Rzsd8nzQz6xq6Py9U9P9Ik0QeAeaR\nhi57chzwWtIEgc+TAsW6XI9lwBTgk6Se5jLg32j9mHcWaTIAubxHSedpPpa393HgqIh4JK/yNeCd\nSj/APS8Pdb6ZdM7vQdJ+/yIvfBF7H3B//hxNz20ZEmozmcysQpK+CLwsIiq/IkOnkvRD4K6IaNpL\nsv7lno5ZBSTtI+mVecjpQNLw19X9Xa/BLA9XvVzSiyRNIvVsrunvellznfqrc7OBZjvSkNqupPH/\nr5B+S2S99zLSubKdSEO5H85Tk20A8/CamZlVxsNrZmZWmSE9vLbzzjvH2LFj+7saZmaDyi233PJI\nRNRftaElQzrojB07lgULFvR3NczMBhVJ9VdnaJmH18zMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zM\nKuOgY2ZmlSk16EiaJGmJpK5GNxDL16E6Ly9fJGn/nvJK+ve87kJJv5S0a2HZ6Xn9Ja1cct7MzKpV\nWtDJN/E6H5hMuu/EsZIm1K02mXT5/fGkWwFf2ELeL0fEKyNiP9I9Kz6T80wgXUZ8X2AScIH65jbK\nZmbWR8rs6RwIdOV73j9DutnYlLp1pgCz8r3U55FuT7tLs7wR8UQh/za8cMe+KcAVEbEuIu4j3ffk\nwLIaZ2Zm7Ssz6Ixi43uWL+fv7yfe3TpN80o6W9Iy0o2PPtPG9pA0TdICSQtWrVrVVoOKDjvsMA47\n7LBe57fG/L6adbZBOZEgIs6IiDHA90h3ZGwn70URMTEiJo4c2atLB1Xq3Ll3P/8wMxvsygw6K4Ax\nhdejc1or67SSF1LQObqN7ZmZWT8qM+jMB8ZLGidpc9JJ/tl168wGpuZZbAcDqyNiZbO8ksYX8k8B\n7iqUdYykLSSNI01OuLmsxpmZWftKu8p0RKyXdDJwHTAMuDgiFkuanpfPBOYAR5JO+q8BTmiWNxd9\njqS9geeAB4BaeYslXQncAawHToqIDWW1z8zM2lfqrQ0iYg4psBTTZhaeB3BSq3lz+tENVq8tOxs4\nu7f1NTOzcg3KiQRmZjY4OeiYmVllHHTMzKwyDjpmZlYZBx0zM6tMqbPXrG8Vr0pwyuF79WNNzMx6\nxz0dMzOrjIOOmZlVxkHHzMwq43M6g5TP75jZYOSejpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZ\nZRx0zMysMg46ZmZWGf9OZwAq/gbHzKyTuKdjZmaVcdAxM7PKeHhtgPCQmpkNBe7pmJlZZRx0zMys\nMg46ZmZWGQcdMzOrTKlBR9IkSUskdUma0WC5JJ2Xly+StH9PeSV9WdJdef2rJY3I6WMlrZW0MD9m\nltk2MzNrX2mz1yQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQucHhHrJX0ROB34RC7v3ojY\nr6w2DVS+oZuZDRZl9nQOBLoiYmlEPANcAUypW2cKMCuSecAISbs0yxsRv4yI9Tn/PGB0iW0wM7M+\nVGbQGQUsK7xentNaWaeVvAAfAH5ReD0uD63dKOkNjSolaZqkBZIWrFq1qrWWmJlZnxi0EwkknQGs\nB76Xk1YCu+XhtVOB70vavj5fRFwUERMjYuLIkSOrq7CZmZV6RYIVwJjC69E5rZV1hjfLK+l44Cjg\njRERABGxDliXn98i6V5gL2BBH7SlFL4KgZkNNWX2dOYD4yWNk7Q5cAwwu26d2cDUPIvtYGB1RKxs\nllfSJODjwNsiYk2tIEkj8wQEJO1BmpywtMT2mZlZm0rr6eTZZScD1wHDgIsjYrGk6Xn5TGAOcCTQ\nBawBTmiWNxf9DWALYK4kgHkRMR04FDhL0rPAc8D0iHisrPaZmVn7Sr3gZ0TMIQWWYtrMwvMATmo1\nb07fs5v1rwKu2pT6mplZuQbtRAIzMxt8HHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQcfMzCpT6pRp\nq56vOG1mA5l7OmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46\nZmZWGQcdMzOrjIOOmZlVxkHHzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyvolbB/MN3cxsoHFP\nx8zMKuOgY2ZmlSk16EiaJGmJpC5JMxosl6Tz8vJFkvbvKa+kL0u6K69/taQRhWWn5/WXSDqizLb1\n1rlz737+YWY21JQWdCQNA84HJgMTgGMlTahbbTIwPj+mARe2kHcu8IqIeCVwN3B6zjMBOAbYF5gE\nXJDLMTOzAaLMns6BQFdELI2IZ4ArgCl160wBZkUyDxghaZdmeSPilxGxPuefB4wulHVFRKyLiPuA\nrlyOmZkNEGUGnVHAssLr5TmtlXVayQvwAeAXbWwPSdMkLZC0YNWqVS00w8zM+sqgnUgg6QxgPfC9\ndvJFxEURMTEiJo4cObKcypmZWUNl/k5nBTCm8Hp0TmtlneHN8ko6HjgKeGNERBvbMzOzflRmT2c+\nMF7SOEmbk07yz65bZzYwNc9iOxhYHRErm+WVNAn4OPC2iFhTV9YxkraQNI40OeHmEttnZmZtKq2n\nExHrJZ0MXAcMAy6OiMWSpuflM4E5wJGkk/5rgBOa5c1FfwPYApgrCWBeREzPZV8J3EEadjspIjaU\n1T4zM2tfqZfBiYg5pMBSTJtZeB7ASa3mzel7Ntne2cDZva2vmZmVa9BOJDAzs8HHQcfMzCrjoGNm\nZpVx0DEzs8r4fjpDhO+tY2YDgXs6ZmZWGQcdMzOrTEtBR9JPJL1FkoOUmZn1WqtB5ALgn4F7JJ0j\nae8S62RmZh2qpaATEddHxHHA/sD9wPWSfifpBEnDy6ygmZl1jpaHyyTtBBwPfBD4E/A1UhCaW0rN\nzMys47Q0ZVrS1cDewGXAW/OVoAF+KGlBWZUzM7PO0urvdL6VL8D5PElb5FtDTyyhXmZm1oFaHV77\nfIO03/dlRczMrPM17elIehkwCthK0qsB5UXbA1uXXDczM+swPQ2vHUGaPDAa+Goh/UngkyXVyczM\nOlTToBMRlwKXSjo6Iq6qqE5mZtahehpee29EXA6MlXRq/fKI+GqDbGZmZg31NLy2Tf67bdkVMTOz\nztfT8No389/PVVMdMzPrZK1e8PNLkraXNFzSryStkvTesitnZmadpdXf6bw5Ip4AjiJde21P4N/K\nqpSZmXWmVoNObRjuLcCPImJ1SfUxM7MO1uplcK6VdBewFviwpJHA0+VVy8zMOlFLQSciZkj6ErA6\nIjZI+hswpdyqWVnOnXv3889POXyvfqyJmQ017dwJdB/gPZKmAu8E3txTBkmTJC2R1CVpRoPlknRe\nXr5I0v495ZX0LkmLJT0naWIhfayktZIW5sfMNtpmZmYVaPXWBpcBLwcWAhtycgCzmuQZBpwPHA4s\nB+ZLmh0RdxRWmwyMz4+DgAuBg3rIezvwDuCbDTZ7b0Ts10qbzMyseq2e05kITIiIaKPsA4GuiFgK\nIOkK0pBcMehMAWblcudJGiFpF2Bsd3kj4s6c1kZVzMxsIGh1eO124GVtlj0KWFZ4vTyntbJOK3kb\nGZeH1m6U9IZGK0iaJmmBpAWrVq1qoUgzM+srrfZ0dgbukHQzsK6WGBFvK6VWvbMS2C0iHpX0GuAa\nSfvm3xc9LyIuAi4CmDhxYjs9NzMz20StBp0ze1H2CmBM4fXonNbKOsNbyLuRiFhHDogRcYuke4G9\nAN9O28xsgGhpeC0ibiRdiWB4fj4f+GMP2eYD4yWNk7Q5cAwwu26d2cDUPIvtYNKU7JUt5t2IpJF5\nAgKS9iBNTljaSvvMzKwarV577UPAj3lhxtgo4JpmeSJiPXAycB1wJ3BlRCyWNF3S9LzaHFJg6AK+\nBfxLs7y5Lm+XtBx4LfBzSdflsg4FFklamOs6PSIea6V9ZmZWjVaH104izUb7A0BE3CPpJT1liog5\npMBSTJtZeB657Jby5vSrgasbpF8FDMgbzRV/jDnQ+IeiZlalVmevrYuIZ2ovJG1G+p2OmZlZy1oN\nOjdK+iSwlaTDgR8BPyuvWmZm1olaDTozgFXAbcD/Jg17faqsSpmZWWdq9YKfz0m6BrgmIvyLSjMz\n65WmPZ08lflMSY8AS4Al+a6hn6mmemZm1kl6Gl47BTgEOCAidoyIHUkX5jxE0iml187MzDpKT0Hn\nfcCxEXFfLSFfhPO9wNQyK2ZmZp2np6AzPCIeqU/M53WGl1MlMzPrVD0FnWd6uczMzOzv9DR77VWS\nnmiQLmDLEupjZmYdrGnQiYhhVVXEzMw6X6s/DjUzM9tkDjpmZlYZBx0zM6tMq7c2sCHAtzkws7K5\np2NmZpVx0DEzs8o46JiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOz\nyjjomJlZZUoNOpImSVoiqUvSjAbLJem8vHyRpP17yivpXZIWS3pO0sS68k7P6y+RdESZbet05869\ne6OHmVlfKC3oSBoGnA9MBiYAx0qaULfaZGB8fkwDLmwh7+3AO4Cb6rY3ATgG2BeYBFyQyzEzswGi\nzJ7OgUBXRCyNiGeAK4ApdetMAWZFMg8YIWmXZnkj4s6IWNJge1OAKyJiXUTcB3TlcszMbIAoM+iM\nApYVXi/Paa2s00re3mwPSdMkLZC0YNWqVT0UaWZmfWnITSSIiIsiYmJETBw5cmR/V8fMbEgp8yZu\nK4Axhdejc1or6wxvIW9vtmdmZv2ozJ7OfGC8pHGSNied5J9dt85sYGqexXYwsDoiVraYt95s4BhJ\nW0gaR5qccHNfNsjMzDZNaT2diFgv6WTgOmAYcHFELJY0PS+fCcwBjiSd9F8DnNAsL4CktwNfB0YC\nP5e0MCKOyGVfCdwBrAdOiogNZbVvqPGtrM2sL5Q5vEZEzCEFlmLazMLzAE5qNW9Ovxq4ups8ZwNn\nb0KVzcysRENuIoGZmfUfBx0zM6uMg46ZmVXGQcfMzCpT6kQC60yeyWZmveWejpmZVcZBx8zMKuOg\nY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrjIOOmZlVxlcksE3iqxOYWTscdEpSPBib\nmVni4TUzM6uMezrWZzzUZmY9cU/HzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0z\nM6tMqUFH0iRJSyR1SZrRYLkknZeXL5K0f095Je0oaa6ke/LfF+f0sZLWSlqYHzPLbJuZmbWvtKAj\naRhwPjAZmAAcK2lC3WqTgfH5MQ24sIW8M4BfRcR44Ff5dc29EbFffkwvp2VmZtZbZfZ0DgS6ImJp\nRDwDXAFMqVtnCjArknnACEm79JB3CnBpfn4p8E8ltsHMzPpQmZfBGQUsK7xeDhzUwjqjesj70ohY\nmZ//BXhpYb1xkhYCq4FPRcSv6yslaRqpV8Vuu+3WTnusDb4kjpk1MqgnEkREAJFfrgR2i4j9gFOB\n70vavkGeiyJiYkRMHDlyZIW1NTOzMns6K4Axhdejc1or6wxvkvchSbtExMo8FPcwQESsA9bl57dI\nuhfYC1jQN82x3nKvx8xqyuzpzAfGSxonaXPgGGB23Tqzgal5FtvBwOo8dNYs72zg/fn5+4GfAkga\nmScgIGkP0uSEpeU1z8zM2lVaTyci1ks6GbgOGAZcHBGLJU3Py2cCc4AjgS5gDXBCs7y56HOAKyWd\nCDwAvDunHwqcJelZ4DlgekQ8Vlb7zMysfaXeTyci5pACSzFtZuF5ACe1mjenPwq8sUH6VcBVm1hl\nMzMr0aCeSGBmZoOL7xxqlfKkArOhzT0dMzOrjIOOmZlVxsNr1m881GY29LinY2ZmlXHQMTOzyjjo\nmJlZZXxOxwaE2vmd5Y+vZfSLt+rn2phZWdzTMTOzyjjo2ICz/PG1nDv37o1mt5lZZ3DQMTOzyjjo\nmJlZZRx0zMysMp69ZgOar1pg1lnc0zEzs8q4p2ODhns9ZoOfezpmZlYZ93T6kH9XYmbWnIOODUoe\najMbnBx0bNBzADIbPBx0rKN0N8TpYGQ2MDjo2JDg3pDZwOCgs4k8eWDwcQAy6z8OOjakOQCZVctB\nxyzz+SCz8pUadCRNAr4GDAO+HRHn1C1XXn4ksAY4PiL+2CyvpB2BHwJjgfuBd0fE43nZ6cCJwAbg\noxFxXZnts6GhlSFUByaz1pQWdCQNA84HDgeWA/MlzY6IOwqrTQbG58dBwIXAQT3knQH8KiLOkTQj\nv/6EpAnAMcC+wK7A9ZL2iogNZbXRrKbdc3sOUjZUldnTORDoioilAJKuAKYAxaAzBZgVEQHMkzRC\n0i6kXkx3eacAh+X8lwI3AJ/I6VdExDrgPklduQ6/L7GNZr3SnxNQigHP57SsamUGnVHAssLr5aTe\nTE/rjOoh70sjYmV+/hfgpYWy5jUoayOSpgHT8sunJC1ppTEFOwOP1F7ceOPebWYf1DZqe5lOffOA\nfF8ra3+ZTm0zPeuItm+Codz+Rm3fvbeFDeqJBBERkqLNPBcBF/V2m5IWRMTE3uYfzIZy22Fot38o\ntx2Gdvv7uu1lXmV6BTCm8Hp0TmtlnWZ5H8pDcOS/D7exPTMz60dlBp35wHhJ4yRtTjrJP7tundnA\nVCUHA6vz0FmzvLOB9+fn7wd+Wkg/RtIWksaRJifcXFbjzMysfaUNr0XEekknA9eRpj1fHBGLJU3P\ny2cCc0jTpbtIU6ZPaJY3F30OcKWkE4EHgHfnPIslXUmabLAeOKmkmWu9HprrAEO57TC02z+U2w5D\nu/192naliWNmZmbl851DzcysMg46ZmZWGQedFkmaJGmJpK58JYSOJOl+SbdJWihpQU7bUdJcSffk\nvy8urH96fk+WSDqi/2rePkkXS3pY0u2FtLbbKuk1+T3rknRevrzTgNdN+8+UtCLv/4WSjiws65j2\nSxoj6b8k3SFpsaR/zekdv/+btL2afR8RfvTwIE1muBfYA9gcuBWY0N/1Kqmt9wM716V9CZiRn88A\nvpifT8jvxRbAuPweDevvNrTR1kOB/YHbN6WtpFmSBwMCfgFM7u+2bUL7zwROa7BuR7Uf2AXYPz/f\nDrg7t7Hj93+Ttley793Tac3zl/SJiGeA2mV5hooppEsOkf/+UyH9iohYFxH3kWYhHtgP9euViLgJ\neKwuua225t+KbR8R8yL9F84q5BnQuml/dzqq/RGxMvLFhSPiSeBO0hVMOn7/N2l7d/q07Q46renu\ncj2dKEgXS70lXzIIml96qNPel3bbOio/r08fzD4iaVEefqsNL3Vs+yWNBV4N/IEhtv/r2g4V7HsH\nHav3+ojYj3QF8JMkHVpcmL/RDIl59kOprQUXkoaR9wNWAl/p3+qUS9K2wFXA/4mIJ4rLOn3/N2h7\nJfveQac1Q+YSOxGxIv99GLiaNFw2lC491G5bV+Tn9emDUkQ8FBEbIuI54Fu8MFzace2XNJx00P1e\nRPwkJw+J/d+o7VXtewed1rRySZ9BT9I2krarPQfeDNzO0Lr0UFttzUMxT0g6OM/cmVrIM+jUDrjZ\n20n7Hzqs/bmu3wHujIivFhZ1/P7vru2V7fv+nkkxWB6ky/XcTZq5cUZ/16ekNu5BmqVyK7C41k5g\nJ+BXwD3A9cCOhTxn5PdkCQN81k6D9v6ANIzwLGk8+sTetBWYmP9B7wW+Qb7Sx0B/dNP+y4DbgEX5\nYLNLJ7YfeD1p6GwRsDA/jhwK+79J2yvZ974MjpmZVcbDa2ZmVhkHHTMzq4yDjpmZVcZBx8zMKuOg\nY2ZmlXHQMQMkPVVy+cdL2rXw+n5JO7eQ79WSvlNy3S6R9M4my0+W9IEy62BDh4OOWTWOB3btaaUG\nPgmc11chA3JvAAADJklEQVSVkNSbW9RfDHykr+pgQ5uDjlk3JI2UdJWk+flxSE4/M18Q8QZJSyV9\ntJDn0/meI7+R9ANJp+VexETge/k+JVvl1T8i6Y/5fiT7NNj+dsArI+LW/Po2SSOUPCppak6fJelw\nSVtK+m5e70+S/mdefryk2ZL+P/CrnP8buZ7XAy8pbPMcpfusLJL0HwARsQa4X9KguYK4DVwOOmbd\n+xpwbkQcABwNfLuwbB/gCNL1qT4rabik2nqvIl0wdSJARPwYWAAcFxH7RcTaXMYjEbE/6UKLpzXY\nfu3X3jW/BQ4B9gWWAm/I6a8FfgeclDYX/wM4FrhU0pZ5nf2Bd0bEP5IucbI36T4pU4HXAUjaKS/b\nNyJeCXy+sO0Fhe2Z9VpvutpmQ8WbgAl64WaI2+cr8wL8PCLWAeskPUy6BP4hwE8j4mngaUk/66H8\n2kUmbwHe0WD5LsCqwutfk2689gApUE2TNAp4PCL+Jun1wNcBIuIuSQ8Ae+W8cyOidu+cQ4EfRMQG\n4MHcAwJYDTwNfEfStcC1hW0/TAq0ZpvEPR2z7r0IODj3TvaLiFERUZtwsK6w3gZ69wWuVkZ3+dcC\nWxZe30TqbbwBuIEUkN5JCkY9+VtPK0TEelLP7cfAUcB/FhZvmetjtkkcdMy690sKJ9Al7dfD+r8F\n3prPrWxLOnDXPEm6NXA77gT2rL2IiGXAzsD4iFgK/IY0LHdTXuXXwHG5rnsBu5Eu0FjvJuA9kobl\nKwvXzv1sC+wQEXOAU0jDhDV7sfFQn1mvOOiYJVtLWl54nAp8FJiYT6rfAUxvVkBEzCddnXcR6X7x\nt5GGrAAuAWbWTSRoKiLuAnao3W4i+wPpaueQgswoUvABuAB4kaTbgB8Cx+chwHpXk66ifAfpFsO/\nz+nbAddKWpTLPLWQ5xBgbiv1NmvGV5k260OSto2IpyRtTepRTIt8P/pelncK8GREfLvHlUsi6dXA\nqRHxvv6qg3UO93TM+tZFkhYCfwSu2pSAk13IxueP+sPOwKf7uQ7WIdzTMTOzyrinY2ZmlXHQMTOz\nyjjomJlZZRx0zMysMg46ZmZWmf8GQp4xDBcjJYgAAAAASUVORK5CYII=\n", |
|
|
250 |
"text/plain": [ |
|
|
251 |
"<matplotlib.figure.Figure at 0x2d78f12e518>" |
|
|
252 |
] |
|
|
253 |
}, |
|
|
254 |
"metadata": {}, |
|
|
255 |
"output_type": "display_data" |
|
|
256 |
} |
|
|
257 |
], |
|
|
258 |
"source": [ |
|
|
259 |
"# We decide the max and min length (in words) of discharge notes.\n", |
|
|
260 |
"\n", |
|
|
261 |
"plt.hist([len(x) for x in train_x], normed=1, bins=100, alpha=.5)\n", |
|
|
262 |
"plt.vlines(MAX_NOTE_LEN, 0, .003); plt.vlines(MIN_NOTE_LEN, 0, .003);\n", |
|
|
263 |
"plt.xlabel(\"Length (words)\"); plt.ylabel(\"Density\");\n", |
|
|
264 |
"plt.title(\"Distribution of Length of Discharge Notes\")\n", |
|
|
265 |
"plt.show()" |
|
|
266 |
] |
|
|
267 |
}, |
|
|
268 |
{ |
|
|
269 |
"cell_type": "code", |
|
|
270 |
"execution_count": 7, |
|
|
271 |
"metadata": { |
|
|
272 |
"collapsed": true |
|
|
273 |
}, |
|
|
274 |
"outputs": [], |
|
|
275 |
"source": [ |
|
|
276 |
"# Keep only the notes that are long enough:\n", |
|
|
277 |
"subset_train = set(np.where([len(x) >= MIN_NOTE_LEN for x in train_x])[0])\n", |
|
|
278 |
"subset_test = set(np.where([len(x) >= MIN_NOTE_LEN for x in test_x])[0])\n", |
|
|
279 |
"subset_valid = set(np.where([len(x) >= MIN_NOTE_LEN for x in valid_x])[0])\n", |
|
|
280 |
"\n", |
|
|
281 |
"def getsubset(orig, index):\n", |
|
|
282 |
" return([j for i,j in enumerate(orig) if i in index])\n", |
|
|
283 |
"\n", |
|
|
284 |
"# Pad the notes that are too short:\n", |
|
|
285 |
"train_x = pad_sequences(getsubset(train_x, subset_train), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", |
|
|
286 |
"valid_x = pad_sequences(getsubset(valid_x, subset_valid), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", |
|
|
287 |
"test_x = pad_sequences(getsubset(test_x, subset_test), maxlen=MAX_NOTE_LEN, padding='post', truncating='post')\n", |
|
|
288 |
"\n", |
|
|
289 |
"train_y = np.array(getsubset(train_y, subset_train))\n", |
|
|
290 |
"valid_y = np.array(getsubset(valid_y, subset_valid))\n", |
|
|
291 |
"test_y = np.array(getsubset(test_y, subset_test))" |
|
|
292 |
] |
|
|
293 |
}, |
|
|
294 |
{ |
|
|
295 |
"cell_type": "markdown", |
|
|
296 |
"metadata": {}, |
|
|
297 |
"source": [ |
|
|
298 |
"# Defining the neural network" |
|
|
299 |
] |
|
|
300 |
}, |
|
|
301 |
{ |
|
|
302 |
"cell_type": "code", |
|
|
303 |
"execution_count": 8, |
|
|
304 |
"metadata": { |
|
|
305 |
"collapsed": true, |
|
|
306 |
"scrolled": true |
|
|
307 |
}, |
|
|
308 |
"outputs": [], |
|
|
309 |
"source": [ |
|
|
310 |
"seq_input_layer = Input(shape=(MAX_NOTE_LEN,), dtype='int32')\n", |
|
|
311 |
"\n", |
|
|
312 |
"embedded_layer = Embedding(embeddings_matrix.shape[0], embeddings_matrix.shape[1],\n", |
|
|
313 |
" weights = [embeddings_matrix],\n", |
|
|
314 |
" input_length = MAX_NOTE_LEN,\n", |
|
|
315 |
" trainable = True)(seq_input_layer)\n", |
|
|
316 |
"\n", |
|
|
317 |
"conv_layer = Conv1D(UNITS, FILTERSIZE, activation='tanh')(embedded_layer)\n", |
|
|
318 |
"\n", |
|
|
319 |
"pool_layer = GlobalMaxPooling1D()(conv_layer)\n", |
|
|
320 |
"\n", |
|
|
321 |
"out_layer = Dense(1, \n", |
|
|
322 |
" activation = 'sigmoid', \n", |
|
|
323 |
" activity_regularizer = l1(REG_FACTOR)\n", |
|
|
324 |
" )(pool_layer)\n", |
|
|
325 |
"\n", |
|
|
326 |
"optimizer = RMSprop(lr = LEARNING_RATE)\n", |
|
|
327 |
"model = Model(inputs=seq_input_layer, outputs=out_layer)\n", |
|
|
328 |
"model.compile(loss=LOSS_FUNC, optimizer=optimizer)" |
|
|
329 |
] |
|
|
330 |
}, |
|
|
331 |
{ |
|
|
332 |
"cell_type": "code", |
|
|
333 |
"execution_count": 9, |
|
|
334 |
"metadata": {}, |
|
|
335 |
"outputs": [ |
|
|
336 |
{ |
|
|
337 |
"name": "stdout", |
|
|
338 |
"output_type": "stream", |
|
|
339 |
"text": [ |
|
|
340 |
"_________________________________________________________________\n", |
|
|
341 |
"Layer (type) Output Shape Param # \n", |
|
|
342 |
"=================================================================\n", |
|
|
343 |
"input_1 (InputLayer) (None, 700) 0 \n", |
|
|
344 |
"_________________________________________________________________\n", |
|
|
345 |
"embedding_1 (Embedding) (None, 700, 1000) 22331000 \n", |
|
|
346 |
"_________________________________________________________________\n", |
|
|
347 |
"conv1d_1 (Conv1D) (None, 698, 450) 1350450 \n", |
|
|
348 |
"_________________________________________________________________\n", |
|
|
349 |
"global_max_pooling1d_1 (Glob (None, 450) 0 \n", |
|
|
350 |
"_________________________________________________________________\n", |
|
|
351 |
"dense_1 (Dense) (None, 1) 451 \n", |
|
|
352 |
"=================================================================\n", |
|
|
353 |
"Total params: 23,681,901\n", |
|
|
354 |
"Trainable params: 23,681,901\n", |
|
|
355 |
"Non-trainable params: 0\n", |
|
|
356 |
"_________________________________________________________________\n" |
|
|
357 |
] |
|
|
358 |
} |
|
|
359 |
], |
|
|
360 |
"source": [ |
|
|
361 |
"model.summary()" |
|
|
362 |
] |
|
|
363 |
}, |
|
|
364 |
{ |
|
|
365 |
"cell_type": "markdown", |
|
|
366 |
"metadata": {}, |
|
|
367 |
"source": [ |
|
|
368 |
"# Training the neural net" |
|
|
369 |
] |
|
|
370 |
}, |
|
|
371 |
{ |
|
|
372 |
"cell_type": "code", |
|
|
373 |
"execution_count": 10, |
|
|
374 |
"metadata": { |
|
|
375 |
"collapsed": true, |
|
|
376 |
"scrolled": true |
|
|
377 |
}, |
|
|
378 |
"outputs": [], |
|
|
379 |
"source": [ |
|
|
380 |
"# Load the weights from a previous run, or train the model anew:\n", |
|
|
381 |
"if isfile(CNN_FILENAME):\n", |
|
|
382 |
" model.load_weights(CNN_FILENAME)\n", |
|
|
383 |
"else:\n", |
|
|
384 |
" model.fit(train_x, train_y, \n", |
|
|
385 |
" batch_size = BATCH_SIZE, \n", |
|
|
386 |
" epochs = EPOCHS, \n", |
|
|
387 |
" validation_data = (valid_x, valid_y), \n", |
|
|
388 |
" verbose = True)" |
|
|
389 |
] |
|
|
390 |
} |
|
|
391 |
], |
|
|
392 |
"metadata": { |
|
|
393 |
"kernelspec": { |
|
|
394 |
"display_name": "Python 3", |
|
|
395 |
"language": "python", |
|
|
396 |
"name": "python3" |
|
|
397 |
}, |
|
|
398 |
"language_info": { |
|
|
399 |
"codemirror_mode": { |
|
|
400 |
"name": "ipython", |
|
|
401 |
"version": 3 |
|
|
402 |
}, |
|
|
403 |
"file_extension": ".py", |
|
|
404 |
"mimetype": "text/x-python", |
|
|
405 |
"name": "python", |
|
|
406 |
"nbconvert_exporter": "python", |
|
|
407 |
"pygments_lexer": "ipython3", |
|
|
408 |
"version": "3.5.2" |
|
|
409 |
} |
|
|
410 |
}, |
|
|
411 |
"nbformat": 4, |
|
|
412 |
"nbformat_minor": 2 |
|
|
413 |
} |