Switch to unified view

a b/development/paraphrase/randomize entities.py
1
import re
2
from ruamel import yaml
3
import numpy as np
4
import sys
5
6
7
def load_entities(entity_names):
8
    entities_dict = dict()
9
    for entity in entity_names:
10
        entity_file_adr = f'./{entity}.yml'
11
        entities_dict[entity] = load_entities_list(entity_file_adr)
12
    return entities_dict
13
 
14
15
def load_entities_list(file_adr):
16
    with open(file_adr, 'r') as f:
17
        entities = yaml.load(f, Loader=yaml.RoundTripLoader)
18
    entities = entities['nlu'][0]
19
    entities = entities['examples'].split('\n')
20
    entities = [row[2:] for row in entities]
21
    return entities
22
23
24
def randomize_entity_names(nlu_dict, entities_dict):
25
    for intent in nlu_dict['nlu']:
26
        examples = ''
27
        for example in intent['examples'].split('\n'):
28
            entity_match = dict()
29
            for entity in entities_dict:
30
                entity_match = re.search(f'\[[^\]]*\]\({entity}\)', example)
31
                if entity_match:
32
                    start, end = entity_match.span()
33
                    random_entity_name = f'[{np.random.choice(entities_dict[entity]).strip()}]({entity})'
34
                    example = example.replace(example[start:end] , random_entity_name)
35
                    
36
            examples += example + '\n'
37
        intent['examples'] = examples[:-1]  #removing last \n to avoid \n\n after the last example
38
    return nlu_dict
39
40
41
42
if __name__=='__main__':
43
    
44
    NLU_FILE = './nlu_cleaned.yml'
45
    ENTITY_NAMES = ['drug', 'lab'] #lookup files should be in same directoty as this file
46
    OUTPUT_FILE = 'nlu_random.yml'
47
    #load files
48
    with open(NLU_FILE, 'r') as f:
49
        nlu = yaml.load(f, Loader=yaml.RoundTripLoader)
50
    entities = load_entities(ENTITY_NAMES)
51
52
    randomized_nlu = randomize_entity_names(nlu, entities)
53
    #save
54
    with open(OUTPUT_FILE, 'w') as f:
55
        yaml.dump(randomized_nlu, f, Dumper=yaml.RoundTripDumper, default_flow_style=None)
56
57
58
    
59
    
60
    
61
62
    
63
    
64
    
65
    
66
    
67