--- a +++ b/src/clustering/search.py @@ -0,0 +1,122 @@ +import config +from config import similarity +from .utils import mean_pooling, encode, find_cluster, text_splitter, semantic_search_base, forward, forward_doc +import pickle +import random +import os +import json +from annoy import AnnoyIndex + +class Buffer_best_k: + def __init__(self, k, initia_value=-float("inf")): + self.k = k + self.values = [initia_value] * self.k + self.data = [None] * self.k + + def new_val(self, value, data=None): + for i in range(self.k): + if self.values[i] < value: + self.values[i + 1 :] = self.values[i:-1] + self.data[i + 1 :] = self.data[i:-1] + self.values[i] = value + self.data[i] = data + return True + return False + + def get_data(self): + return self.data + + def get_values(self): + return self.values + +# # ---------------------------------- Kmeans ---------------------------------- # +# with open(config.embeddings_path + os.sep + "clustered_data_concepts.pkl", "rb") as f: +# clustered_data = pickle.load(f) + +# ----------------------------------- Annoy ---------------------------------- # +with open(config.embeddings_path + os.sep + "index_to_name.pkl", "rb") as f: + sample_names_list = pickle.load(f) +search_index = AnnoyIndex(config.embedding_size, config.annoy_metric) +search_index.load(config.embeddings_path + os.sep + "annoy_index_concepts.ann") + + +# --------------------------------- Functions -------------------------------- # +def parse_metadata(filename): + with open(config.metadata_path + os.sep + filename + ".json") as f: + metadata = json.load(f) + print("metadata", metadata) + if metadata["age"] != None: + metadata["age"] = int(metadata["age"]) + return metadata + + +def search_query(query, filters={}, top=30): + # encore query + query_emb = encode(query) + + + # ---------------------------------- Kmeans ---------------------------------- # + # # find cluster of docs it belongs in + # cluster = find_cluster(query_emb, clustered_data) + + # buffer = Buffer_best_k(k=top) + # for name, doc_emb in clustered_data[cluster]["elements"].items(): + # score = similarity(query_emb, doc_emb) + # # print(name, "\t{:.2f}".format(float(score))) + # buffer.new_val(score, name) + + # scores, data_names = buffer.get_values(), buffer.get_data() + + # ----------------------------------- Annoy ---------------------------------- # + indeces, scores = search_index.get_nns_by_vector(query_emb.numpy().reshape(-1), top, include_distances=True) + data_names = [sample_names_list[i] for i in indeces] + + + results = [] + for i, name in enumerate(data_names): + filename, paragraph = name.split(config.filename_split_key) + paragraph = int(paragraph) + with open(config.data_path + os.sep + filename + ".txt") as f: + text = f.read() + file_path = config.data_path + os.sep + filename + ".txt" + results.append( + { + "score": float(scores[i]), + "filename": filename, + "id": name, + "preview": text_splitter(text, file_path)[paragraph], + "metadata": parse_metadata(filename), + } + ) + + # TODO: need a better filtering + # filter results + range_filters = ["age", "birthdate", "admission_date", "discharge_date"] + multiselect_filters = ["sexe"] + filtered_results = [] + for result in results: + valid = True + for key in range_filters: + if key in filters and result["metadata"][key] != None: + if filters[key][0] > result["metadata"][key] or filters[key][1] < result["metadata"][key]: + valid = False + print("filtered", result["metadata"][key], filters[key]) + break + if valid: + for key in multiselect_filters: + if key in filters: + if result["metadata"][key] not in filters[key]: + valid = False + print("filtered", result["metadata"][key], filters[key]) + break + if valid: + filtered_results.append(result) + + count_filtered = len(filtered_results) + filtered_results = filtered_results[:top] + return filtered_results, count_filtered + + +if __name__ == "__main__": + query = "What is the best way to train a neural network?" + print(search_query(query))