|
a |
|
b/doc2vec_trainer.py |
|
|
1 |
import argparse |
|
|
2 |
import lockfile |
|
|
3 |
|
|
|
4 |
from daemon import DaemonContext |
|
|
5 |
from gensim.models import Doc2Vec |
|
|
6 |
from gensim.models.doc2vec import LabeledSentence |
|
|
7 |
|
|
|
8 |
from loader import get_data |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class LabeledDocIterator(object): |
|
|
12 |
def __init__(self, patient_list, categories, status): |
|
|
13 |
self.patient_list = patient_list |
|
|
14 |
self.category = categories |
|
|
15 |
self.status = status |
|
|
16 |
|
|
|
17 |
def __iter__(self): |
|
|
18 |
for i in self.patient_list: |
|
|
19 |
p = get_data([i])[0] |
|
|
20 |
self.status.write(p['NEW_EMPI'] + '\n') |
|
|
21 |
for category in categories: |
|
|
22 |
if category in p: |
|
|
23 |
for idx, doc in enumerate(p[category]): |
|
|
24 |
tag = p['NEW_EMPI'] + '_' + category + '_' + str(idx) + '\n' |
|
|
25 |
yield LabeledSentence(words=doc['free_text'].split(), tags=[tag]) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def train_doc2vec_model(categories, n_patients, output_file, status_file, dm): |
|
|
29 |
with open(status_file, 'w') as status: |
|
|
30 |
it = LabeledDocIterator(range(n_patients), categories, status) |
|
|
31 |
|
|
|
32 |
model = Doc2Vec(size=300, window=10, dm=dm, min_count=5, workers=11,alpha=0.025, min_alpha=0.025) # use fixed learning rate |
|
|
33 |
model.build_vocab(it) |
|
|
34 |
for epoch in range(10): |
|
|
35 |
message = ("***********Training Epoch: " + str(epoch) |
|
|
36 |
+ ("***********") + '\n') |
|
|
37 |
print(message) |
|
|
38 |
status.write(message) |
|
|
39 |
model.train(it) |
|
|
40 |
model.alpha -= 0.002 # decrease the learning rate |
|
|
41 |
model.min_alpha = model.alpha # fix the learning rate, no decay |
|
|
42 |
model.train(it) |
|
|
43 |
|
|
|
44 |
# Save the model |
|
|
45 |
model.save(output_file) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
if __name__ == "__main__": |
|
|
49 |
parser = argparse.ArgumentParser() |
|
|
50 |
parser.add_argument("output_file") |
|
|
51 |
parser.add_argument("n_patients") |
|
|
52 |
parser.add_argument("categories") |
|
|
53 |
# Switches between Distributed Memory and Distributed Bag of Words Model |
|
|
54 |
parser.add_argument("dm") |
|
|
55 |
args = parser.parse_args() |
|
|
56 |
status_file = args.output_file + '.status' |
|
|
57 |
categories = args.categories.split(',') |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
base = '/home/ubuntu/josh_project' |
|
|
61 |
context = DaemonContext( |
|
|
62 |
working_directory=base, |
|
|
63 |
umask=0o002, |
|
|
64 |
pidfile=lockfile.FileLock(base + 'doc2vec_trainer.pid'), |
|
|
65 |
) |
|
|
66 |
|
|
|
67 |
with context: |
|
|
68 |
train_doc2vec_model(categories, int(args.n_patients), |
|
|
69 |
args.output_file, status_file, int(args.dm)) |