Diff of /data_augment.py [000000] .. [5c3b8b]

Switch to unified view

a b/data_augment.py
1
'''
2
    Data augmentation strategies for facts lists
3
'''
4
5
from bs4 import BeautifulSoup
6
import urllib
7
import pdb
8
import json
9
import os
10
import pandas as pd
11
import re
12
import numpy as np
13
import pickle
14
15
from nltk.corpus import wordnet
16
from nltk.corpus import stopwords
17
from pattern.en import singularize,pluralize
18
import random
19
20
from utils import get_stories
21
22
random.seed(1234)
23
np.random.seed(1234)
24
25
def check_repeated(name,repeated_list):
26
    name = name.lower().strip()
27
    return name if not (name in repeated_list) else repeated_list[name]
28
29
def synonyms(data):
30
    augment_n = 10
31
    data_dict = dict((key,[val]) for val,key,_ in data)
32
33
    is_plural = lambda word: singularize(word) <> word
34
    stops = set(stopwords.words('english') + ['l'])
35
36
    for disease in data:
37
        for _ in range(augment_n):
38
            new_facts_list = []
39
            for fact in disease[0]:
40
                new_fact = fact[:]
41
                for k,word in enumerate(fact):
42
                    if word not in stops:
43
                        syn = wordnet.synsets(word)
44
                        if syn:
45
                            random_syn = syn[0]              
46
                            random_lemma = random.choice(random_syn.lemma_names())
47
                            random_lemma = pluralize(random_lemma) if is_plural(word)\
48
                                                else random_lemma
49
                            random_lemma = random_lemma.lower()
50
                            random_lemma = random_lemma.replace('_',' ')
51
                            random_lemma = random_lemma.replace('-',' ')
52
                            if ' ' in random_lemma:
53
                                continue
54
                            new_fact[k] = random_lemma
55
                new_facts_list.append(new_fact)
56
            #print new_facts_list
57
            data_dict[disease[1]].append(new_facts_list[:])
58
    return data_dict
59
60
    # TODO: this is adding the name of the disease by synonym, check it!
61
62
def remove(data,data_dict):
63
64
    num_delete = (3,15) # Number of facts to delete
65
    min_delete = 5 # delete only if you have more than 8 facts
66
    n_augment = 20
67
68
    for (values,name,_) in data:
69
        facts = data_dict[name]
70
        new_facts = []
71
        for k in range(n_augment):
72
            fact = random.choice(facts)
73
            n_facts = len(fact)
74
            if n_facts > min_delete:
75
                max_facts = num_delete[1] if n_facts > 15 else n_facts - 1
76
                min_facts = num_delete[0]
77
                n_choice = np.random.randint(min_facts, max_facts)
78
                choice = np.random.choice(n_facts, n_choice, replace=False)
79
                new_fact = [f for k,f in enumerate(fact) if k not in choice]
80
                data_dict[name].append(new_fact)
81
                new_facts.append(new_fact)
82
    return data_dict
83
84
85
def permute(data,data_dict):
86
    n_augment = 10
87
88
    for (values, name,_) in data:
89
        facts = data_dict[name]
90
        new_facts = []
91
        for k in range(n_augment):
92
            fact = random.choice(facts)
93
            n_facts = len(fact)
94
            permutations = np.random.permutation(n_facts)
95
            new_fact = np.array(fact[:])
96
            new_fact = new_fact[permutations]
97
            new_facts.append(new_fact)
98
            data_dict[name].append(list(new_fact))
99
100
    return data_dict
101
102
103
file_names = ['facts_list_all.txt']
104
105
training_set = []
106
test_set = []
107
data = []
108
109
for file_name in file_names:
110
    print 'Reading {0} .....'.format(file_name)
111
    read_file = open('data/{0}'.format(file_name), 'r')
112
    # Data in format:
113
    # [([[fact1],[fact2],..][answer])...]
114
    # where each fact is a list of words
115
    
116
    data += get_stories(read_file)
117
118
    read_file.close()
119
120
121
# Data augmenting strategies
122
#1. Changing randomly nouns by synonyms
123
print('Data augmentation: synonyms')
124
data_dict = synonyms(data)
125
print('Number of diseases: {0}'.format(len(data_dict)))
126
#2. Removing facts randomly
127
print('Data augmentation: removing')
128
data_dict = remove(data,data_dict)
129
#3. Changing facts order
130
print('Data augmentation: permutation')
131
data_dict = permute(data,data_dict)
132
133
134
for (values, name,_) in data:
135
    # Save training and test data
136
    data_len = len(data_dict[name])
137
    training_size = int(0.7 * data_len)
138
    test_size = int(0.3 * data_len)
139
    facts = np.array(data_dict[name])
140
    indexes = np.random.permutation(len(facts))
141
    training_facts = facts[indexes[:training_size]]
142
    test_facts = facts[indexes[training_size:]]
143
    training_set += zip(list(training_facts),[name]*len(training_facts))
144
    test_set += zip(list(test_facts),[name]*len(test_facts))   
145
146
147
print(len(training_set))
148
print(len(test_set))
149
pickle.dump(training_set,open('data/training_set.dat','w'))
150
pickle.dump(test_set,open('data/test_set.dat','w'))
151
print 'Saved'