Skip to content

Commit 9b7c429

Browse files
authored
Combine consecutive AG-UI user and assistant messages into the same model request/response (#2912)
1 parent b26a687 commit 9b7c429

File tree

2 files changed

+180
-35
lines changed

2 files changed

+180
-35
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424

2525
from pydantic import BaseModel, ValidationError
26+
from typing_extensions import assert_never
2627

2728
from . import _utils
2829
from ._agent_graph import CallToolsNode, ModelRequestNode
@@ -32,7 +33,9 @@
3233
FunctionToolResultEvent,
3334
ModelMessage,
3435
ModelRequest,
36+
ModelRequestPart,
3537
ModelResponse,
38+
ModelResponsePart,
3639
ModelResponseStreamEvent,
3740
PartDeltaEvent,
3841
PartStartEvent,
@@ -573,49 +576,57 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
573576
"""Convert a AG-UI history to a Pydantic AI one."""
574577
result: list[ModelMessage] = []
575578
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
579+
request_parts: list[ModelRequestPart] | None = None
580+
response_parts: list[ModelResponsePart] | None = None
576581
for msg in messages:
577-
if isinstance(msg, UserMessage):
578-
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
582+
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage | ToolMessage):
583+
if request_parts is None:
584+
request_parts = []
585+
result.append(ModelRequest(parts=request_parts))
586+
response_parts = None
587+
588+
if isinstance(msg, UserMessage):
589+
request_parts.append(UserPromptPart(content=msg.content))
590+
elif isinstance(msg, SystemMessage | DeveloperMessage):
591+
request_parts.append(SystemPromptPart(content=msg.content))
592+
elif isinstance(msg, ToolMessage):
593+
tool_name = tool_calls.get(msg.tool_call_id)
594+
if tool_name is None: # pragma: no cover
595+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
596+
597+
request_parts.append(
598+
ToolReturnPart(
599+
tool_name=tool_name,
600+
content=msg.content,
601+
tool_call_id=msg.tool_call_id,
602+
)
603+
)
604+
else:
605+
assert_never(msg)
606+
579607
elif isinstance(msg, AssistantMessage):
608+
if response_parts is None:
609+
response_parts = []
610+
result.append(ModelResponse(parts=response_parts))
611+
request_parts = None
612+
580613
if msg.content:
581-
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
614+
response_parts.append(TextPart(content=msg.content))
582615

583616
if msg.tool_calls:
584617
for tool_call in msg.tool_calls:
585618
tool_calls[tool_call.id] = tool_call.function.name
586619

