|
a |
|
b/src/Matcher/crossencoder_reranker.py |
|
|
1 |
import json |
|
|
2 |
import logging |
|
|
3 |
from typing import List |
|
|
4 |
|
|
|
5 |
import torch |
|
|
6 |
from sagemaker_inference import encoder |
|
|
7 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
8 |
|
|
|
9 |
PAIRS = "pairs" |
|
|
10 |
SCORES = "scores" |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class CrossEncoder: |
|
|
14 |
def __init__(self) -> None: |
|
|
15 |
self.device = ( |
|
|
16 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
17 |
) |
|
|
18 |
logging.info(f"Using device: {self.device}") |
|
|
19 |
model_name = "BAAI/bge-reranker-base" |
|
|
20 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
21 |
self.model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
22 |
self.model = self.model.to(self.device) |
|
|
23 |
|
|
|
24 |
def __call__(self, pairs: List[List[str]]) -> List[float]: |
|
|
25 |
with torch.inference_mode(): |
|
|
26 |
inputs = self.tokenizer( |
|
|
27 |
pairs, |
|
|
28 |
padding=True, |
|
|
29 |
truncation=True, |
|
|
30 |
return_tensors="pt", |
|
|
31 |
max_length=512, |
|
|
32 |
) |
|
|
33 |
inputs = inputs.to(self.device) |
|
|
34 |
scores = ( |
|
|
35 |
self.model(**inputs, return_dict=True) |
|
|
36 |
.logits.view( |
|
|
37 |
-1, |
|
|
38 |
) |
|
|
39 |
.float() |
|
|
40 |
) |
|
|
41 |
|
|
|
42 |
return scores.detach().cpu().tolist() |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def model_fn(model_dir: str) -> CrossEncoder: |
|
|
46 |
try: |
|
|
47 |
return CrossEncoder() |
|
|
48 |
except Exception: |
|
|
49 |
logging.exception(f"Failed to load model from: {model_dir}") |
|
|
50 |
raise |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def transform_fn( |
|
|
54 |
cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str |
|
|
55 |
) -> bytes: |
|
|
56 |
payload = json.loads(input_data) |
|
|
57 |
model_output = cross_encoder(**payload) |
|
|
58 |
output = {SCORES: model_output} |
|
|
59 |
return encoder.encode(output, accept) |