Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .changeset/tool-race-conditions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
---
'@tanstack/ai': patch
'@tanstack/ai-client': patch
---

fix: improve tool execution reliability and prevent race conditions

- Fix client tool execution race conditions by tracking pending tool executions
- Prevent duplicate continuation attempts with continuationPending flag
- Guard against concurrent stream processing in streamResponse
- Add approval info to ToolCall type for server-side decision tracking
- Include approval info in model message conversion for approval workflows
- Check ModelMessage format for approval info extraction in chat activity

This change improves the reliability of tool execution, especially for:

- Client tools with async execute functions
- Approval-based tool workflows
- Sequential tool execution scenarios
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ vite.config.ts.timestamp-*
test-traces
**/adapters/output
.nitro
.output
.output

# Playwright
playwright-report
test-results
105 changes: 82 additions & 23 deletions packages/typescript/ai-client/src/chat-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ export class ChatClient {
private currentStreamId: string | null = null
private currentMessageId: string | null = null
private postStreamActions: Array<() => Promise<void>> = []
// Track pending client tool executions to await them before stream finalization
private pendingToolExecutions: Map<string, Promise<void>> = new Map()
// Flag to deduplicate continuation checks during action draining
private continuationPending = false

private callbacksRef: {
current: {
Expand Down Expand Up @@ -133,31 +137,41 @@ export class ChatClient {
)
}
},
onToolCall: async (args: {
onToolCall: (args: {
toolCallId: string
toolName: string
input: any
}) => {
// Handle client-side tool execution automatically
const clientTool = this.clientToolsRef.current.get(args.toolName)
if (clientTool?.execute) {
try {
const output = await clientTool.execute(args.input)
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
output,
state: 'output-available',
})
} catch (error: any) {
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
output: null,
state: 'output-error',
errorText: error.message,
})
}
const executeFunc = clientTool?.execute
if (executeFunc) {
// Create and track the execution promise
const executionPromise = (async () => {
try {
const output = await executeFunc(args.input)
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
output,
state: 'output-available',
})
} catch (error: any) {
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
output: null,
state: 'output-error',
errorText: error.message,
})
} finally {
// Remove from pending when complete
this.pendingToolExecutions.delete(args.toolCallId)
}
})()

