a b/doc2vec_trainer.py
1
import argparse
2
import lockfile
3
4
from daemon import DaemonContext
5
from gensim.models import Doc2Vec
6
from gensim.models.doc2vec import LabeledSentence
7
8
from loader import get_data
9
10
11
class LabeledDocIterator(object):
12
    def __init__(self, patient_list, categories, status):
13
        self.patient_list = patient_list 
14
        self.category = categories
15
        self.status = status
16
17
    def __iter__(self):
18
        for i in self.patient_list:
19
            p = get_data([i])[0]
20
            self.status.write(p['NEW_EMPI'] + '\n')
21
            for category in categories:
22
                if category in p:
23
                    for idx, doc in enumerate(p[category]):
24
                        tag = p['NEW_EMPI'] + '_' + category + '_' + str(idx) + '\n'
25
                        yield LabeledSentence(words=doc['free_text'].split(), tags=[tag])
26
27
28
def train_doc2vec_model(categories, n_patients, output_file, status_file, dm):
29
    with open(status_file, 'w') as status:
30
        it = LabeledDocIterator(range(n_patients), categories, status)
31
32
        model = Doc2Vec(size=300, window=10, dm=dm, min_count=5, workers=11,alpha=0.025, min_alpha=0.025) # use fixed learning rate
33
        model.build_vocab(it)
34
        for epoch in range(10):
35
            message = ("***********Training Epoch: " + str(epoch)
36
                       + ("***********") + '\n')
37
            print(message)
38
            status.write(message)
39
            model.train(it)
40
            model.alpha -= 0.002 # decrease the learning rate
41
            model.min_alpha = model.alpha # fix the learning rate, no decay
42
            model.train(it)
43
44
        # Save the model
45
        model.save(output_file)
46
47
48
if __name__ == "__main__":
49
    parser = argparse.ArgumentParser()
50
    parser.add_argument("output_file")
51
    parser.add_argument("n_patients")
52
    parser.add_argument("categories")
53
    # Switches between Distributed Memory and Distributed Bag of Words Model
54
    parser.add_argument("dm")
55
    args = parser.parse_args()
56
    status_file = args.output_file + '.status'
57
    categories = args.categories.split(',')
58
59
60
    base = '/home/ubuntu/josh_project'
61
    context = DaemonContext(
62
        working_directory=base,
63
        umask=0o002,
64
        pidfile=lockfile.FileLock(base + 'doc2vec_trainer.pid'),
65
    )
66
67
    with context:
68
        train_doc2vec_model(categories, int(args.n_patients), 
69
                        args.output_file, status_file, int(args.dm))