Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions py/core/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from abc import ABCMeta
from typing import AsyncGenerator, Optional, Tuple
from typing import AsyncGenerator, Optional, Tuple, Dict, Any

from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
from core.base.agent import Agent, Conversation
Expand Down Expand Up @@ -42,9 +42,34 @@ def wrapper():
class R2RAgent(Agent, metaclass=CombinedMeta):
def __init__(self, *args, **kwargs):
self.search_results_collector = SearchResultsCollector()
self.memory_enabled = kwargs.get("memory_enabled", False)

# Initialize mem0 client if memory is enabled
if self.memory_enabled:
try:
self.user_id = "r2r-user"
self.agent_id = "r2r-agent"
from mem0 import AsyncMemoryClient
self.mem0_client = AsyncMemoryClient()
except ImportError:
logger.warning("mem0 is not installed. Memory functionality will be disabled.")
self.memory_enabled = False

super().__init__(*args, **kwargs)
self._reset()

def _prepare_mem0_params(self) -> Dict[str, Any]:
"""Prepare parameters for mem0 client operations"""
if not self.memory_enabled:
return {}

mem0_params = {}
if self.user_id:
mem0_params["user_id"] = self.user_id
if self.agent_id:
mem0_params["agent_id"] = self.agent_id
return mem0_params

async def _generate_llm_summary(self, iterations_count: int) -> str:
"""
Generate a summary of the conversation using the LLM when max iterations are exceeded.
Expand All @@ -56,8 +81,11 @@ async def _generate_llm_summary(self, iterations_count: int) -> str:
A string containing the LLM-generated summary
"""
try:
# Get all messages in the conversation
all_messages = await self.conversation.get_messages()
# Get all messages in the conversation - from mem0 if enabled, otherwise from conversation
if self.memory_enabled:
all_messages = await self.mem0_client.get_all(**self._prepare_mem0_params())
else:
all_messages = await self.conversation.get_messages()