// Track the pending execution
this.pendingToolExecutions.set(args.toolCallId, executionPromise)
}
},
onApprovalRequest: (args: {
Expand Down Expand Up @@ -238,6 +252,12 @@ export class ChatClient {
await new Promise((resolve) => setTimeout(resolve, 0))
}

// Wait for all pending tool executions to complete before finalizing
// This ensures client tools finish before we check for continuation
if (this.pendingToolExecutions.size > 0) {
await Promise.all(this.pendingToolExecutions.values())
}

// Finalize the stream
this.processor.finalizeStream()

Expand Down Expand Up @@ -298,21 +318,35 @@ export class ChatClient {
const messages = this.processor.getMessages()
this.processor.setMessages([...messages, uiMessage])

// If stream is in progress, queue the response for after it ends
if (this.isLoading) {
this.queuePostStreamAction(() => this.streamResponse())
return
}

await this.streamResponse()
}

/**
* Stream a response from the LLM
*/
private async streamResponse(): Promise<void> {
// Guard against concurrent streams - if already loading, skip
if (this.isLoading) {
return
}

this.setIsLoading(true)
this.setStatus('submitted')
this.setError(undefined)
this.abortController = new AbortController()
// Reset pending tool executions for the new stream
this.pendingToolExecutions.clear()
let streamCompletedSuccessfully = false

try {
// Get model messages for the LLM
const modelMessages = this.processor.toModelMessages()
// Get UIMessages with parts (preserves approval state and client tool results)
const messages = this.processor.getMessages()

// Call onResponse callback
await this.callbacksRef.current.onResponse()
Expand All @@ -325,12 +359,13 @@ export class ChatClient {

// Connect and stream
const stream = this.connection.connect(
modelMessages,
messages,
bodyWithConversationId,
this.abortController.signal,
)

await this.processStream(stream)
streamCompletedSuccessfully = true
} catch (err) {
if (err instanceof Error) {
if (err.name === 'AbortError') {
Expand All @@ -346,6 +381,20 @@ export class ChatClient {

// Drain any actions that were queued while the stream was in progress
await this.drainPostStreamActions()

// Continue conversation if the stream ended with a tool result (server tool completed)
if (streamCompletedSuccessfully) {
const messages = this.processor.getMessages()
const lastPart = messages.at(-1)?.parts.at(-1)

if (lastPart?.type === 'tool-result' && this.shouldAutoSend()) {
try {
await this.checkForContinuation()
} catch (error) {
console.error('Failed to continue flow after tool result:', error)
}
}
}
Comment on lines +384 to +397
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and find the chat-client.ts file
find . -name "chat-client.ts" -type f

Repository: TanStack/ai

Length of output: 105


🏁 Script executed:

# Read the specific lines mentioned in the review (384-397)
sed -n '370,420p' packages/typescript/ai-client/src/chat-client.ts | cat -n

Repository: TanStack/ai

Length of output: 1896


🏁 Script executed:

# Look at the drainPostStreamActions method
rg -n "drainPostStreamActions" packages/typescript/ai-client/src/chat-client.ts -A 20

Repository: TanStack/ai

Length of output: 1426


🏁 Script executed:

# Look at the checkForContinuation method
rg -n "checkForContinuation" packages/typescript/ai-client/src/chat-client.ts -A 15

Repository: TanStack/ai

Length of output: 2180


🏁 Script executed:

# Check for any tests related to tool continuation and nested streams
find . -name "*.test.ts" -o -name "*.spec.ts" | xargs rg -l "tool.*result|continuation" | head -5

Repository: TanStack/ai

Length of output: 444


🏁 Script executed:

# Check the server-client-sequence test for nested tool scenarios
head -100 packages/typescript/smoke-tests/e2e/tests/tools/server-client-sequence.spec.ts

Repository: TanStack/ai

Length of output: 3195


🏁 Script executed:

# Look for tests that involve multiple tool calls or nested tool scenarios
rg -n "tool.*result|nested|continuation" packages/typescript/smoke-tests/e2e/tests/tools/server-client-sequence.spec.ts -B 2 -A 5

Repository: TanStack/ai

Length of output: 702


🏁 Script executed:

# Check if there are any comments in chat-client.ts explaining the streamCompletedSuccessfully logic
rg -n "streamCompletedSuccessfully" packages/typescript/ai-client/src/chat-client.ts -B 5 -A 2

Repository: TanStack/ai

Length of output: 993


🏁 Script executed:

# Look at the broader context where streamCompletedSuccessfully is set
sed -n '300,400p' packages/typescript/ai-client/src/chat-client.ts | cat -n

Repository: TanStack/ai

Length of output: 3857


Consider adding a clarifying comment about state synchronization after nested streams.

The implementation correctly handles nested streams from drainPostStreamActions(): streamCompletedSuccessfully reflects the original stream's completion, while lastPart fetches the current message state (potentially updated by nested streams). This prevents spurious continuations when nested streams produce text responses.

Adding a comment like the following would clarify this pattern for future maintainers:

// Note: lastPart reflects the current message state (potentially updated by
// drainPostStreamActions), not the original stream's final state.
// If nested streams produced a text response, lastPart won't be 'tool-result'.
🤖 Prompt for AI Agents
In `@packages/typescript/ai-client/src/chat-client.ts` around lines 384 - 397, Add
a brief clarifying comment above the block that checks
streamCompletedSuccessfully to explain that lastPart is read from
this.processor.getMessages() which may have been updated by nested streams
(e.g., via drainPostStreamActions), so lastPart reflects the current message
state rather than the original stream's final state and therefore prevents
triggering a continuation when nested streams produced text; reference the
surrounding symbols like streamCompletedSuccessfully,
this.processor.getMessages(), lastPart, shouldAutoSend(), and
checkForContinuation() to show the intent.

}
}

Expand Down Expand Up @@ -490,8 +539,18 @@ export class ChatClient {
* Check if we should continue the flow and do so if needed
*/
private async checkForContinuation(): Promise<void> {
// Prevent duplicate continuation attempts
if (this.continuationPending || this.isLoading) {
return
}

if (this.shouldAutoSend()) {
await this.streamResponse()
this.continuationPending = true
try {
await this.streamResponse()
} finally {
this.continuationPending = false
}
}
}

Expand Down
32 changes: 14 additions & 18 deletions packages/typescript/ai-client/src/connection-adapters.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { convertMessagesToModelMessages } from '@tanstack/ai'
import type { ModelMessage, StreamChunk, UIMessage } from '@tanstack/ai'

/**
Expand Down Expand Up @@ -138,16 +137,15 @@ export function fetchServerSentEvents(
const resolvedOptions =
typeof options === 'function' ? await options() : options

const modelMessages = convertMessagesToModelMessages(messages)

const requestHeaders: Record<string, string> = {
'Content-Type': 'application/json',
...mergeHeaders(resolvedOptions.headers),
}

// Merge body from options with messages and data
// Send messages as-is (UIMessages with parts preserved)
// Server-side TextEngine handles conversion to ModelMessages
const requestBody = {
messages: modelMessages,
messages,
data,
...resolvedOptions.body,
}
Expand Down Expand Up @@ -238,17 +236,15 @@ export function fetchHttpStream(
const resolvedOptions =
typeof options === 'function' ? await options() : options

// Convert UIMessages to ModelMessages if needed
const modelMessages = convertMessagesToModelMessages(messages)

const requestHeaders: Record<string, string> = {
'Content-Type': 'application/json',
...mergeHeaders(resolvedOptions.headers),
}

// Merge body from options with messages and data
// Send messages as-is (UIMessages with parts preserved)
// Server-side TextEngine handles conversion to ModelMessages
const requestBody = {
messages: modelMessages,
messages,
data,
...resolvedOptions.body,
}
Expand Down Expand Up @@ -302,14 +298,15 @@ export function fetchHttpStream(
*/
export function stream(
streamFactory: (
messages: Array<ModelMessage>,
messages: Array<UIMessage> | Array<ModelMessage>,
data?: Record<string, any>,
) => AsyncIterable<StreamChunk>,
): ConnectionAdapter {
return {
async *connect(messages, data) {
const modelMessages = convertMessagesToModelMessages(messages)
yield* streamFactory(modelMessages, data)
// Pass messages as-is (UIMessages with parts preserved)
// Server-side chat() handles conversion to ModelMessages
yield* streamFactory(messages, data)
},
}
}
Expand All @@ -332,16 +329,15 @@ export function stream(
*/
export function rpcStream(
rpcCall: (
messages: Array<ModelMessage>,
messages: Array<UIMessage> | Array<ModelMessage>,
data?: Record<string, any>,
) => AsyncIterable<StreamChunk>,
): ConnectionAdapter {
return {
async *connect(messages, data) {
const modelMessages = convertMessagesToModelMessages(messages)
// Simply yield from the RPC call
// The RPC layer handles WebSocket transport
yield* rpcCall(modelMessages, data)
// Pass messages as-is (UIMessages with parts preserved)
// Server-side chat() handles conversion to ModelMessages
yield* rpcCall(messages, data)
},
}
}
Loading
Loading