Diff of /src/Matcher/BM25.py [000000] .. [f87529]

Switch to unified view

a b/src/Matcher/BM25.py
1
from typing import List
2
import math
3
4
from langchain_core.callbacks import CallbackManagerForRetrieverRun
5
from langchain_core.documents import Document
6
from langchain_core.retrievers import BaseRetriever
7
8
class BM25Retriever(BaseRetriever):
9
    documents: List[Document]
10
    k: int
11
    document_entity_index: dict
12
    idf_scores: dict
13
    average_doc_length: float
14
15
    def __init__(self, documents: List[Document], k: int, entity_extractor):
16
        super().__init__()
17
        self.documents = documents
18
        self.k = k
19
        self.entity_extractor = entity_extractor
20
        self._prepare_documents()
21
22
    def _prepare_documents(self):
23
        """Prepare documents by extracting entities and calculating necessary BM25 metrics."""
24
        num_docs = len(self.documents)
25
        doc_lengths = []
26
        df = {}
27
        
28
        # Extract entities and calculate document frequency (DF)
29
        for doc in self.documents:
30
            entities = self.entity_extractor.extract(doc.page_content)
31
            doc.entity_bag = entities
32
            doc_lengths.append(len(entities))
33
            
34
            unique_entities = set(entities)
35
            for entity in unique_entities:
36
                if entity in df:
37
                    df[entity] += 1
38
                else:
39
                    df[entity] = 1
40
        
41
        self.average_doc_length = sum(doc_lengths) / num_docs
42
        self.idf_scores = {term: math.log((num_docs - df[term] + 0.5) / (df[term] + 0.5)) for term in df}
43
44
    def _get_relevant_documents(
45
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
46
    ) -> List[Document]:
47
        """Use BM25 to rank documents based on entities extracted from the query."""
48
        query_entities = self.entity_extractor.extract(query)
49
        scores = []
50
51
        for doc in self.documents:
52
            score = 0
53
            for entity in query_entities:
54
                if entity in doc.entity_bag:
55
                    term_frequency = doc.entity_bag.count(entity)
56
                    idf = self.idf_scores.get(entity, 0)
57
                    doc_length = len(doc.entity_bag)
58
                    score += idf * (term_frequency * (1.2 + 1) / (term_frequency + 1.2 * (1 - 0.75 + 0.75 * (doc_length / self.average_doc_length))))
59
            if score > 0:
60
                scores.append((score, doc))
61
62
        # Sort documents by their score and return top k
63
        sorted_docs = sorted(scores, key=lambda x: x[0], reverse=True)
64
        return [doc for _, doc in sorted_docs[:self.k]]
65
66
# Note: This implementation assumes the presence of an entity_extractor with an extract method.