[3af7d7]: / aiagents4pharma / talk2knowledgegraphs / tests / test_utils_embeddings_huggingface.py

Download this file

45 lines (39 with data), 1.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Test cases for utils/embeddings/huggingface.py
"""
import pytest
from ..utils.embeddings.huggingface import EmbeddingWithHuggingFace
@pytest.fixture(name="embedding_model")
def embedding_model_fixture():
"""Return the configuration object for the HuggingFace embedding model and model object"""
return EmbeddingWithHuggingFace(
model_name="NeuML/pubmedbert-base-embeddings",
model_cache_dir="../../cache",
truncation=True,
)
def test_embedding_with_huggingface_embed_documents(embedding_model):
"""Test embedding documents using the EmbeddingWithHuggingFace class."""
# Perform embedding
texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
result = embedding_model.embed_documents(texts)
# Check the result
assert len(result) == 3
assert len(result[0]) == 768
def test_embedding_with_huggingface_embed_query(embedding_model):
"""Test embedding a query using the EmbeddingWithHuggingFace class."""
# Perform embedding
text = "Adalimumab"
result = embedding_model.embed_query(text)
# Check the result
assert len(result) == 768
def test_embedding_with_huggingface_failed():
"""Test embedding documents using the EmbeddingWithHuggingFace class."""
# Check if the model is available on HuggingFace Hub
model_name = "aiagents4pharma/embeddings"
err_msg = f"Model {model_name} is not available on HuggingFace Hub."
with pytest.raises(ValueError, match=err_msg):
EmbeddingWithHuggingFace(
model_name=model_name,
model_cache_dir="../../cache",
truncation=True,
)