Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions packages/typescript/ai-client/src/chat-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ export class ChatClient {
this.abortController = new AbortController()

try {
// Get model messages for the LLM
const modelMessages = this.processor.toModelMessages()
const messages = this.processor.getMessages()

// Call onResponse callback
await this.callbacksRef.current.onResponse()
Expand All @@ -307,7 +306,7 @@ export class ChatClient {

// Connect and stream
const stream = this.connection.connect(
modelMessages,
messages,
bodyWithConversationId,
this.abortController.signal,
)
Expand Down Expand Up @@ -417,6 +416,8 @@ export class ChatClient {
// Find the tool call ID from the approval ID
const messages = this.processor.getMessages()
let foundToolCallId: string | undefined
let foundToolName: string | undefined
let foundToolInput: any | undefined

for (const msg of messages) {
const toolCallPart = msg.parts.find(
Expand All @@ -425,6 +426,12 @@ export class ChatClient {
)
if (toolCallPart) {
foundToolCallId = toolCallPart.id
foundToolName = toolCallPart.name
try {
foundToolInput = JSON.parse(toolCallPart.arguments)
} catch {
// Ignore parse errors
}
break
}
}
Expand All @@ -440,6 +447,32 @@ export class ChatClient {
// Add response via processor
this.processor.addToolApprovalResponse(response.id, response.approved)

// Execute client-side tool if approved
if (response.approved && foundToolCallId && foundToolName) {
const clientTool = this.clientToolsRef.current.get(foundToolName)
if (clientTool?.execute) {
try {
const output = await clientTool.execute(foundToolInput)
await this.addToolResult({
toolCallId: foundToolCallId,
tool: foundToolName,
output,
state: 'output-available',
})
return
} catch (error: any) {
await this.addToolResult({
toolCallId: foundToolCallId,
tool: foundToolName,
output: null,
state: 'output-error',
errorText: error.message,
})
return
}
}
}

// If stream is in progress, queue continuation check for after it ends
if (this.isLoading) {
this.queuePostStreamAction(() => this.checkForContinuation())
Expand Down
19 changes: 14 additions & 5 deletions packages/typescript/ai-client/src/connection-adapters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ export interface FetchConnectionOptions {
signal?: AbortSignal
body?: Record<string, any>
fetchClient?: typeof globalThis.fetch
/**
* Send full UIMessage objects (including `parts`) instead of ModelMessages.
* Required for advanced server features that depend on UIMessage metadata
* (e.g. tool approvals and client tool results tracked in parts).
*/
sendFullMessages?: boolean
}

/**
Expand Down Expand Up @@ -138,7 +144,9 @@ export function fetchServerSentEvents(
const resolvedOptions =
typeof options === 'function' ? await options() : options

const modelMessages = convertMessagesToModelMessages(messages)
const requestMessages = resolvedOptions.sendFullMessages
? messages
: convertMessagesToModelMessages(messages)

const requestHeaders: Record<string, string> = {
'Content-Type': 'application/json',
Expand All @@ -147,7 +155,7 @@ export function fetchServerSentEvents(

// Merge body from options with messages and data
const requestBody = {
messages: modelMessages,
messages: requestMessages,
data,
...resolvedOptions.body,
}
Expand Down Expand Up @@ -238,8 +246,9 @@ export function fetchHttpStream(
const resolvedOptions =
typeof options === 'function' ? await options() : options

// Convert UIMessages to ModelMessages if needed
const modelMessages = convertMessagesToModelMessages(messages)
const requestMessages = resolvedOptions.sendFullMessages
? messages
: convertMessagesToModelMessages(messages)

const requestHeaders: Record<string, string> = {
'Content-Type': 'application/json',
Expand All @@ -248,7 +257,7 @@ export function fetchHttpStream(

// Merge body from options with messages and data
const requestBody = {
messages: modelMessages,
messages: requestMessages,
data,
...resolvedOptions.body,
}
Expand Down
133 changes: 133 additions & 0 deletions packages/typescript/ai-client/tests/chat-client-approval.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { describe, expect, it, vi } from 'vitest'
import { ChatClient } from '../src/chat-client'
import { stream } from '../src/connection-adapters'
import type { StreamChunk } from '@tanstack/ai'

function createMockConnectionAdapter(options: { chunks: StreamChunk[] }) {
return stream(async function* () {
for (const chunk of options.chunks) {
yield chunk
}
})
}

function createApprovalToolCallChunks(
toolCalls: Array<{
id: string
name: string
arguments: string
approvalId: string
}>,
): StreamChunk[] {
const chunks: StreamChunk[] = []
const timestamp = Date.now()

// Start assistant message
chunks.push({
type: 'content',
id: 'msg-1',
model: 'test-model',
timestamp,
delta: '',
content: '',
role: 'assistant',
})

for (const toolCall of toolCalls) {
// 1. Tool Call Chunk
chunks.push({
type: 'tool_call',
id: 'msg-1',
model: 'test-model',
timestamp,
toolCall: {
id: toolCall.id,
type: 'function',
function: {
name: toolCall.name,
arguments: toolCall.arguments,
},
},
index: 0,
})

// 2. Approval Requested Chunk
chunks.push({
type: 'approval-requested',
id: 'msg-1',
model: 'test-model',
timestamp,
toolCallId: toolCall.id,
toolName: toolCall.name,
input: JSON.parse(toolCall.arguments),
approval: {
id: toolCall.approvalId,
needsApproval: true,
},
} as any) // Cast to any if types are not perfectly aligned yet, or use correct type
}

// Done chunk
chunks.push({
type: 'done',
id: 'msg-1',
model: 'test-model',
timestamp,
finishReason: 'tool_calls',
})

return chunks
}

describe('ChatClient Approval Flow', () => {
it('should execute client tool when approved', async () => {
const toolName = 'delete_local_data'
const toolCallId = 'call_123'
const approvalId = 'approval_123'
const input = { key: 'test-key' }

const chunks = createApprovalToolCallChunks([
{
id: toolCallId,
name: toolName,
arguments: JSON.stringify(input),
approvalId,
},
])

const adapter = createMockConnectionAdapter({ chunks })

const execute = vi.fn().mockResolvedValue({ deleted: true })
const clientTool = {
name: toolName,
description: 'Delete data',
execute,
}

const client = new ChatClient({
connection: adapter,
tools: [clientTool],
})

// Start the flow
await client.sendMessage('Delete data')

// Wait for stream to finish (approval request should be pending)
await new Promise((resolve) => setTimeout(resolve, 100))

// Verify tool execution hasn't happened yet
expect(execute).not.toHaveBeenCalled()

// Approve the tool
await client.addToolApprovalResponse({
id: approvalId,
approved: true,
})

// Wait for execution (this is where it currently hangs/fails)
await new Promise((resolve) => setTimeout(resolve, 100))

// Expect execute to have been called
expect(execute).toHaveBeenCalledWith(input)
})
})
101 changes: 101 additions & 0 deletions packages/typescript/ai-client/tests/connection-adapters.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,107 @@ describe('connection-adapters', () => {
expect(body.data).toEqual({ key: 'value' })
})

it('should send model messages by default', async () => {
const mockReader = {
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
releaseLock: vi.fn(),
}

const mockResponse = {
ok: true,
body: {
getReader: () => mockReader,
},
}

fetchMock.mockResolvedValue(mockResponse as any)

const adapter = fetchServerSentEvents('/api/chat')

for await (const _ of adapter.connect([
{
id: 'msg_1',
role: 'assistant',
parts: [
{
type: 'tool-call',
id: 'tool_1',
name: 'testTool',
arguments: '{}',
state: 'approval-responded',
approval: { id: 'approval_tool_1', needsApproval: true, approved: true },
},
],
createdAt: new Date(),
},
] as any)) {
// Consume
}

const call = fetchMock.mock.calls[0]
const body = JSON.parse(call?.[1]?.body as string)
expect(body.messages[0]).not.toHaveProperty('parts')
expect(body.messages[0]).toMatchObject({
role: 'assistant',
})
})

it('should send full UI messages when configured', async () => {
const mockReader = {
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
releaseLock: vi.fn(),
}

const mockResponse = {
ok: true,
body: {
getReader: () => mockReader,
},
}

fetchMock.mockResolvedValue(mockResponse as any)

const adapter = fetchServerSentEvents('/api/chat', {
sendFullMessages: true,
})

const uiMessages = [
{
id: 'msg_1',
role: 'assistant',
parts: [
{
type: 'tool-call',
id: 'tool_1',
name: 'testTool',
arguments: '{}',
state: 'approval-responded',
approval: { id: 'approval_tool_1', needsApproval: true, approved: true },
},
],
createdAt: new Date(),
},
]

for await (const _ of adapter.connect(uiMessages as any)) {
// Consume
}

const call = fetchMock.mock.calls[0]
const body = JSON.parse(call?.[1]?.body as string)
expect(body.messages[0]).toHaveProperty('parts')
expect(body.messages[0]).toMatchObject({
role: 'assistant',
parts: [
{
type: 'tool-call',
id: 'tool_1',
approval: { id: 'approval_tool_1', approved: true },
},
],
})
})

it('should use custom fetchClient when provided', async () => {
const customFetch = vi.fn()
const mockReader = {
Expand Down
Loading