# Create a prompt for the LLM to summarize
summary_prompt = {
Expand Down Expand Up @@ -95,7 +123,64 @@ async def _generate_llm_summary(self, iterations_count: int) -> str:

def _reset(self):
self._completed = False
self.conversation = Conversation()

# Reset conversation or mem0 based on memory_enabled
if self.memory_enabled:
try:
# This is async but we're in a sync method - use event loop
loop = asyncio.get_event_loop()
loop.run_until_complete(self.mem0_client.delete_all(**self._prepare_mem0_params()))
except Exception as e:
logger.error(f"Failed to reset mem0: {str(e)}")
# Fallback to regular conversation if mem0 fails
self.conversation = Conversation()
else:
self.conversation = Conversation()

async def _add_message_to_storage(self, message):
"""Add a message to either mem0 or conversation based on memory_enabled"""
if not self.memory_enabled:
return await self.conversation.add_message(message)

try:
# Convert to the format expected by Mem0
if isinstance(message, Message):
mem0_message = {
"role": message.role,
"content": message.content
}
else:
mem0_message = {
"role": message.get("role"),
"content": message.get("content")
}

# Get metadata from additional fields
metadata = {}
if hasattr(message, "tool_calls") and message.tool_calls:
metadata["type"] = "tool_calls"
elif hasattr(message, "structured_content") and message.structured_content:
metadata["type"] = "structured_content"

# Add the message to memory
await self.mem0_client.add([mem0_message], **self._prepare_mem0_params(), version="v2", metadata=metadata)
return True
except Exception as e:
logger.error(f"Failed to store message in Mem0: {str(e)}")
# Fallback to regular conversation
return await self.conversation.add_message(message)

async def _get_messages_from_storage(self):
"""Get messages from either mem0 or conversation based on memory_enabled"""
if not self.memory_enabled:
return await self.conversation.get_messages()

try:
return await self.mem0_client.get_all(**self._prepare_mem0_params())
except Exception as e:
logger.error(f"Failed to retrieve messages from Mem0: {str(e)}")
# Fallback to regular conversation
return await self.conversation.get_messages()

@syncable
async def arun(
Expand All @@ -110,14 +195,15 @@ async def arun(

if messages:
for message in messages:
await self.conversation.add_message(message)
await self._add_message_to_storage(message)

iterations_count = 0
while (
not self._completed
and iterations_count < self.config.max_iterations
):
iterations_count += 1
messages_list = await self.conversation.get_messages()
messages_list = await self._get_messages_from_storage()
generation_config = self.get_generation_config(messages_list[-1])
response = await self.llm_provider.aget_completion(
messages_list,
Expand All @@ -129,18 +215,17 @@ async def arun(
if not self._completed:
# Generate a summary of the conversation using the LLM
summary = await self._generate_llm_summary(iterations_count)
await self.conversation.add_message(
await self._add_message_to_storage(
Message(role="assistant", content=summary)
)

# Return final content
all_messages: list[dict] = await self.conversation.get_messages()
all_messages: list[dict] = await self._get_messages_from_storage()
all_messages.reverse()

output_messages = []
for message_2 in all_messages:
if (
# message_2.get("content")
message_2.get("content") != messages[-1].content
):
output_messages.append(message_2)
Expand Down Expand Up @@ -173,7 +258,7 @@ async def process_llm_response(
content="",
tool_calls=[msg.dict() for msg in message.tool_calls],
)
await self.conversation.add_message(assistant_msg)
await self._add_message_to_storage(assistant_msg)

# If there are multiple tool_calls, call them sequentially here
for tool_call in message.tool_calls:
Expand All @@ -185,7 +270,7 @@ async def process_llm_response(
**kwargs,
)
else:
await self.conversation.add_message(
await self._add_message_to_storage(
Message(role="assistant", content=message.content)
)
self._completed = True
Expand All @@ -208,7 +293,7 @@ async def process_llm_response(
role="assistant",
structured_content=message.structured_content, # Use structured_content field
)
await self.conversation.add_message(assistant_msg)
await self._add_message_to_storage(assistant_msg)

# Add explicit tool_use blocks in a separate message
tool_uses = []
Expand Down Expand Up @@ -242,7 +327,7 @@ async def process_llm_response(

# Add tool_use blocks as a separate assistant message with structured content
if tool_uses:
await self.conversation.add_message(
await self._add_message_to_storage(
Message(
role="assistant",
structured_content=tool_uses,
Expand All @@ -255,11 +340,11 @@ async def process_llm_response(
role="assistant",
structured_content=message.structured_content,
)
await self.conversation.add_message(assistant_msg)
await self._add_message_to_storage(assistant_msg)

elif message.content:
# For regular text content
await self.conversation.add_message(
await self._add_message_to_storage(
Message(role="assistant", content=message.content)
)

Expand Down Expand Up @@ -294,7 +379,7 @@ async def process_llm_response(
}
)

await self.conversation.add_message(
await self._add_message_to_storage(
Message(
role="assistant", structured_content=tool_uses
)
Expand Down Expand Up @@ -329,7 +414,7 @@ async def process_llm_response(

# Add tool_use blocks as a message before processing tools
if tool_uses:
await self.conversation.add_message(
await self._add_message_to_storage(
Message(
role="assistant",
structured_content=tool_uses,
Expand Down Expand Up @@ -1481,4 +1566,4 @@ def _parse_single_tool_call(
# If all else fails, treat as a plain string value
tool_params = {"value": raw_params}

return tool_name, tool_params
return tool_name, tool_params
4 changes: 3 additions & 1 deletion py/core/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
R2RAgent,
R2RStreamingAgent,
R2RXMLStreamingAgent,
R2RXMLToolsAgent,
R2RXMLToolsAgent
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -487,6 +487,7 @@ def __init__(
content_method: Callable,
file_search_method: Callable,
max_tool_context_length: int = 20_000,
memory_enabled: bool = False,
):
# Initialize base R2RAgent
R2RAgent.__init__(
Expand All @@ -495,6 +496,7 @@ def __init__(
llm_provider=llm_provider,
config=config,
rag_generation_config=rag_generation_config,
memory_enabled=memory_enabled,
)
# Initialize the RAGAgentMixin
RAGAgentMixin.__init__(
Expand Down
4 changes: 4 additions & 0 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def create_agent(
# Set streaming mode based on generation config
is_streaming = generation_config.stream

# Set memory mode based on generation config
is_memory = generation_config.memory

# Create the appropriate agent based on all factors
if mode == "rag":
# RAG mode agents
Expand Down Expand Up @@ -181,6 +184,7 @@ def create_agent(
knowledge_search_method=knowledge_search_method,
content_method=content_method,
file_search_method=file_search_method,
memory_enabled=is_memory,
)
else:
# Research mode agents
Expand Down
5 changes: 5 additions & 0 deletions py/shared/abstractions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class GenerationConfig(R2RSerializable):
"extended_thinking": False,
"thinking_budget": None,
"reasoning_effort": None,
"memory": False,
}

model: Optional[str] = Field(
Expand Down Expand Up @@ -226,6 +227,10 @@ class GenerationConfig(R2RSerializable):
"Only applicable to OpenAI providers."
),
)
memory: bool = Field(
default=False,
description="Flag to enable memory mode to use Mem0 for context.",
)

@classmethod
def set_default(cls, **kwargs):
Expand Down