|
a |
|
b/diagnosis_rnn.py |
|
|
1 |
''' |
|
|
2 |
Trains a RNN on medical diagnosis of diseases dataset |
|
|
3 |
data is obtained from various online sources |
|
|
4 |
Memory network needs to predict the disease using many symptoms listed as |
|
|
5 |
natural language sentences |
|
|
6 |
|
|
|
7 |
''' |
|
|
8 |
|
|
|
9 |
from __future__ import print_function |
|
|
10 |
from functools import reduce |
|
|
11 |
import re |
|
|
12 |
import tarfile |
|
|
13 |
import os.path |
|
|
14 |
import pickle |
|
|
15 |
import h5py |
|
|
16 |
import pdb |
|
|
17 |
from itertools import izip_longest |
|
|
18 |
|
|
|
19 |
import random |
|
|
20 |
import numpy as np |
|
|
21 |
np.random.seed(1337) # for reproducibility |
|
|
22 |
random.seed(1337) |
|
|
23 |
|
|
|
24 |
from keras.utils.data_utils import get_file |
|
|
25 |
from keras.layers.embeddings import Embedding |
|
|
26 |
from keras.layers.core import Dense, Merge, Dropout, RepeatVector |
|
|
27 |
from keras.layers import recurrent |
|
|
28 |
from keras.layers.recurrent import LSTM, GRU |
|
|
29 |
from keras.models import Sequential |
|
|
30 |
from keras.preprocessing.sequence import pad_sequences |
|
|
31 |
from keras.callbacks import ModelCheckpoint, Callback |
|
|
32 |
from utils import create_vectors_dataset, get_spacy_vectors |
|
|
33 |
from glove import Glove |
|
|
34 |
from spacy.en import English |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
|
|
|
38 |
RNN = recurrent.LSTM |
|
|
39 |
NUM_HIDDEN_UNITS = 128 |
|
|
40 |
BATCH_SIZE = 32 |
|
|
41 |
EPOCHS = 10 |
|
|
42 |
DROPOUT_FACTOR = 0.5 |
|
|
43 |
print('RNN / HIDDENS = {}, {}'.format(RNN, NUM_HIDDEN_UNITS)) |
|
|
44 |
|
|
|
45 |
max_len = 500 |
|
|
46 |
word_vec_dim = 300 |
|
|
47 |
vocab_size = 2350 |
|
|
48 |
|
|
|
49 |
training_set_file = 'data/training_set.dat' |
|
|
50 |
test_set_file = 'data/test_set.dat' |
|
|
51 |
|
|
|
52 |
train_stories = pickle.load(open(training_set_file,'r')) |
|
|
53 |
test_stories = pickle.load(open(test_set_file,'r')) |
|
|
54 |
|
|
|
55 |
train_stories = [(reduce(lambda x,y: x + y, map(list,fact)),q) for fact,q in train_stories] |
|
|
56 |
test_stories = [(reduce(lambda x,y: x + y, map(list,fact)),q) for fact,q in test_stories] |
|
|
57 |
|
|
|
58 |
answer_vocab = sorted(reduce(lambda x, y: x | y, (set([answer]) for _, answer in train_stories + test_stories))) |
|
|
59 |
# Reserve 0 for masking via pad_sequences |
|
|
60 |
answer_dict = dict((word, i) for i, word in enumerate(answer_vocab)) |
|
|
61 |
print('Answers dict len: {0}'.format(len(answer_dict))) |
|
|
62 |
|
|
|
63 |
# I need to check also if this exist |
|
|
64 |
#word_vectors_dir = 'word_vectors/glove.42B.300d.txt' |
|
|
65 |
#word_vectors_model = Glove.load_stanford(word_vectors_dir) |
|
|
66 |
nlp = English() |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
print('Build model...') |
|
|
70 |
|
|
|
71 |
model = Sequential() |
|
|
72 |
model.add(GRU(output_dim = NUM_HIDDEN_UNITS, activation='tanh', |
|
|
73 |
return_sequences=True, input_shape=(max_len, word_vec_dim))) |
|
|
74 |
model.add(Dropout(DROPOUT_FACTOR)) |
|
|
75 |
model.add(GRU(NUM_HIDDEN_UNITS, return_sequences=False)) |
|
|
76 |
model.add(Dense(vocab_size, init='uniform',activation='softmax')) |
|
|
77 |
|
|
|
78 |
#json_string = model.to_json() |
|
|
79 |
#model_file_name = 'models/lstm_num_hidden_units_' + str(NUM_HIDDEN_UNITS) + '_num_lstm_layers_' + str(2) + '_dropout_' + str(0.3) |
|
|
80 |
#open(model_file_name + '.json', 'w').write(json_string) |
|
|
81 |
|
|
|
82 |
print('Compiling model...') |
|
|
83 |
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) |
|
|
84 |
print('Compilation done...') |
|
|
85 |
|
|
|
86 |
print('Training') |
|
|
87 |
|
|
|
88 |
base_dir = '.' |
|
|
89 |
|
|
|
90 |
NUM_DATA_TRAIN = len(train_stories) |
|
|
91 |
NUM_DATA_TEST = len(test_stories) |
|
|
92 |
|
|
|
93 |
random.shuffle(train_stories) |
|
|
94 |
valid_stories = train_stories[int(len(train_stories)*0.95):] |
|
|
95 |
train_stories = train_stories[:int(len(train_stories)*0.95)] |
|
|
96 |
print('Validation size: {0}'.format(len(valid_stories))) |
|
|
97 |
print('Training size: {0}'.format(len(train_stories))) |
|
|
98 |
|
|
|
99 |
def grouper(iterable, n, fillvalue=None): |
|
|
100 |
args = [iter(iterable)] * n |
|
|
101 |
return izip_longest(*args, fillvalue=fillvalue) |
|
|
102 |
|
|
|
103 |
acc_hist = [] |
|
|
104 |
acc_hist.append(0.) |
|
|
105 |
show_batch_interval = 50 |
|
|
106 |
for k in xrange(EPOCHS): |
|
|
107 |
for b,train_batch in enumerate(zip(grouper(train_stories, BATCH_SIZE, fillvalue=train_stories[-1]))): |
|
|
108 |
X,Y = get_spacy_vectors(train_batch[0], answer_dict, |
|
|
109 |
max_len, nlp) |
|
|
110 |
|
|
|
111 |
loss = model.train_on_batch(X, Y) |
|
|
112 |
if b % show_batch_interval == 0: |
|
|
113 |
print('Epoch: {0}, Batch: {1}, loss: {2}'.format(k,b,loss)) |
|
|
114 |
|
|
|
115 |
X,Y = get_spacy_vectors(valid_stories, answer_dict, |
|
|
116 |
max_len, nlp) |
|
|
117 |
loss, acc = model.evaluate(X, Y, batch_size=BATCH_SIZE) |
|
|
118 |
print('Epoch{0}, Valid loss / valid accuracy = {1:.4f} / {2:.4f}'.format(k,loss, acc)) |
|
|
119 |
#Logging results |
|
|
120 |
with open(base_dir + '/logs/log_{0}_{1}_drop_{2}.txt'.format( |
|
|
121 |
'GRU',str(NUM_HIDDEN_UNITS),str(DROPOUT_FACTOR)),'a') as fil: |
|
|
122 |
fil.write(str(loss) + ' ' + str(acc) + '\n') |
|
|
123 |
#Saving model |
|
|
124 |
if max(acc_hist) < acc: |
|
|
125 |
model.save_weights(base_dir + '/models/weights_{0}_{1}_drop_{2}.hdf5'.format( |
|
|
126 |
'GRU',str(NUM_HIDDEN_UNITS),str(DROPOUT_FACTOR)),overwrite=True) |
|
|
127 |
acc_hist.append(acc) |
|
|
128 |
|
|
|
129 |
# Obtaining test results |
|
|
130 |
# Evaluatin Best 5 accuracy and best accuracy |
|
|
131 |
SAVE_ERRORS = False |
|
|
132 |
|
|
|
133 |
acc_5 = 0. |
|
|
134 |
acc = 0. |
|
|
135 |
for b,test_batch in enumerate(zip(grouper(test_stories, BATCH_SIZE, fillvalue=test_stories[-1]))): |
|
|
136 |
X,Y = get_spacy_vectors(test_batch[0], answer_dict, |
|
|
137 |
max_len, nlp) |
|
|
138 |
answers_test = Y if b == 0 else np.vstack((answers_test,Y)) |
|
|
139 |
preds = model.predict(X) |
|
|
140 |
# Saving in order to make some more visualizations |
|
|
141 |
all_predictions = preds if b == 0 else np.vstack((all_predictions,preds)) |
|
|
142 |
if b % 50 == 0: |
|
|
143 |
print('Batch: {0}'.format(b)) |
|
|
144 |
|
|
|
145 |
all_predictions = all_predictions[:len(test_stories)] |
|
|
146 |
answers_test = answers_test[:len(test_stories)] |
|
|
147 |
for k,(pred,answer) in enumerate(zip(all_predictions,answers_test)): |
|
|
148 |
prediction = np.argsort(pred)[-5:][::-1] |
|
|
149 |
pred_words = [answer_dict.keys()[answer_dict.values().index(pred)] for pred in prediction] |
|
|
150 |
answer_word = answer_dict.keys()[answer_dict.values().index(answer.argmax())] |
|
|
151 |
if answer_word in pred_words: |
|
|
152 |
acc_5 += 1. |
|
|
153 |
if pred_words[0] == answer_word: |
|
|
154 |
acc += 1. |
|
|
155 |
|
|
|
156 |
all_err = -np.log(all_predictions[range(all_predictions.shape[0]),answers_test.argmax(axis=1)]) |
|
|
157 |
|
|
|
158 |
np.savetxt('logs/error.dat',all_err) |
|
|
159 |
|
|
|
160 |
acc /= len(test_stories) |
|
|
161 |
acc_5 /= len(test_stories) |
|
|
162 |
print('Accuracy: {0}'.format(acc)) |
|
|
163 |
print('5 most prob. accuracy: {0}'.format(acc_5)) |