--- a +++ b/src/Matcher/BM25.py @@ -0,0 +1,66 @@ +from typing import List +import math + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +class BM25Retriever(BaseRetriever): + documents: List[Document] + k: int + document_entity_index: dict + idf_scores: dict + average_doc_length: float + + def __init__(self, documents: List[Document], k: int, entity_extractor): + super().__init__() + self.documents = documents + self.k = k + self.entity_extractor = entity_extractor + self._prepare_documents() + + def _prepare_documents(self): + """Prepare documents by extracting entities and calculating necessary BM25 metrics.""" + num_docs = len(self.documents) + doc_lengths = [] + df = {} + + # Extract entities and calculate document frequency (DF) + for doc in self.documents: + entities = self.entity_extractor.extract(doc.page_content) + doc.entity_bag = entities + doc_lengths.append(len(entities)) + + unique_entities = set(entities) + for entity in unique_entities: + if entity in df: + df[entity] += 1 + else: + df[entity] = 1 + + self.average_doc_length = sum(doc_lengths) / num_docs + self.idf_scores = {term: math.log((num_docs - df[term] + 0.5) / (df[term] + 0.5)) for term in df} + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Use BM25 to rank documents based on entities extracted from the query.""" + query_entities = self.entity_extractor.extract(query) + scores = [] + + for doc in self.documents: + score = 0 + for entity in query_entities: + if entity in doc.entity_bag: + term_frequency = doc.entity_bag.count(entity) + idf = self.idf_scores.get(entity, 0) + doc_length = len(doc.entity_bag) + score += idf * (term_frequency * (1.2 + 1) / (term_frequency + 1.2 * (1 - 0.75 + 0.75 * (doc_length / self.average_doc_length)))) + if score > 0: + scores.append((score, doc)) + + # Sort documents by their score and return top k + sorted_docs = sorted(scores, key=lambda x: x[0], reverse=True) + return [doc for _, doc in sorted_docs[:self.k]] + +# Note: This implementation assumes the presence of an entity_extractor with an extract method.