|
9 | 9 | from strands import Agent as StrandsAgent |
10 | 10 | from strands.models import BedrockModel |
11 | 11 |
|
12 | | -from .config import extract_model_info, get_system_prompt |
| 12 | +from .config import extract_model_info, get_system_prompt, get_max_iterations |
13 | 13 | from .tools import ToolManager |
14 | 14 | from .types import Message, ModelInfo |
15 | 15 | from .utils import process_messages, process_prompt |
16 | 16 |
|
17 | 17 | logger = logging.getLogger(__name__) |
18 | 18 |
|
| 19 | +class IterationLimitExceededError(Exception): |
| 20 | + """Exception raised when iteration limit is exceeded""" |
| 21 | + pass |
19 | 22 |
|
20 | 23 | class AgentManager: |
21 | 24 | """Manages Strands agent creation and execution.""" |
22 | 25 |
|
23 | 26 | def __init__(self): |
24 | 27 | self.tool_manager = ToolManager() |
| 28 | + self.max_iterations = get_max_iterations() |
| 29 | + self.iteration_count = 0 |
25 | 30 |
|
26 | 31 | def set_session_info(self, session_id: str, trace_id: str): |
27 | 32 | """Set session and trace IDs""" |
28 | 33 | self.tool_manager.set_session_info(session_id, trace_id) |
29 | | - |
| 34 | + self.iteration_count = 0 |
| 35 | + |
| 36 | + def iteration_limit_handler(self, **ev): |
| 37 | + if ev.get("start_event_loop"): |
| 38 | + self.iteration_count += 1 |
| 39 | + if self.iteration_count > self.max_iterations: |
| 40 | + raise IterationLimitExceededError( |
| 41 | + f"Event loop reached maximum iteration count ({self.max_iterations}). Please contact the administrator." |
| 42 | + ) |
30 | 43 | async def process_request_streaming( |
31 | 44 | self, |
32 | 45 | messages: list[Message] | list[dict[str, Any]], |
@@ -64,6 +77,7 @@ async def process_request_streaming( |
64 | 77 | messages=processed_messages, |
65 | 78 | model=bedrock_model, |
66 | 79 | tools=tools, |
| 80 | + callback_handler=self.iteration_limit_handler, |
67 | 81 | ) |
68 | 82 |
|
69 | 83 | async for event in agent.stream_async(processed_prompt): |
|
0 commit comments