Skip to content

Commit d462b8a

Browse files
committed
Update test_agent.py
1 parent 9117215 commit d462b8a

File tree

1 file changed

+53
-80
lines changed

1 file changed

+53
-80
lines changed

examples/tutorials/00_sync/020_streaming/tests/test_agent.py

Lines changed: 53 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
Run: pytest tests/test_agent.py -v
1515
"""
1616

17+
import pytest
18+
1719
from agentex import Agentex
1820
from agentex.lib.testing import (
1921
test_sync_agent,
@@ -24,67 +26,53 @@
2426
AGENT_NAME = "s020-streaming"
2527

2628

27-
def test_multiturn_conversation_with_state():
28-
"""Test multi-turn non-streaming conversation with state management validation."""
29-
# Need direct client for state checks
30-
client = Agentex(api_key="test", base_url="http://localhost:5003")
29+
@pytest.fixture
30+
def agent_name():
31+
"""Return the agent name for testing."""
32+
return AGENT_NAME
33+
34+
@pytest.fixture
35+
def test_agent(agent_name: str):
36+
"""Fixture to create a test sync agent."""
37+
with test_sync_agent(agent_name=agent_name) as test:
38+
yield test
3139

32-
# Get agent
33-
agents = client.agents.list()
34-
agent = next((a for a in agents if a.name == AGENT_NAME), None)
35-
assert agent is not None, f"Agent {AGENT_NAME} not found"
40+
class TestNonStreamingMessages:
41+
"""Test non-streaming message sending."""
3642

37-
with test_sync_agent(agent_name=AGENT_NAME) as test:
43+
def test_send_message(self, test_agent):
3844
messages = [
39-
"Hello, can you tell me a little bit about tennis? I want you to make sure you use the word 'tennis' in each response.",
45+
"Hello, can you tell me a little bit about tennis? I want to you make sure you use the word 'tennis' in each response.",
4046
"Pick one of the things you just mentioned, and dive deeper into it.",
4147
"Can you now output a summary of this conversation",
4248
]
4349

4450
for i, msg in enumerate(messages):
45-
# Send message
46-
response = test.send_message(msg)
51+
response = test_agent.send_message(msg)
4752

48-
# Validate response structure
53+
# Validate response (agent may require OpenAI key)
4954
assert_valid_agent_response(response)
5055

51-
# Check message history count
52-
message_history = client.messages.list(task_id=test.task_id)
53-
expected_count = (i + 1) * 2 # Each turn: user + agent
54-
assert (
55-
len(message_history) == expected_count
56-
), f"Expected {expected_count} messages, got {len(message_history)}"
57-
58-
# Check state (agent should maintain system prompt)
59-
# Note: states.list API may have changed - handle gracefully
60-
try:
61-
states = client.states.list(agent_id=agent.id, task_id=test.task_id)
62-
if states and len(states) > 0:
63-
# Filter to our task
64-
task_states = [s for s in states if s.task_id == test.task_id]
65-
if task_states:
66-
state = task_states[0]
67-
assert state.state is not None
68-
assert (
69-
state.state.get("system_prompt")
70-
== "You are a helpful assistant that can answer questions."
71-
)
72-
except Exception as e:
73-
# If states API has changed, skip this check
74-
print(f"State check skipped (API may have changed): {e}")
75-
76-
77-
def test_multiturn_streaming_with_state():
78-
"""Test multi-turn streaming conversation with state management validation."""
79-
# Need direct client for state checks
80-
client = Agentex(api_key="test", base_url="http://localhost:5003")
81-
82-
# Get agent
83-
agents = client.agents.list()
84-
agent = next((a for a in agents if a.name == AGENT_NAME), None)
85-
assert agent is not None, f"Agent {AGENT_NAME} not found"
86-
87-
with test_sync_agent(agent_name=AGENT_NAME) as test:
56+
# Validate that "tennis" appears in the response because that is what our model does
57+
assert "tennis" in response.content.lower()
58+
59+
states = test_agent.client.states.list(task_id=test_agent.task_id)
60+
assert len(states) == 1
61+
62+
state = states[0]
63+
assert state.state is not None
64+
assert state.state.get("system_prompt") == "You are a helpful assistant that can answer questions."
65+
66+
# Verify conversation history
67+
message_history = test_agent.client.messages.list(task_id=test_agent.task_id)
68+
assert len(message_history) == (i + 1) * 2 # user + agent messages
69+
70+
71+
class TestStreamingMessages:
72+
"""Test streaming message sending."""
73+
74+
def test_send_stream_message(self, test_agent):
75+
"""Test streaming messages in a multi-turn conversation."""
8876
messages = [
8977
"Hello, can you tell me a little bit about tennis? I want you to make sure you use the word 'tennis' in each response.",
9078
"Pick one of the things you just mentioned, and dive deeper into it.",
@@ -93,39 +81,24 @@ def test_multiturn_streaming_with_state():
9381

9482
for i, msg in enumerate(messages):
9583
# Get streaming response
96-
response_gen = test.send_message_streaming(msg)
84+
response_gen = test_agent.send_message_streaming(msg)
9785

98-
# Collect streaming deltas
86+
# Collect the streaming response
9987
aggregated_content, chunks = collect_streaming_deltas(response_gen)
10088

101-
# Validate streaming response
102-
assert aggregated_content is not None, "Should receive aggregated content"
103-
assert len(chunks) > 1, "Should receive multiple chunks in streaming response"
104-
105-
# Check message history count
106-
message_history = client.messages.list(task_id=test.task_id)
107-
expected_count = (i + 1) * 2
108-
assert (
109-
len(message_history) == expected_count
110-
), f"Expected {expected_count} messages, got {len(message_history)}"
111-
112-
# Check state (agent should maintain system prompt)
113-
# Note: states.list API may have changed - handle gracefully
114-
try:
115-
states = client.states.list(agent_id=agent.id, task_id=test.task_id)
116-
if states and len(states) > 0:
117-
# Filter to our task
118-
task_states = [s for s in states if s.task_id == test.task_id]
119-
if task_states:
120-
state = task_states[0]
121-
assert state.state is not None
122-
assert (
123-
state.state.get("system_prompt")
124-
== "You are a helpful assistant that can answer questions."
125-
)
126-
except Exception as e:
127-
# If states API has changed, skip this check
128-
print(f"State check skipped (API may have changed): {e}")
89+
assert len(chunks) > 1
90+
91+
# Validate we got content
92+
assert len(aggregated_content) > 0, "Should receive content"
93+
94+
# Validate that "tennis" appears in the response because that is what our model does
95+
assert "tennis" in aggregated_content.lower()
96+
97+
states = test_agent.client.states.list(task_id=test_agent.task_id)
98+
assert len(states) == 1
99+
100+
message_history = test_agent.client.messages.list(task_id=test_agent.task_id)
101+
assert len(message_history) == (i + 1) * 2 # user + agent messages
129102

130103

131104
if __name__ == "__main__":

0 commit comments

Comments
 (0)