Switch to unified view

a b/src/Matcher/crossencoder_reranker.py
1
import json
2
import logging
3
from typing import List
4
5
import torch
6
from sagemaker_inference import encoder
7
from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
9
PAIRS = "pairs"
10
SCORES = "scores"
11
12
13
class CrossEncoder:
14
    def __init__(self) -> None:
15
        self.device = (
16
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
        )
18
        logging.info(f"Using device: {self.device}")
19
        model_name = "BAAI/bge-reranker-base"
20
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
21
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
22
        self.model = self.model.to(self.device)
23
24
    def __call__(self, pairs: List[List[str]]) -> List[float]:
25
        with torch.inference_mode():
26
            inputs = self.tokenizer(
27
                pairs,
28
                padding=True,
29
                truncation=True,
30
                return_tensors="pt",
31
                max_length=512,
32
            )
33
            inputs = inputs.to(self.device)
34
            scores = (
35
                self.model(**inputs, return_dict=True)
36
                .logits.view(
37
                    -1,
38
                )
39
                .float()
40
            )
41
42
        return scores.detach().cpu().tolist()
43
44
45
def model_fn(model_dir: str) -> CrossEncoder:
46
    try:
47
        return CrossEncoder()
48
    except Exception:
49
        logging.exception(f"Failed to load model from: {model_dir}")
50
        raise
51
52
53
def transform_fn(
54
    cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str
55
) -> bytes:
56
    payload = json.loads(input_data)
57
    model_output = cross_encoder(**payload)
58
    output = {SCORES: model_output}
59
    return encoder.encode(output, accept)