--- a +++ b/aiagents4pharma/talk2biomodels/tests/test_integration.py @@ -0,0 +1,128 @@ +''' +Test cases for Talk2Biomodels. +''' + +import pandas as pd +from langchain_core.messages import HumanMessage, ToolMessage +from langchain_openai import ChatOpenAI +from ..agents.t2b_agent import get_app + +LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0) + +def test_integration(): + ''' + Test the integration of the tools. + ''' + unique_id = 1234567 + app = get_app(unique_id, llm_model=LLM_MODEL) + config = {"configurable": {"thread_id": unique_id}} + # ########################################## + # ## Test simulate_model tool + # ########################################## + prompt = '''Simulate the model BIOMD0000000537 for 100 hours and time intervals + 100 with an initial concentration of `DoseQ2W` set to 300 and `Dose` + set to 0. Reset the concentration of `Ab{serum}` to 100 every 25 hours.''' + # Test the tool get_modelinfo + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config + ) + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert isinstance(assistant_msg, str) + ########################################## + # Test ask_question tool when simulation + # results are available + ########################################## + # Update state + app.update_state(config, {"llm_model": LLM_MODEL}) + prompt = """What is the concentration of CRP in serum after 100 hours? + Round off the value to 2 decimal places.""" + # Test the tool get_modelinfo + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config + ) + assistant_msg = response["messages"][-1].content + # print (assistant_msg) + # Check if the assistant message is a string + assert '211' in assistant_msg + + ########################################## + # Test the custom_plotter tool when the + # simulation results are available but + # the species is not available + ########################################## + prompt = """Call the custom_plotter tool to make a plot + showing only species 'Infected cases'. Let me + know if these species were not found. Do not + invoke any other tool.""" + # Update state + app.update_state(config, {"llm_model": LLM_MODEL} + ) + # Test the tool get_modelinfo + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config + ) + assistant_msg = response["messages"][-1].content + current_state = app.get_state(config) + # Get the messages from the current state + # and reverse the order + reversed_messages = current_state.values["messages"][::-1] + # Loop through the reversed messages until a + # ToolMessage is found. + predicted_artifact = [] + for msg in reversed_messages: + if isinstance(msg, ToolMessage): + # Work on the message if it is a ToolMessage + # These may contain additional visuals that + # need to be displayed to the user. + if msg.name == "custom_plotter": + predicted_artifact = msg.artifact + break + # Check if the the predicted artifact is `None` + assert predicted_artifact is None + + ########################################## + # Test custom_plotter tool when the + # simulation results are available + ########################################## + prompt = "Plot only CRP related species." + + # Update state + app.update_state(config, {"llm_model": LLM_MODEL} + ) + # Test the tool get_modelinfo + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config + ) + assistant_msg = response["messages"][-1].content + current_state = app.get_state(config) + # Get the messages from the current state + # and reverse the order + reversed_messages = current_state.values["messages"][::-1] + # Loop through the reversed messages + # until a ToolMessage is found. + expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular'] + expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)'] + expected_header += ['CRP{liver}'] + predicted_artifact = [] + for msg in reversed_messages: + if isinstance(msg, ToolMessage): + # Work on the message if it is a ToolMessage + # These may contain additional visuals that + # need to be displayed to the user. + if msg.name == "custom_plotter": + predicted_artifact = msg.artifact['dic_data'] + break + # Convert the artifact into a pandas dataframe + # for easy comparison + df = pd.DataFrame(predicted_artifact) + # Extract the headers from the dataframe + predicted_header = df.columns.tolist() + # Check if the header is in the expected_header + # assert expected_header in predicted_artifact + assert set(expected_header).issubset(set(predicted_header))