587-
result.append(
588-
ModelResponse(
589-
parts=[
590-
ToolCallPart(
591-
tool_name=tool_call.function.name,
592-
tool_call_id=tool_call.id,
593-
args=tool_call.function.arguments,
594-
)
595-
for tool_call in msg.tool_calls
596-
]
620+
response_parts.extend(
621+
ToolCallPart(
622+
tool_name=tool_call.function.name,
623+
tool_call_id=tool_call.id,
624+
args=tool_call.function.arguments,
597625
)
626+
for tool_call in msg.tool_calls
598627
)
599-
elif isinstance(msg, SystemMessage):
600-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
601-
elif isinstance(msg, ToolMessage):
602-
tool_name = tool_calls.get(msg.tool_call_id)
603-
if tool_name is None: # pragma: no cover
604-
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
605-
606-
result.append(
607-
ModelRequest(
608-
parts=[
609-
ToolReturnPart(
610-
tool_name=tool_name,
611-
content=msg.content,
612-
tool_call_id=msg.tool_call_id,
613-
)
614-
]
615-
)
616-
)
617-
elif isinstance(msg, DeveloperMessage): # pragma: no branch
618-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
628+
else:
629+
assert_never(msg)
619630

620631
return result
621632

tests/test_ag_ui.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@
2121
from pydantic_ai._run_context import RunContext
2222
from pydantic_ai.agent import Agent, AgentRunResult
2323
from pydantic_ai.exceptions import UserError
24-
from pydantic_ai.messages import ModelMessage
24+
from pydantic_ai.messages import (
25+
ModelMessage,
26+
ModelRequest,
27+
ModelResponse,
28+
SystemPromptPart,
29+
TextPart,
30+
ToolCallPart,
31+
ToolReturnPart,
32+
UserPromptPart,
33+
)
2534
from pydantic_ai.models.function import (
2635
AgentInfo,
2736
DeltaThinkingCalls,
@@ -34,7 +43,7 @@
3443
from pydantic_ai.output import OutputDataT
3544
from pydantic_ai.tools import AgentDepsT, ToolDefinition
3645

37-
from .conftest import IsSameStr
46+
from .conftest import IsDatetime, IsSameStr
3847

3948
has_ag_ui: bool = False
4049
with contextlib.suppress(ImportError):
@@ -59,6 +68,7 @@
5968
SSE_CONTENT_TYPE,
6069
OnCompleteFunc,
6170
StateDeps,
71+
_messages_from_ag_ui, # type: ignore[reportPrivateUsage]
6272
run_ag_ui,
6373
)
6474

@@ -1347,3 +1357,127 @@ def error_callback(run_result: AgentRunResult[Any]) -> None:
13471357
assert len(events) > 0
13481358
assert events[0]['type'] == 'RUN_STARTED'
13491359
assert any(event['type'] == 'RUN_ERROR' for event in events)
1360+
1361+
1362+
async def test_messages_from_ag_ui() -> None:
1363+
messages = [
1364+
SystemMessage(
1365+
id='msg_1',
1366+
content='System message',
1367+
),
1368+
DeveloperMessage(
1369+
id='msg_2',
1370+
content='Developer message',
1371+
),
1372+
UserMessage(
1373+
id='msg_3',
1374+
content='User message',
1375+
),
1376+
UserMessage(
1377+
id='msg_4',
1378+
content='User message',
1379+
),
1380+
AssistantMessage(
1381+
id='msg_5',
1382+
content='Assistant message',
1383+
),
1384+
AssistantMessage(
1385+
id='msg_6',
1386+
tool_calls=[
1387+
ToolCall(
1388+
id='tool_call_1',
1389+
function=FunctionCall(
1390+
name='tool_call_1',
1391+
arguments='{}',
1392+
),
1393+
),
1394+
],
1395+
),
1396+
AssistantMessage(
1397+
id='msg_7',
1398+
tool_calls=[
1399+
ToolCall(
1400+
id='tool_call_2',
1401+
function=FunctionCall(
1402+
name='tool_call_2',
1403+
arguments='{}',
1404+
),
1405+
),
1406+
],
1407+
),
1408+
ToolMessage(
1409+
id='msg_8',
1410+
content='Tool message',
1411+
tool_call_id='tool_call_1',
1412+
),
1413+
ToolMessage(
1414+
id='msg_9',
1415+
content='Tool message',
1416+
tool_call_id='tool_call_2',
1417+
),
1418+
UserMessage(
1419+
id='msg_10',
1420+
content='User message',
1421+
),
1422+
AssistantMessage(
1423+
id='msg_11',
1424+
content='Assistant message',
1425+
),
1426+
]
1427+
1428+
assert _messages_from_ag_ui(messages) == snapshot(
1429+
[
1430+
ModelRequest(
1431+
parts=[
1432+
SystemPromptPart(
1433+
content='System message',
1434+
timestamp=IsDatetime(),
1435+
),
1436+
SystemPromptPart(
1437+
content='Developer message',
1438+
timestamp=IsDatetime(),
1439+
),
1440+
UserPromptPart(
1441+
content='User message',
1442+
timestamp=IsDatetime(),
1443+
),
1444+
UserPromptPart(
1445+
content='User message',
1446+
timestamp=IsDatetime(),
1447+
),
1448+
]
1449+
),
1450+
ModelResponse(
1451+
parts=[
1452+
TextPart(content='Assistant message'),
1453+
ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'),
1454+
ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'),
1455+
],
1456+
timestamp=IsDatetime(),
1457+
),
1458+
ModelRequest(
1459+
parts=[
1460+
ToolReturnPart(
1461+
tool_name='tool_call_1',
1462+
content='Tool message',
1463+
tool_call_id='tool_call_1',
1464+
timestamp=IsDatetime(),
1465+
),
1466+
ToolReturnPart(
1467+
tool_name='tool_call_2',
1468+
content='Tool message',
1469+
tool_call_id='tool_call_2',
1470+
timestamp=IsDatetime(),
1471+
),
1472+
UserPromptPart(
1473+
content='User message',
1474+
timestamp=IsDatetime(),
1475+
),
1476+
]
1477+
),
1478+
ModelResponse(
1479+
parts=[TextPart(content='Assistant message')],
1480+
timestamp=IsDatetime(),
1481+
),
1482+
]
1483+
)

0 commit comments

Comments
 (0)