[507a54]: / production / action-server / actions / retrievers.py

Download this file

144 lines (126 with data), 4.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from abc import ABC, abstractmethod
from decouple import config
import json
import requests
from .database import DatabaseConnector
from .utils import get_retriever_conf, get_columns
class Retriever(ABC):
'''
Meta class for retrievers.
'''
@abstractmethod
def __init__():
pass
@abstractmethod
def retrieve():
pass
class SQLRetriever(Retriever):
'''
A retriever to retrieve information from database. Uses intent and extracted entity.
'''
def __init__(self, conf):
host, user, database, self.tables = self.parse_conf(conf)
# read password from .env
password = config('SQL_PASSWORD')
self.db = DatabaseConnector(host=host,
user=user,
password=password,
database=database,
)
def retrieve(self, tracker):
'''
Retrieves information from database. Collects parameters from tracker.
'''
input_entity = self.get_entity(tracker=tracker)
intent = tracker.get_slot('intent_name')
if input_entity is not None:
tables_containing_entity = []
for table in self.tables:
rec = self.db.search_in_table(table, input_entity)
if rec:
tables_containing_entity.append(table)
if tables_containing_entity:
# For now in information is found in multiple tables or multiple
# rows of table, just the first one is used, we're planning to
# make better use of multiple piece if inforamtion in future updates
table = tables_containing_entity[0]
columns = get_columns(intent, table)
if columns:
answer = self.collect_answer(table, input_entity, columns)
return answer
else:
# Detected intent does not map to any of table's columns.
return '__CODE2__'
else:
# input_entity is not found in any of tables.
return '__CODE1__'
else:
# input_entity in None
return '__CODE0__'
@staticmethod
def parse_conf(conf):
'''
Returns database information and list of tables.
'''
host = conf['host']
user = conf['user']
database = conf['database']
tables = conf['tables']
if not isinstance(tables, list):
tables = [tables]
return host, user, database, tables
@staticmethod
def get_entity(tracker):
'''
Returns the extracted entity. Prioritizes the last extractor in pipeline.
'''
entity = tracker.get_slot('entity_name')
if isinstance(entity, list):
return entity[-1]
else:
return entity
def collect_answer(self, table, name, columns):
'''
collects anwers from different columns of table and puts them together.
'''
answers = []
for column in columns:
record = self.db.retrieve_from_table(table, name, column) #is a list
answers.append(record[0])
answer = ' '.join(answers)
return answer
class SemanticRetriever(Retriever):
'''
'''
def __init__(self, conf):
self.host = conf['host']
self.top_k = conf['top_k']
self.api_key = config('DPR_API_KEY')
self.headers = {'Content-Type': 'application/json'}
def retrieve(self, tracker):
'''
Retrieves information from database. Collects parameters from tracker.
'''
query = tracker.latest_message['text']
payload = json.dumps({
"query": query,
"top_k": self.top_k,
"api_key": self.api_key,
})
response = requests.request("POST",
self.host,
headers=self.headers,
data=payload
)
contexts = response.json()['contexts']
contexts = '. '.join(contexts)
return contexts
def create_retriever():
'''
Returns a retriever object based on config file.
'''
conf = get_retriever_conf()
if conf['type']=='SQL_table':
return SQLRetriever(conf)
elif conf['type']=='semantic':
return SemanticRetriever(conf)