a b/aiagents4pharma/talk2biomodels/tests/test_integration.py
1
'''
2
Test cases for Talk2Biomodels.
3
'''
4
5
import pandas as pd
6
from langchain_core.messages import HumanMessage, ToolMessage
7
from langchain_openai import ChatOpenAI
8
from ..agents.t2b_agent import get_app
9
10
LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)
11
12
def test_integration():
13
    '''
14
    Test the integration of the tools.
15
    '''
16
    unique_id = 1234567
17
    app = get_app(unique_id, llm_model=LLM_MODEL)
18
    config = {"configurable": {"thread_id": unique_id}}
19
    # ##########################################
20
    # ## Test simulate_model tool
21
    # ##########################################
22
    prompt = '''Simulate the model BIOMD0000000537 for 100 hours and time intervals
23
    100 with an initial concentration of `DoseQ2W` set to 300 and `Dose`
24
    set to 0. Reset the concentration of `Ab{serum}` to 100 every 25 hours.'''
25
    # Test the tool get_modelinfo
26
    response = app.invoke(
27
                        {"messages": [HumanMessage(content=prompt)]},
28
                        config=config
29
                    )
30
    assistant_msg = response["messages"][-1].content
31
    print (assistant_msg)
32
    # Check if the assistant message is a string
33
    assert isinstance(assistant_msg, str)
34
    ##########################################
35
    # Test ask_question tool when simulation
36
    # results are available
37
    ##########################################
38
    # Update state
39
    app.update_state(config, {"llm_model": LLM_MODEL})
40
    prompt = """What is the concentration of CRP in serum after 100 hours?
41
    Round off the value to 2 decimal places."""
42
    # Test the tool get_modelinfo
43
    response = app.invoke(
44
                        {"messages": [HumanMessage(content=prompt)]},
45
                        config=config
46
                    )
47
    assistant_msg = response["messages"][-1].content
48
    # print (assistant_msg)
49
    # Check if the assistant message is a string
50
    assert '211' in assistant_msg
51
52
    ##########################################
53
    # Test the custom_plotter tool when the
54
    # simulation results are available but
55
    # the species is not available
56
    ##########################################
57
    prompt = """Call the custom_plotter tool to make a plot
58
        showing only species 'Infected cases'. Let me
59
        know if these species were not found. Do not
60
        invoke any other tool."""
61
    # Update state
62
    app.update_state(config, {"llm_model": LLM_MODEL}
63
                    )
64
    # Test the tool get_modelinfo
65
    response = app.invoke(
66
                        {"messages": [HumanMessage(content=prompt)]},
67
                        config=config
68
                    )
69
    assistant_msg = response["messages"][-1].content
70
    current_state = app.get_state(config)
71
    # Get the messages from the current state
72
    # and reverse the order
73
    reversed_messages = current_state.values["messages"][::-1]
74
    # Loop through the reversed messages until a
75
    # ToolMessage is found.
76
    predicted_artifact = []
77
    for msg in reversed_messages:
78
        if isinstance(msg, ToolMessage):
79
            # Work on the message if it is a ToolMessage
80
            # These may contain additional visuals that
81
            # need to be displayed to the user.
82
            if msg.name == "custom_plotter":
83
                predicted_artifact = msg.artifact
84
                break
85
    # Check if the the predicted artifact is `None`
86
    assert predicted_artifact is None
87
88
    ##########################################
89
    # Test custom_plotter tool when the
90
    # simulation results are available
91
    ##########################################
92
    prompt = "Plot only CRP related species."
93
94
    # Update state
95
    app.update_state(config, {"llm_model": LLM_MODEL}
96
                    )
97
    # Test the tool get_modelinfo
98
    response = app.invoke(
99
                        {"messages": [HumanMessage(content=prompt)]},
100
                        config=config
101
                    )
102
    assistant_msg = response["messages"][-1].content
103
    current_state = app.get_state(config)
104
    # Get the messages from the current state
105
    # and reverse the order
106
    reversed_messages = current_state.values["messages"][::-1]
107
    # Loop through the reversed messages
108
    # until a ToolMessage is found.
109
    expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
110
    expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
111
    expected_header += ['CRP{liver}']
112
    predicted_artifact = []
113
    for msg in reversed_messages:
114
        if isinstance(msg, ToolMessage):
115
            # Work on the message if it is a ToolMessage
116
            # These may contain additional visuals that
117
            # need to be displayed to the user.
118
            if msg.name == "custom_plotter":
119
                predicted_artifact = msg.artifact['dic_data']
120
                break
121
    # Convert the artifact into a pandas dataframe
122
    # for easy comparison
123
    df = pd.DataFrame(predicted_artifact)
124
    # Extract the headers from the dataframe
125
    predicted_header = df.columns.tolist()
126
    # Check if the header is in the expected_header
127
    # assert expected_header in predicted_artifact
128
    assert set(expected_header).issubset(set(predicted_header))