Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions backend/pkg/providers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (
)

const (
RepeatingToolCallThreshold = 3
maxQASectionsAfterRestore = 3
RepeatingToolCallThreshold = 3
MaxConsecutiveRepeatingDetections = 5
maxQASectionsAfterRestore = 3
keepQASectionsAfterRestore = 1
lastSecBytesAfterRestore = 16 * 1024 // 16 KB
maxBPBytesAfterRestore = 8 * 1024 // 8 KB
Expand All @@ -36,7 +37,8 @@ const (
)

type repeatingDetector struct {
funcCalls []llms.FunctionCall
funcCalls []llms.FunctionCall
consecutiveDetections int
}

func (rd *repeatingDetector) detect(toolCall llms.ToolCall) bool {
Expand All @@ -59,7 +61,16 @@ func (rd *repeatingDetector) detect(toolCall llms.ToolCall) bool {

rd.funcCalls = append(rd.funcCalls, funcCall)

return len(rd.funcCalls) >= RepeatingToolCallThreshold
if len(rd.funcCalls) >= RepeatingToolCallThreshold {
rd.consecutiveDetections++
return true
}

return false
}

func (rd *repeatingDetector) shouldError() bool {
return rd.consecutiveDetections >= MaxConsecutiveRepeatingDetections
}

func (rd *repeatingDetector) clearCallArguments(toolCall *llms.FunctionCall) llms.FunctionCall {
Expand Down
12 changes: 11 additions & 1 deletion backend/pkg/providers/performer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
maxRetriesToCallAgentChain = 3
maxRetriesToCallFunction = 3
maxReflectorCallsPerChain = 3
maxAgentChainIterations = 100
delayBetweenRetries = 5 * time.Second
)

Expand Down Expand Up @@ -91,7 +92,7 @@ func (fp *flowProvider) performAgentChain(
groupID := fmt.Sprintf("flow-%d", fp.flowID)
toolTypeMapping := tools.GetToolTypeMapping()

for {
for iteration := 0; iteration < maxAgentChainIterations; iteration++ {
result, err := fp.callWithRetries(ctx, chain, optAgentType, executor)
if err != nil {
logger.WithError(err).Error("failed to call agent chain")
Expand Down Expand Up @@ -219,6 +220,9 @@ func (fp *flowProvider) performAgentChain(
}
}
}

logger.Error("agent chain exceeded maximum iterations")
return fmt.Errorf("agent chain exceeded maximum iterations (%d)", maxAgentChainIterations)
}

func (fp *flowProvider) execToolCall(
Expand Down Expand Up @@ -271,6 +275,12 @@ func (fp *flowProvider) execToolCall(
langfuse.WithEventLevel(langfuse.ObservationLevelError),
langfuse.WithEventOutput(response),
)

if detector.shouldError() {
logger.Error("tool call repeated too many times, aborting agent chain")
return "", fmt.Errorf("tool call '%s' repeated %d consecutive times, aborting", funcName, detector.consecutiveDetections)
}

logger.Warn("failed to exec function: tool call is repeating")

return response, nil
Expand Down
84 changes: 84 additions & 0 deletions backend/pkg/providers/performer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package providers

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vxcontrol/langchaingo/llms"
)

func TestAgentChainIterationCap(t *testing.T) {
t.Run("maxAgentChainIterations is set", func(t *testing.T) {
assert.Equal(t, 100, maxAgentChainIterations,
"maxAgentChainIterations should be 100 to prevent infinite loops")
})
}

func TestRepeatingDetectorShouldError(t *testing.T) {
makeToolCall := func(name, args string) llms.ToolCall {
return llms.ToolCall{
FunctionCall: &llms.FunctionCall{
Name: name,
Arguments: args,
},
}
}

t.Run("no error before threshold", func(t *testing.T) {
rd := &repeatingDetector{}
tc := makeToolCall("test_tool", `{"key":"value"}`)

// First RepeatingToolCallThreshold calls build up to first detection
for i := 0; i < RepeatingToolCallThreshold; i++ {
rd.detect(tc)
}
assert.Equal(t, 1, rd.consecutiveDetections)
assert.False(t, rd.shouldError(),
"should not error after only 1 consecutive detection")
})

t.Run("error after max consecutive detections", func(t *testing.T) {
rd := &repeatingDetector{}
tc := makeToolCall("test_tool", `{"key":"value"}`)

// Each detection requires RepeatingToolCallThreshold identical calls.
// But after the first detection the counter keeps growing, so every
// subsequent identical call is also a detection.
// We need enough calls to reach MaxConsecutiveRepeatingDetections.
totalCallsNeeded := RepeatingToolCallThreshold + MaxConsecutiveRepeatingDetections - 1
for i := 0; i < totalCallsNeeded; i++ {
rd.detect(tc)
}

require.Equal(t, MaxConsecutiveRepeatingDetections, rd.consecutiveDetections)
assert.True(t, rd.shouldError(),
"should error after %d consecutive detections", MaxConsecutiveRepeatingDetections)
})

t.Run("different tool call resets detection", func(t *testing.T) {
rd := &repeatingDetector{}
tc1 := makeToolCall("tool_a", `{"key":"value"}`)
tc2 := makeToolCall("tool_b", `{"key":"value"}`)

// Build up detections for tool_a
for i := 0; i < RepeatingToolCallThreshold; i++ {
rd.detect(tc1)
}
assert.Equal(t, 1, rd.consecutiveDetections)

// Switch to a different tool - should reset funcCalls
rd.detect(tc2)
assert.Equal(t, 1, rd.consecutiveDetections,
"consecutive detections should not increase after a different tool call")
})

t.Run("nil function call does not trigger", func(t *testing.T) {
rd := &repeatingDetector{}
tc := llms.ToolCall{FunctionCall: nil}

detected := rd.detect(tc)
assert.False(t, detected)
assert.False(t, rd.shouldError())
})
}