[3af7d7]: / aiagents4pharma / talk2knowledgegraphs / tests / test_tools_subgraph_extraction.py

Download this file

175 lines (151 with data), 6.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Test cases for tools/subgraph_extraction.py
"""
import pytest
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ..agents.t2kg_agent import get_app
# Define the data path
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
@pytest.fixture(name="input_dict")
def input_dict_fixture():
"""
Input dictionary fixture.
"""
input_dict = {
"llm_model": None, # TBA for each test case
"embedding_model": None, # TBA for each test case
"uploaded_files": [],
"topk_nodes": 3,
"topk_edges": 3,
"dic_source_graph": [
{
"name": "PrimeKG",
"kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
"kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
}
],
}
return input_dict
def test_extract_subgraph_wo_docs(input_dict):
"""
Test the subgraph extraction tool without any documents using OpenAI model.
Args:
input_dict: Input dictionary.
"""
# Prepare LLM and embedding model
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
# Setup the app
unique_id = 12345
app = get_app(unique_id, llm_model=input_dict["llm_model"])
config = {"configurable": {"thread_id": unique_id}}
# Update state
app.update_state(
config,
input_dict,
)
prompt = """
Please directly invoke `subgraph_extraction` tool without calling any other tools
to respond to the following prompt:
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
(IBD) that existed in the knowledge graph.
Please set the extraction name for this process as `subkg_12345`.
"""
# Test the tool subgraph_extraction
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
# Check assistant message
assistant_msg = response["messages"][-1].content
assert isinstance(assistant_msg, str)
# Check tool message
tool_msg = response["messages"][-2]
assert tool_msg.name == "subgraph_extraction"
# Check extracted subgraph dictionary
current_state = app.get_state(config)
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
assert isinstance(dic_extracted_graph, dict)
assert dic_extracted_graph["name"] == "subkg_12345"
assert dic_extracted_graph["graph_source"] == "PrimeKG"
assert dic_extracted_graph["topk_nodes"] == 3
assert dic_extracted_graph["topk_edges"] == 3
assert isinstance(dic_extracted_graph["graph_dict"], dict)
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
assert isinstance(dic_extracted_graph["graph_text"], str)
# Check if the nodes are in the graph_text
assert all(
n[0] in dic_extracted_graph["graph_text"]
for n in dic_extracted_graph["graph_dict"]["nodes"]
)
# Check if the edges are in the graph_text
assert all(
",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
in dic_extracted_graph["graph_text"]
for e in dic_extracted_graph["graph_dict"]["edges"]
)
def test_extract_subgraph_w_docs(input_dict):
"""
Test the subgraph extraction tool with a document as reference (i.e., endotype document)
using OpenAI model.
Args:
input_dict: Input dictionary.
"""
# Prepare LLM and embedding model
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
# Setup the app
unique_id = 12345
app = get_app(unique_id, llm_model=input_dict["llm_model"])
config = {"configurable": {"thread_id": unique_id}}
# Update state
input_dict["uploaded_files"] = [
{
"file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
"file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
"file_type": "endotype",
"uploaded_by": "VPEUser",
"uploaded_timestamp": "2024-11-05 00:00:00",
}
]
app.update_state(
config,
input_dict,
)
prompt = """
Please ONLY invoke `subgraph_extraction` tool without calling any other tools
to respond to the following prompt:
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
(IBD) that existed in the knowledge graph.
Please set the extraction name for this process as `subkg_12345`.
"""
# Test the tool subgraph_extraction
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
# Check assistant message
assistant_msg = response["messages"][-1].content
assert isinstance(assistant_msg, str)
# Check tool message
tool_msg = response["messages"][-2]
assert tool_msg.name == "subgraph_extraction"
# Check extracted subgraph dictionary
current_state = app.get_state(config)
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
assert isinstance(dic_extracted_graph, dict)
assert dic_extracted_graph["name"] == "subkg_12345"
assert dic_extracted_graph["graph_source"] == "PrimeKG"
assert dic_extracted_graph["topk_nodes"] == 3
assert dic_extracted_graph["topk_edges"] == 3
assert isinstance(dic_extracted_graph["graph_dict"], dict)
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
assert isinstance(dic_extracted_graph["graph_text"], str)
# Check if the nodes are in the graph_text
assert all(
n[0] in dic_extracted_graph["graph_text"]
for n in dic_extracted_graph["graph_dict"]["nodes"]
)
# Check if the edges are in the graph_text
assert all(
",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
in dic_extracted_graph["graph_text"]
for e in dic_extracted_graph["graph_dict"]["edges"]
)