diff --git a/backend/pkg/providers/helpers.go b/backend/pkg/providers/helpers.go index 81d3e42..d42f358 100644 --- a/backend/pkg/providers/helpers.go +++ b/backend/pkg/providers/helpers.go @@ -25,8 +25,9 @@ import ( ) const ( - RepeatingToolCallThreshold = 3 - maxQASectionsAfterRestore = 3 + RepeatingToolCallThreshold = 3 + MaxConsecutiveRepeatingDetections = 5 + maxQASectionsAfterRestore = 3 keepQASectionsAfterRestore = 1 lastSecBytesAfterRestore = 16 * 1024 // 16 KB maxBPBytesAfterRestore = 8 * 1024 // 8 KB @@ -36,7 +37,8 @@ const ( ) type repeatingDetector struct { - funcCalls []llms.FunctionCall + funcCalls []llms.FunctionCall + consecutiveDetections int } func (rd *repeatingDetector) detect(toolCall llms.ToolCall) bool { @@ -59,7 +61,16 @@ func (rd *repeatingDetector) detect(toolCall llms.ToolCall) bool { rd.funcCalls = append(rd.funcCalls, funcCall) - return len(rd.funcCalls) >= RepeatingToolCallThreshold + if len(rd.funcCalls) >= RepeatingToolCallThreshold { + rd.consecutiveDetections++ + return true + } + + return false +} + +func (rd *repeatingDetector) shouldError() bool { + return rd.consecutiveDetections >= MaxConsecutiveRepeatingDetections } func (rd *repeatingDetector) clearCallArguments(toolCall *llms.FunctionCall) llms.FunctionCall { diff --git a/backend/pkg/providers/performer.go b/backend/pkg/providers/performer.go index 446a0ca..b342b63 100644 --- a/backend/pkg/providers/performer.go +++ b/backend/pkg/providers/performer.go @@ -30,6 +30,7 @@ const ( maxRetriesToCallAgentChain = 3 maxRetriesToCallFunction = 3 maxReflectorCallsPerChain = 3 + maxAgentChainIterations = 100 delayBetweenRetries = 5 * time.Second ) @@ -91,7 +92,7 @@ func (fp *flowProvider) performAgentChain( groupID := fmt.Sprintf("flow-%d", fp.flowID) toolTypeMapping := tools.GetToolTypeMapping() - for { + for iteration := 0; iteration < maxAgentChainIterations; iteration++ { result, err := fp.callWithRetries(ctx, chain, optAgentType, executor) if err != nil { logger.WithError(err).Error("failed to call agent chain") @@ -219,6 +220,9 @@ func (fp *flowProvider) performAgentChain( } } } + + logger.Error("agent chain exceeded maximum iterations") + return fmt.Errorf("agent chain exceeded maximum iterations (%d)", maxAgentChainIterations) } func (fp *flowProvider) execToolCall( @@ -271,6 +275,12 @@ func (fp *flowProvider) execToolCall( langfuse.WithEventLevel(langfuse.ObservationLevelError), langfuse.WithEventOutput(response), ) + + if detector.shouldError() { + logger.Error("tool call repeated too many times, aborting agent chain") + return "", fmt.Errorf("tool call '%s' repeated %d consecutive times, aborting", funcName, detector.consecutiveDetections) + } + logger.Warn("failed to exec function: tool call is repeating") return response, nil diff --git a/backend/pkg/providers/performer_test.go b/backend/pkg/providers/performer_test.go new file mode 100644 index 0000000..4811329 --- /dev/null +++ b/backend/pkg/providers/performer_test.go @@ -0,0 +1,84 @@ +package providers + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vxcontrol/langchaingo/llms" +) + +func TestAgentChainIterationCap(t *testing.T) { + t.Run("maxAgentChainIterations is set", func(t *testing.T) { + assert.Equal(t, 100, maxAgentChainIterations, + "maxAgentChainIterations should be 100 to prevent infinite loops") + }) +} + +func TestRepeatingDetectorShouldError(t *testing.T) { + makeToolCall := func(name, args string) llms.ToolCall { + return llms.ToolCall{ + FunctionCall: &llms.FunctionCall{ + Name: name, + Arguments: args, + }, + } + } + + t.Run("no error before threshold", func(t *testing.T) { + rd := &repeatingDetector{} + tc := makeToolCall("test_tool", `{"key":"value"}`) + + // First RepeatingToolCallThreshold calls build up to first detection + for i := 0; i < RepeatingToolCallThreshold; i++ { + rd.detect(tc) + } + assert.Equal(t, 1, rd.consecutiveDetections) + assert.False(t, rd.shouldError(), + "should not error after only 1 consecutive detection") + }) + + t.Run("error after max consecutive detections", func(t *testing.T) { + rd := &repeatingDetector{} + tc := makeToolCall("test_tool", `{"key":"value"}`) + + // Each detection requires RepeatingToolCallThreshold identical calls. + // But after the first detection the counter keeps growing, so every + // subsequent identical call is also a detection. + // We need enough calls to reach MaxConsecutiveRepeatingDetections. + totalCallsNeeded := RepeatingToolCallThreshold + MaxConsecutiveRepeatingDetections - 1 + for i := 0; i < totalCallsNeeded; i++ { + rd.detect(tc) + } + + require.Equal(t, MaxConsecutiveRepeatingDetections, rd.consecutiveDetections) + assert.True(t, rd.shouldError(), + "should error after %d consecutive detections", MaxConsecutiveRepeatingDetections) + }) + + t.Run("different tool call resets detection", func(t *testing.T) { + rd := &repeatingDetector{} + tc1 := makeToolCall("tool_a", `{"key":"value"}`) + tc2 := makeToolCall("tool_b", `{"key":"value"}`) + + // Build up detections for tool_a + for i := 0; i < RepeatingToolCallThreshold; i++ { + rd.detect(tc1) + } + assert.Equal(t, 1, rd.consecutiveDetections) + + // Switch to a different tool - should reset funcCalls + rd.detect(tc2) + assert.Equal(t, 1, rd.consecutiveDetections, + "consecutive detections should not increase after a different tool call") + }) + + t.Run("nil function call does not trigger", func(t *testing.T) { + rd := &repeatingDetector{} + tc := llms.ToolCall{FunctionCall: nil} + + detected := rd.detect(tc) + assert.False(t, detected) + assert.False(t, rd.shouldError()) + }) +}