--- a
+++ b/src/clustering/search.py
@@ -0,0 +1,122 @@
+import config
+from config import similarity
+from .utils import mean_pooling, encode, find_cluster, text_splitter, semantic_search_base, forward, forward_doc
+import pickle
+import random
+import os
+import json
+from annoy import AnnoyIndex
+
+class Buffer_best_k:
+    def __init__(self, k, initia_value=-float("inf")):
+        self.k = k
+        self.values = [initia_value] * self.k
+        self.data = [None] * self.k
+
+    def new_val(self, value, data=None):
+        for i in range(self.k):
+            if self.values[i] < value:
+                self.values[i + 1 :] = self.values[i:-1]
+                self.data[i + 1 :] = self.data[i:-1]
+                self.values[i] = value
+                self.data[i] = data
+                return True
+        return False
+
+    def get_data(self):
+        return self.data
+
+    def get_values(self):
+        return self.values
+
+# # ---------------------------------- Kmeans ---------------------------------- #
+# with open(config.embeddings_path + os.sep + "clustered_data_concepts.pkl", "rb") as f:
+#     clustered_data = pickle.load(f)
+
+# ----------------------------------- Annoy ---------------------------------- #
+with open(config.embeddings_path + os.sep + "index_to_name.pkl", "rb") as f:
+    sample_names_list = pickle.load(f)
+search_index = AnnoyIndex(config.embedding_size, config.annoy_metric)
+search_index.load(config.embeddings_path + os.sep + "annoy_index_concepts.ann")
+
+
+# --------------------------------- Functions -------------------------------- #
+def parse_metadata(filename):
+    with open(config.metadata_path + os.sep + filename + ".json") as f:
+        metadata = json.load(f)
+    print("metadata", metadata)
+    if metadata["age"] != None:
+        metadata["age"] = int(metadata["age"])
+    return metadata
+
+
+def search_query(query, filters={}, top=30):
+    # encore query
+    query_emb = encode(query)
+    
+    
+    # ---------------------------------- Kmeans ---------------------------------- #
+    # # find cluster of docs it belongs in
+    # cluster = find_cluster(query_emb, clustered_data)
+
+    # buffer = Buffer_best_k(k=top)
+    # for name, doc_emb in clustered_data[cluster]["elements"].items():
+    #     score = similarity(query_emb, doc_emb)
+    #     # print(name, "\t{:.2f}".format(float(score)))
+    #     buffer.new_val(score, name)
+
+    # scores, data_names = buffer.get_values(), buffer.get_data()
+    
+    # ----------------------------------- Annoy ---------------------------------- #
+    indeces, scores = search_index.get_nns_by_vector(query_emb.numpy().reshape(-1), top, include_distances=True)
+    data_names = [sample_names_list[i] for i in indeces]
+    
+    
+    results = []
+    for i, name in enumerate(data_names):
+        filename, paragraph = name.split(config.filename_split_key)
+        paragraph = int(paragraph)
+        with open(config.data_path + os.sep + filename + ".txt") as f:
+            text = f.read()
+        file_path = config.data_path + os.sep + filename + ".txt"
+        results.append(
+            {
+                "score": float(scores[i]),
+                "filename": filename,
+                "id": name,
+                "preview": text_splitter(text, file_path)[paragraph],
+                "metadata": parse_metadata(filename),
+            }
+        )
+
+    # TODO: need a better filtering
+    # filter results
+    range_filters = ["age", "birthdate", "admission_date", "discharge_date"]
+    multiselect_filters = ["sexe"]
+    filtered_results = []
+    for result in results:
+        valid = True
+        for key in range_filters:
+            if key in filters and result["metadata"][key] != None:
+                if filters[key][0] > result["metadata"][key] or filters[key][1] < result["metadata"][key]:
+                    valid = False
+                    print("filtered", result["metadata"][key], filters[key])
+                    break
+        if valid:
+            for key in multiselect_filters:
+                if key in filters:
+                    if result["metadata"][key] not in filters[key]:
+                        valid = False
+                        print("filtered", result["metadata"][key], filters[key])
+                        break
+        if valid:
+            filtered_results.append(result)
+
+    count_filtered = len(filtered_results)
+    filtered_results = filtered_results[:top]
+    return filtered_results, count_filtered
+
+
+if __name__ == "__main__":
+    query = "What is the best way to train a neural network?"
+    print(search_query(query))