diff --git a/py/core/agent/base.py b/py/core/agent/base.py index 84aae3f23..dc2c825d4 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -3,7 +3,7 @@ import logging import re from abc import ABCMeta -from typing import AsyncGenerator, Optional, Tuple +from typing import AsyncGenerator, Optional, Tuple, Dict, Any from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable from core.base.agent import Agent, Conversation @@ -42,9 +42,34 @@ def wrapper(): class R2RAgent(Agent, metaclass=CombinedMeta): def __init__(self, *args, **kwargs): self.search_results_collector = SearchResultsCollector() + self.memory_enabled = kwargs.get("memory_enabled", False) + + # Initialize mem0 client if memory is enabled + if self.memory_enabled: + try: + self.user_id = "r2r-user" + self.agent_id = "r2r-agent" + from mem0 import AsyncMemoryClient + self.mem0_client = AsyncMemoryClient() + except ImportError: + logger.warning("mem0 is not installed. Memory functionality will be disabled.") + self.memory_enabled = False + super().__init__(*args, **kwargs) self._reset() + def _prepare_mem0_params(self) -> Dict[str, Any]: + """Prepare parameters for mem0 client operations""" + if not self.memory_enabled: + return {} + + mem0_params = {} + if self.user_id: + mem0_params["user_id"] = self.user_id + if self.agent_id: + mem0_params["agent_id"] = self.agent_id + return mem0_params + async def _generate_llm_summary(self, iterations_count: int) -> str: """ Generate a summary of the conversation using the LLM when max iterations are exceeded. @@ -56,8 +81,11 @@ async def _generate_llm_summary(self, iterations_count: int) -> str: A string containing the LLM-generated summary """ try: - # Get all messages in the conversation - all_messages = await self.conversation.get_messages() + # Get all messages in the conversation - from mem0 if enabled, otherwise from conversation + if self.memory_enabled: + all_messages = await self.mem0_client.get_all(**self._prepare_mem0_params()) + else: + all_messages = await self.conversation.get_messages() # Create a prompt for the LLM to summarize summary_prompt = { @@ -95,7 +123,64 @@ async def _generate_llm_summary(self, iterations_count: int) -> str: def _reset(self): self._completed = False - self.conversation = Conversation() + + # Reset conversation or mem0 based on memory_enabled + if self.memory_enabled: + try: + # This is async but we're in a sync method - use event loop + loop = asyncio.get_event_loop() + loop.run_until_complete(self.mem0_client.delete_all(**self._prepare_mem0_params())) + except Exception as e: + logger.error(f"Failed to reset mem0: {str(e)}") + # Fallback to regular conversation if mem0 fails + self.conversation = Conversation() + else: + self.conversation = Conversation() + + async def _add_message_to_storage(self, message): + """Add a message to either mem0 or conversation based on memory_enabled""" + if not self.memory_enabled: + return await self.conversation.add_message(message) + + try: + # Convert to the format expected by Mem0 + if isinstance(message, Message): + mem0_message = { + "role": message.role, + "content": message.content + } + else: + mem0_message = { + "role": message.get("role"), + "content": message.get("content") + } + + # Get metadata from additional fields + metadata = {} + if hasattr(message, "tool_calls") and message.tool_calls: + metadata["type"] = "tool_calls" + elif hasattr(message, "structured_content") and message.structured_content: + metadata["type"] = "structured_content" + + # Add the message to memory + await self.mem0_client.add([mem0_message], **self._prepare_mem0_params(), version="v2", metadata=metadata) + return True + except Exception as e: + logger.error(f"Failed to store message in Mem0: {str(e)}") + # Fallback to regular conversation + return await self.conversation.add_message(message) + + async def _get_messages_from_storage(self): + """Get messages from either mem0 or conversation based on memory_enabled""" + if not self.memory_enabled: + return await self.conversation.get_messages() + + try: + return await self.mem0_client.get_all(**self._prepare_mem0_params()) + except Exception as e: + logger.error(f"Failed to retrieve messages from Mem0: {str(e)}") + # Fallback to regular conversation + return await self.conversation.get_messages() @syncable async def arun( @@ -110,14 +195,15 @@ async def arun( if messages: for message in messages: - await self.conversation.add_message(message) + await self._add_message_to_storage(message) + iterations_count = 0 while ( not self._completed and iterations_count < self.config.max_iterations ): iterations_count += 1 - messages_list = await self.conversation.get_messages() + messages_list = await self._get_messages_from_storage() generation_config = self.get_generation_config(messages_list[-1]) response = await self.llm_provider.aget_completion( messages_list, @@ -129,18 +215,17 @@ async def arun( if not self._completed: # Generate a summary of the conversation using the LLM summary = await self._generate_llm_summary(iterations_count) - await self.conversation.add_message( + await self._add_message_to_storage( Message(role="assistant", content=summary) ) # Return final content - all_messages: list[dict] = await self.conversation.get_messages() + all_messages: list[dict] = await self._get_messages_from_storage() all_messages.reverse() output_messages = [] for message_2 in all_messages: if ( - # message_2.get("content") message_2.get("content") != messages[-1].content ): output_messages.append(message_2) @@ -173,7 +258,7 @@ async def process_llm_response( content="", tool_calls=[msg.dict() for msg in message.tool_calls], ) - await self.conversation.add_message(assistant_msg) + await self._add_message_to_storage(assistant_msg) # If there are multiple tool_calls, call them sequentially here for tool_call in message.tool_calls: @@ -185,7 +270,7 @@ async def process_llm_response( **kwargs, ) else: - await self.conversation.add_message( + await self._add_message_to_storage( Message(role="assistant", content=message.content) ) self._completed = True @@ -208,7 +293,7 @@ async def process_llm_response( role="assistant", structured_content=message.structured_content, # Use structured_content field ) - await self.conversation.add_message(assistant_msg) + await self._add_message_to_storage(assistant_msg) # Add explicit tool_use blocks in a separate message tool_uses = [] @@ -242,7 +327,7 @@ async def process_llm_response( # Add tool_use blocks as a separate assistant message with structured content if tool_uses: - await self.conversation.add_message( + await self._add_message_to_storage( Message( role="assistant", structured_content=tool_uses, @@ -255,11 +340,11 @@ async def process_llm_response( role="assistant", structured_content=message.structured_content, ) - await self.conversation.add_message(assistant_msg) + await self._add_message_to_storage(assistant_msg) elif message.content: # For regular text content - await self.conversation.add_message( + await self._add_message_to_storage( Message(role="assistant", content=message.content) ) @@ -294,7 +379,7 @@ async def process_llm_response( } ) - await self.conversation.add_message( + await self._add_message_to_storage( Message( role="assistant", structured_content=tool_uses ) @@ -329,7 +414,7 @@ async def process_llm_response( # Add tool_use blocks as a message before processing tools if tool_uses: - await self.conversation.add_message( + await self._add_message_to_storage( Message( role="assistant", structured_content=tool_uses, @@ -1481,4 +1566,4 @@ def _parse_single_tool_call( # If all else fails, treat as a plain string value tool_params = {"value": raw_params} - return tool_name, tool_params + return tool_name, tool_params \ No newline at end of file diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index 6f3ab630a..89afd8391 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -33,7 +33,7 @@ R2RAgent, R2RStreamingAgent, R2RXMLStreamingAgent, - R2RXMLToolsAgent, + R2RXMLToolsAgent ) logger = logging.getLogger(__name__) @@ -487,6 +487,7 @@ def __init__( content_method: Callable, file_search_method: Callable, max_tool_context_length: int = 20_000, + memory_enabled: bool = False, ): # Initialize base R2RAgent R2RAgent.__init__( @@ -495,6 +496,7 @@ def __init__( llm_provider=llm_provider, config=config, rag_generation_config=rag_generation_config, + memory_enabled=memory_enabled, ) # Initialize the RAGAgentMixin RAGAgentMixin.__init__( diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 2ae4af31d..2201fc4e3 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -129,6 +129,9 @@ def create_agent( # Set streaming mode based on generation config is_streaming = generation_config.stream + # Set memory mode based on generation config + is_memory = generation_config.memory + # Create the appropriate agent based on all factors if mode == "rag": # RAG mode agents @@ -181,6 +184,7 @@ def create_agent( knowledge_search_method=knowledge_search_method, content_method=content_method, file_search_method=file_search_method, + memory_enabled=is_memory, ) else: # Research mode agents diff --git a/py/shared/abstractions/llm.py b/py/shared/abstractions/llm.py index d71e279e5..e48694e5d 100644 --- a/py/shared/abstractions/llm.py +++ b/py/shared/abstractions/llm.py @@ -174,6 +174,7 @@ class GenerationConfig(R2RSerializable): "extended_thinking": False, "thinking_budget": None, "reasoning_effort": None, + "memory": False, } model: Optional[str] = Field( @@ -226,6 +227,10 @@ class GenerationConfig(R2RSerializable): "Only applicable to OpenAI providers." ), ) + memory: bool = Field( + default=False, + description="Flag to enable memory mode to use Mem0 for context.", + ) @classmethod def set_default(cls, **kwargs):