diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 1d1ba091..11e2ff4c 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -293,8 +293,7 @@ export class ChatClient { this.abortController = new AbortController() try { - // Get model messages for the LLM - const modelMessages = this.processor.toModelMessages() + const messages = this.processor.getMessages() // Call onResponse callback await this.callbacksRef.current.onResponse() @@ -307,7 +306,7 @@ export class ChatClient { // Connect and stream const stream = this.connection.connect( - modelMessages, + messages, bodyWithConversationId, this.abortController.signal, ) @@ -417,6 +416,8 @@ export class ChatClient { // Find the tool call ID from the approval ID const messages = this.processor.getMessages() let foundToolCallId: string | undefined + let foundToolName: string | undefined + let foundToolInput: any | undefined for (const msg of messages) { const toolCallPart = msg.parts.find( @@ -425,6 +426,12 @@ export class ChatClient { ) if (toolCallPart) { foundToolCallId = toolCallPart.id + foundToolName = toolCallPart.name + try { + foundToolInput = JSON.parse(toolCallPart.arguments) + } catch { + // Ignore parse errors + } break } } @@ -440,6 +447,32 @@ export class ChatClient { // Add response via processor this.processor.addToolApprovalResponse(response.id, response.approved) + // Execute client-side tool if approved + if (response.approved && foundToolCallId && foundToolName) { + const clientTool = this.clientToolsRef.current.get(foundToolName) + if (clientTool?.execute) { + try { + const output = await clientTool.execute(foundToolInput) + await this.addToolResult({ + toolCallId: foundToolCallId, + tool: foundToolName, + output, + state: 'output-available', + }) + return + } catch (error: any) { + await this.addToolResult({ + toolCallId: foundToolCallId, + tool: foundToolName, + output: null, + state: 'output-error', + errorText: error.message, + }) + return + } + } + } + // If stream is in progress, queue continuation check for after it ends if (this.isLoading) { this.queuePostStreamAction(() => this.checkForContinuation()) diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index 302b306e..cf6a6b04 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -89,6 +89,12 @@ export interface FetchConnectionOptions { signal?: AbortSignal body?: Record fetchClient?: typeof globalThis.fetch + /** + * Send full UIMessage objects (including `parts`) instead of ModelMessages. + * Required for advanced server features that depend on UIMessage metadata + * (e.g. tool approvals and client tool results tracked in parts). + */ + sendFullMessages?: boolean } /** @@ -138,7 +144,9 @@ export function fetchServerSentEvents( const resolvedOptions = typeof options === 'function' ? await options() : options - const modelMessages = convertMessagesToModelMessages(messages) + const requestMessages = resolvedOptions.sendFullMessages + ? messages + : convertMessagesToModelMessages(messages) const requestHeaders: Record = { 'Content-Type': 'application/json', @@ -147,7 +155,7 @@ export function fetchServerSentEvents( // Merge body from options with messages and data const requestBody = { - messages: modelMessages, + messages: requestMessages, data, ...resolvedOptions.body, } @@ -238,8 +246,9 @@ export function fetchHttpStream( const resolvedOptions = typeof options === 'function' ? await options() : options - // Convert UIMessages to ModelMessages if needed - const modelMessages = convertMessagesToModelMessages(messages) + const requestMessages = resolvedOptions.sendFullMessages + ? messages + : convertMessagesToModelMessages(messages) const requestHeaders: Record = { 'Content-Type': 'application/json', @@ -248,7 +257,7 @@ export function fetchHttpStream( // Merge body from options with messages and data const requestBody = { - messages: modelMessages, + messages: requestMessages, data, ...resolvedOptions.body, } diff --git a/packages/typescript/ai-client/tests/chat-client-approval.test.ts b/packages/typescript/ai-client/tests/chat-client-approval.test.ts new file mode 100644 index 00000000..88667b0b --- /dev/null +++ b/packages/typescript/ai-client/tests/chat-client-approval.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, it, vi } from 'vitest' +import { ChatClient } from '../src/chat-client' +import { stream } from '../src/connection-adapters' +import type { StreamChunk } from '@tanstack/ai' + +function createMockConnectionAdapter(options: { chunks: StreamChunk[] }) { + return stream(async function* () { + for (const chunk of options.chunks) { + yield chunk + } + }) +} + +function createApprovalToolCallChunks( + toolCalls: Array<{ + id: string + name: string + arguments: string + approvalId: string + }>, +): StreamChunk[] { + const chunks: StreamChunk[] = [] + const timestamp = Date.now() + + // Start assistant message + chunks.push({ + type: 'content', + id: 'msg-1', + model: 'test-model', + timestamp, + delta: '', + content: '', + role: 'assistant', + }) + + for (const toolCall of toolCalls) { + // 1. Tool Call Chunk + chunks.push({ + type: 'tool_call', + id: 'msg-1', + model: 'test-model', + timestamp, + toolCall: { + id: toolCall.id, + type: 'function', + function: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + }, + index: 0, + }) + + // 2. Approval Requested Chunk + chunks.push({ + type: 'approval-requested', + id: 'msg-1', + model: 'test-model', + timestamp, + toolCallId: toolCall.id, + toolName: toolCall.name, + input: JSON.parse(toolCall.arguments), + approval: { + id: toolCall.approvalId, + needsApproval: true, + }, + } as any) // Cast to any if types are not perfectly aligned yet, or use correct type + } + + // Done chunk + chunks.push({ + type: 'done', + id: 'msg-1', + model: 'test-model', + timestamp, + finishReason: 'tool_calls', + }) + + return chunks +} + +describe('ChatClient Approval Flow', () => { + it('should execute client tool when approved', async () => { + const toolName = 'delete_local_data' + const toolCallId = 'call_123' + const approvalId = 'approval_123' + const input = { key: 'test-key' } + + const chunks = createApprovalToolCallChunks([ + { + id: toolCallId, + name: toolName, + arguments: JSON.stringify(input), + approvalId, + }, + ]) + + const adapter = createMockConnectionAdapter({ chunks }) + + const execute = vi.fn().mockResolvedValue({ deleted: true }) + const clientTool = { + name: toolName, + description: 'Delete data', + execute, + } + + const client = new ChatClient({ + connection: adapter, + tools: [clientTool], + }) + + // Start the flow + await client.sendMessage('Delete data') + + // Wait for stream to finish (approval request should be pending) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify tool execution hasn't happened yet + expect(execute).not.toHaveBeenCalled() + + // Approve the tool + await client.addToolApprovalResponse({ + id: approvalId, + approved: true, + }) + + // Wait for execution (this is where it currently hangs/fails) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Expect execute to have been called + expect(execute).toHaveBeenCalledWith(input) + }) +}) diff --git a/packages/typescript/ai-client/tests/connection-adapters.test.ts b/packages/typescript/ai-client/tests/connection-adapters.test.ts index cfce39cd..28365ff1 100644 --- a/packages/typescript/ai-client/tests/connection-adapters.test.ts +++ b/packages/typescript/ai-client/tests/connection-adapters.test.ts @@ -318,6 +318,107 @@ describe('connection-adapters', () => { expect(body.data).toEqual({ key: 'value' }) }) + it('should send model messages by default', async () => { + const mockReader = { + read: vi.fn().mockResolvedValue({ done: true, value: undefined }), + releaseLock: vi.fn(), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } + + fetchMock.mockResolvedValue(mockResponse as any) + + const adapter = fetchServerSentEvents('/api/chat') + + for await (const _ of adapter.connect([ + { + id: 'msg_1', + role: 'assistant', + parts: [ + { + type: 'tool-call', + id: 'tool_1', + name: 'testTool', + arguments: '{}', + state: 'approval-responded', + approval: { id: 'approval_tool_1', needsApproval: true, approved: true }, + }, + ], + createdAt: new Date(), + }, + ] as any)) { + // Consume + } + + const call = fetchMock.mock.calls[0] + const body = JSON.parse(call?.[1]?.body as string) + expect(body.messages[0]).not.toHaveProperty('parts') + expect(body.messages[0]).toMatchObject({ + role: 'assistant', + }) + }) + + it('should send full UI messages when configured', async () => { + const mockReader = { + read: vi.fn().mockResolvedValue({ done: true, value: undefined }), + releaseLock: vi.fn(), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } + + fetchMock.mockResolvedValue(mockResponse as any) + + const adapter = fetchServerSentEvents('/api/chat', { + sendFullMessages: true, + }) + + const uiMessages = [ + { + id: 'msg_1', + role: 'assistant', + parts: [ + { + type: 'tool-call', + id: 'tool_1', + name: 'testTool', + arguments: '{}', + state: 'approval-responded', + approval: { id: 'approval_tool_1', needsApproval: true, approved: true }, + }, + ], + createdAt: new Date(), + }, + ] + + for await (const _ of adapter.connect(uiMessages as any)) { + // Consume + } + + const call = fetchMock.mock.calls[0] + const body = JSON.parse(call?.[1]?.body as string) + expect(body.messages[0]).toHaveProperty('parts') + expect(body.messages[0]).toMatchObject({ + role: 'assistant', + parts: [ + { + type: 'tool-call', + id: 'tool_1', + approval: { id: 'approval_tool_1', approved: true }, + }, + ], + }) + }) + it('should use custom fetchClient when provided', async () => { const customFetch = vi.fn() const mockReader = { diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index 95c01374..2791f530 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -8,6 +8,7 @@ import { aiEventClient } from '../../event-client.js' import { streamToText } from '../../stream-to-response.js' import { ToolCallManager, executeToolCalls } from './tools/tool-calls' +import { convertMessagesToModelMessages } from './messages' import { convertSchemaToJsonSchema, isStandardSchema, @@ -31,6 +32,7 @@ import type { TextOptions, Tool, ToolCall, + UIMessage, } from '../../types' // =========================== @@ -206,7 +208,7 @@ class TextEngine< private readonly effectiveRequest?: Request | RequestInit private readonly effectiveSignal?: AbortSignal - private messages: Array + private messages: Array private iterationCount = 0 private lastFinishReason: string | null = null private streamStartTime = 0 @@ -244,7 +246,7 @@ class TextEngine< /** Get the final messages array after the chat loop completes */ getMessages(): Array { - return this.messages + return convertMessagesToModelMessages(this.messages) } async *run(): AsyncGenerator { @@ -386,7 +388,7 @@ class TextEngine< for await (const chunk of this.adapter.chatStream({ model: this.params.model, - messages: this.messages, + messages: convertMessagesToModelMessages(this.messages), tools: toolsWithJsonSchemas, temperature, topP, @@ -718,17 +720,13 @@ class TextEngine< for (const part of parts) { if ( part.type === 'tool-call' && - part.state === 'approval-responded' && - part.approval + part.approval && + part.approval.approved !== undefined ) { approvals.set(part.approval.id, part.approval.approved) } - if ( - part.type === 'tool-call' && - part.output !== undefined && - !part.approval - ) { + if (part.type === 'tool-call' && part.output !== undefined) { clientToolResults.set(part.id, part.output) } } @@ -847,15 +845,16 @@ class TextEngine< } private getPendingToolCallsFromMessages(): Array { + const modelMessages = convertMessagesToModelMessages(this.messages) const completedToolIds = new Set( - this.messages + modelMessages .filter((message) => message.role === 'tool' && message.toolCallId) .map((message) => message.toolCallId!), // toolCallId exists due to filter ) const pending: Array = [] - for (const message of this.messages) { + for (const message of modelMessages) { if (message.role === 'assistant' && message.toolCalls) { for (const toolCall of message.toolCalls) { if (!completedToolIds.has(toolCall.id)) { @@ -886,7 +885,7 @@ class TextEngine< return ( this.loopStrategy({ iterationCount: this.iterationCount, - messages: this.messages, + messages: convertMessagesToModelMessages(this.messages), finishReason: this.lastFinishReason, }) && this.toolPhase === 'continue' ) @@ -985,46 +984,31 @@ export function chat< // If outputSchema is provided, run agentic structured output if (outputSchema) { return runAgenticStructuredOutput( - options as unknown as TextActivityOptions< - AnyTextAdapter, - SchemaInput, - boolean - >, + options as TextActivityOptions, ) as TextActivityResult } // If stream is explicitly false, run non-streaming text if (stream === false) { - return runNonStreamingText( - options as unknown as TextActivityOptions< - AnyTextAdapter, - undefined, - false - >, - ) as TextActivityResult + return runNonStreamingText(options) as TextActivityResult } // Otherwise, run streaming text (default) - return runStreamingText( - options as unknown as TextActivityOptions, - ) as TextActivityResult + return runStreamingText(options) as TextActivityResult } /** * Run streaming text (agentic or one-shot depending on tools) */ -async function* runStreamingText( - options: TextActivityOptions, +async function* runStreamingText( + options: TextActivityOptions, ): AsyncIterable { const { adapter, ...textOptions } = options const model = adapter.model const engine = new TextEngine({ adapter, - params: { ...textOptions, model } as TextOptions< - Record, - Record - >, + params: { ...textOptions, model, messages: textOptions.messages ?? [] }, }) for await (const chunk of engine.run()) { @@ -1036,13 +1020,11 @@ async function* runStreamingText( * Run non-streaming text - collects all content and returns as a string. * Runs the full agentic loop (if tools are provided) but returns collected text. */ -function runNonStreamingText( - options: TextActivityOptions, +function runNonStreamingText( + options: TextActivityOptions, ): Promise { // Run the streaming text and collect all text using streamToText - const stream = runStreamingText( - options as unknown as TextActivityOptions, - ) + const stream = runStreamingText(options) return streamToText(stream) } @@ -1053,8 +1035,11 @@ function runNonStreamingText( * 2. Once complete, call adapter.structuredOutput with the conversation context * 3. Validate and return the structured result */ -async function runAgenticStructuredOutput( - options: TextActivityOptions, +async function runAgenticStructuredOutput< + TAdapter extends AnyTextAdapter, + TSchema extends SchemaInput, +>( + options: TextActivityOptions, ): Promise> { const { adapter, outputSchema, ...textOptions } = options const model = adapter.model @@ -1066,10 +1051,7 @@ async function runAgenticStructuredOutput( // Create the engine and run the agentic loop const engine = new TextEngine({ adapter, - params: { ...textOptions, model } as TextOptions< - Record, - Record - >, + params: { ...textOptions, model, messages: textOptions.messages ?? [] }, }) // Consume the stream to run the agentic loop diff --git a/packages/typescript/ai/src/activities/chat/stream/message-updaters.ts b/packages/typescript/ai/src/activities/chat/stream/message-updaters.ts index 5deb39f8..7bcfaee4 100644 --- a/packages/typescript/ai/src/activities/chat/stream/message-updaters.ts +++ b/packages/typescript/ai/src/activities/chat/stream/message-updaters.ts @@ -68,19 +68,25 @@ export function updateToolCallPart( (p): p is ToolCallPart => p.type === 'tool-call' && p.id === toolCall.id, ) - const toolCallPart: ToolCallPart = { - type: 'tool-call', - id: toolCall.id, - name: toolCall.name, - arguments: toolCall.arguments, - state: toolCall.state, - } - if (existingPartIndex >= 0) { // Update existing tool call - parts[existingPartIndex] = toolCallPart + const existing = parts[existingPartIndex] as ToolCallPart + parts[existingPartIndex] = { + ...existing, + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + state: toolCall.state, + } } else { // Add new tool call at the end (preserve natural streaming order) + const toolCallPart: ToolCallPart = { + type: 'tool-call', + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + state: toolCall.state, + } parts.push(toolCallPart) } diff --git a/packages/typescript/ai/tests/ai-text.test.ts b/packages/typescript/ai/tests/ai-text.test.ts index 3eee78e8..d7033e65 100644 --- a/packages/typescript/ai/tests/ai-text.test.ts +++ b/packages/typescript/ai/tests/ai-text.test.ts @@ -1472,6 +1472,86 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { expect(toolExecute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) expect(adapter.chatStreamCallCount).toBe(1) }) + + it('should execute pending tool calls from UI message parts', async () => { + const toolExecute = vi + .fn() + .mockResolvedValue(JSON.stringify({ success: true })) + + const approvalTool: Tool = { + name: 'approval_tool', + description: 'Needs approval', + inputSchema: z.object({ + path: z.string(), + }), + needsApproval: true, + execute: toolExecute, + } + + class PendingToolAdapter extends MockAdapter { + async *chatStream(options: TextOptions): AsyncIterable { + this.trackStreamCall(options) + + const toolMessage = options.messages.find((msg) => msg.role === 'tool') + expect(toolMessage).toBeDefined() + expect(toolMessage?.toolCallId).toBe('call-1') + expect(toolMessage?.content).toBe(JSON.stringify({ success: true })) + + yield { + type: 'content', + model: 'test-model', + id: 'done-id', + timestamp: Date.now(), + delta: 'Finished', + content: 'Finished', + role: 'assistant', + } + yield { + type: 'done', + model: 'test-model', + id: 'done-id', + timestamp: Date.now(), + finishReason: 'stop', + } + } + } + + const adapter = new PendingToolAdapter() + + const messages = [ + { role: 'user', content: 'Delete file' }, + { + id: 'msg-1', + role: 'assistant', + parts: [ + { + type: 'tool-call', + id: 'call-1', + name: 'approval_tool', + arguments: '{"path":"/tmp/test.txt"}', + state: 'approval-responded', + approval: { + id: 'approval_call-1', + needsApproval: true, + approved: true, + }, + }, + ], + createdAt: new Date(), + } as any, + ] + + const stream = chat({ + adapter, + messages: messages as any, + tools: [approvalTool], + }) + + const chunks = await collectChunks(stream) + expect(chunks[0]?.type).toBe('tool_result') + expect(toolExecute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) + expect(adapter.chatStreamCallCount).toBe(1) + }) }) describe('Agent Loop Strategy Paths', () => { diff --git a/testing/panel/src/routes/api.chat.ts b/testing/panel/src/routes/api.chat.ts index 4a0d29a0..4ed85388 100644 --- a/testing/panel/src/routes/api.chat.ts +++ b/testing/panel/src/routes/api.chat.ts @@ -44,13 +44,16 @@ Step 1: Call getGuitars() Step 2: Call recommendGuitar(id: "6") Step 3: Done - do NOT add any text after calling recommendGuitar ` -const addToCartToolServer = addToCartToolDef.server((args) => ({ - success: true, - cartId: 'CART_' + Date.now(), - guitarId: args.guitarId, - quantity: args.quantity, - totalItems: args.quantity, -})) +const addToCartToolServer = addToCartToolDef.server((args) => { + console.log('Executing addToCartToolServer', args) + return { + success: true, + cartId: 'CART_' + Date.now(), + guitarId: args.guitarId, + quantity: args.quantity, + totalItems: args.quantity, + } +}) type Provider = 'openai' | 'anthropic' | 'gemini' | 'ollama' | 'grok' diff --git a/testing/panel/src/routes/index.tsx b/testing/panel/src/routes/index.tsx index ed51a3e5..3d630ec8 100644 --- a/testing/panel/src/routes/index.tsx +++ b/testing/panel/src/routes/index.tsx @@ -16,7 +16,6 @@ import type { ModelOption } from '@/lib/model-selection' import GuitarRecommendation from '@/components/example-GuitarRecommendation' import { - addToCartToolDef, addToWishListToolDef, getPersonalGuitarPreferenceToolDef, recommendGuitarToolDef, @@ -46,14 +45,6 @@ const addToWishListToolClient = addToWishListToolDef.client((args) => { } }) -const addToCartToolClient = addToCartToolDef.client((args) => ({ - success: true, - cartId: 'CART_CLIENT_' + Date.now(), - guitarId: args.guitarId, - quantity: args.quantity, - totalItems: args.quantity, -})) - const recommendGuitarToolClient = recommendGuitarToolDef.client(({ id }) => ({ id: +id, })) @@ -61,7 +52,6 @@ const recommendGuitarToolClient = recommendGuitarToolDef.client(({ id }) => ({ const tools = clientTools( getPersonalGuitarPreferenceToolClient, addToWishListToolClient, - addToCartToolClient, recommendGuitarToolClient, ) @@ -455,7 +445,9 @@ function ChatPage() { const { messages, sendMessage, isLoading, addToolApprovalResponse, stop } = useChat({ - connection: fetchServerSentEvents('/api/chat'), + connection: fetchServerSentEvents('/api/chat', { + sendFullMessages: true, + }), tools, onChunk: (chunk: any) => { setChunks((prev) => [...prev, chunk])