diff --git a/lib/crewai/src/crewai/events/lifecycle_decorator.py b/lib/crewai/src/crewai/events/lifecycle_decorator.py new file mode 100644 index 0000000000..fcd4584330 --- /dev/null +++ b/lib/crewai/src/crewai/events/lifecycle_decorator.py @@ -0,0 +1,388 @@ +"""Decorators for automatic event lifecycle management. + +This module provides decorators that automatically emit started/completed/failed +events for methods, reducing boilerplate code across the codebase. +""" + +from collections.abc import Callable +from functools import wraps +import time +from typing import Any, Concatenate, Literal, ParamSpec, TypeVar, TypedDict, cast + +from crewai.events.base_events import BaseEvent +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.crew_events import ( + CrewKickoffCompletedEvent, + CrewKickoffFailedEvent, + CrewKickoffStartedEvent, + CrewTestCompletedEvent, + CrewTestFailedEvent, + CrewTestStartedEvent, + CrewTrainCompletedEvent, + CrewTrainFailedEvent, + CrewTrainStartedEvent, +) +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryFailedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveFailedEvent, + MemorySaveStartedEvent, +) +from crewai.events.types.task_events import ( + TaskCompletedEvent, + TaskFailedEvent, + TaskStartedEvent, +) + + +P = ParamSpec("P") +R = TypeVar("R") + +EventPrefix = Literal[ + "task", + "memory_save", + "memory_query", + "crew_kickoff", + "crew_train", + "crew_test", +] + +EventParams = dict[str, Any] + +StartedParamsFn = Callable[[Any, tuple[Any, ...], dict[str, Any]], EventParams] +CompletedParamsFn = Callable[ + [Any, tuple[Any, ...], dict[str, Any], Any, float], EventParams +] +FailedParamsFn = Callable[ + [Any, tuple[Any, ...], dict[str, Any], Exception], EventParams +] + + +class LifecycleEventClasses(TypedDict): + """Mapping of lifecycle event types to their corresponding event classes.""" + + started: type[BaseEvent] + completed: type[BaseEvent] + failed: type[BaseEvent] + + +class EventClassMap(TypedDict): + """Mapping of event prefixes to their lifecycle event classes.""" + + task: LifecycleEventClasses + memory_save: LifecycleEventClasses + memory_query: LifecycleEventClasses + crew_kickoff: LifecycleEventClasses + crew_train: LifecycleEventClasses + crew_test: LifecycleEventClasses + + +class LifecycleParamExtractors(TypedDict): + """Parameter extractors for lifecycle events.""" + + started_params: StartedParamsFn + completed_params: CompletedParamsFn + failed_params: FailedParamsFn + + +EVENT_CLASS_MAP: EventClassMap = { + "task": { + "started": TaskStartedEvent, + "completed": TaskCompletedEvent, + "failed": TaskFailedEvent, + }, + "memory_save": { + "started": MemorySaveStartedEvent, + "completed": MemorySaveCompletedEvent, + "failed": MemorySaveFailedEvent, + }, + "memory_query": { + "started": MemoryQueryStartedEvent, + "completed": MemoryQueryCompletedEvent, + "failed": MemoryQueryFailedEvent, + }, + "crew_kickoff": { + "started": CrewKickoffStartedEvent, + "completed": CrewKickoffCompletedEvent, + "failed": CrewKickoffFailedEvent, + }, + "crew_train": { + "started": CrewTrainStartedEvent, + "completed": CrewTrainCompletedEvent, + "failed": CrewTrainFailedEvent, + }, + "crew_test": { + "started": CrewTestStartedEvent, + "completed": CrewTestCompletedEvent, + "failed": CrewTestFailedEvent, + }, +} + + +def _extract_arg( + position: str | int, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Any: + """Extract argument by name from kwargs or by position from args. + + Args: + position: Argument name (str) or positional index (int). + args: Positional arguments tuple. + kwargs: Keyword arguments dict. + + Returns: + Extracted argument value or None if not found. + """ + if isinstance(position, str): + return kwargs.get(position) + try: + return args[position] + except IndexError: + return None + + +def lifecycle_params( + *, + args_map: dict[str, str | int] | None = None, + context: dict[str, Any | Callable[[Any], Any]] | None = None, + result_name: str | None = None, + elapsed_name: str = "elapsed_ms", +) -> LifecycleParamExtractors: + """Helper to create lifecycle event parameter extractors with reduced boilerplate. + + This function generates the three parameter extractors (started_params, completed_params, + failed_params) needed by @with_lifecycle_events, following common patterns and reducing + code duplication. + + Args: + args_map: Maps event parameter names to function argument names (str) or positions (int). + Example: {"query": "query", "value": 0} extracts kwargs["query"] and args[0] + context: Static or dynamic context fields included in all events. + Values can be static (Any) or callables that receive self and return a value. + Example: {"source_type": "external_memory", "from_agent": lambda self: self.agent} + result_name: Name for the result in completed_params (e.g., "results", "output"). + If None, result is not included in the event. + elapsed_name: Name for elapsed time in completed_params (default: "elapsed_ms"). + + Returns: + Dictionary with keys "started_params", "completed_params", "failed_params" + containing the appropriate lambda functions for @with_lifecycle_events. + + Example: + >>> param_extractors = lifecycle_params( + ... args_map={"value": "value", "metadata": "metadata"}, + ... context={ + ... "source_type": "external_memory", + ... "from_agent": lambda self: self.agent, + ... "from_task": lambda self: self.task, + ... }, + ... elapsed_name="save_time_ms", + ... ) + >>> param_extractors["started_params"] # doctest: +ELLIPSIS + .started_params_fn at 0x...> + """ + args_map = args_map or {} + context = context or {} + + static_context: EventParams = {} + dynamic_context: dict[str, Callable[[Any], Any]] = {} + for ctx_key, ctx_value in context.items(): + if callable(ctx_value): + dynamic_context[ctx_key] = ctx_value + else: + static_context[ctx_key] = ctx_value + + def started_params_fn( + self: Any, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> EventParams: + """Extract parameters for started event. + + Args: + self: Instance emitting the event. + args: Positional arguments from decorated method. + kwargs: Keyword arguments from decorated method. + + Returns: + Parameters for started event. + """ + params: EventParams = {**static_context} + for param_name, arg_spec in args_map.items(): + params[param_name] = _extract_arg(arg_spec, args, kwargs) + for key, func in dynamic_context.items(): + params[key] = func(self) + return params + + def completed_params_fn( + self: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + result: Any, + elapsed_ms: float, + ) -> EventParams: + """Extract parameters for completed event. + + Args: + self: Instance emitting the event. + args: Positional arguments from decorated method. + kwargs: Keyword arguments from decorated method. + result: Return value from decorated method. + elapsed_ms: Elapsed execution time in milliseconds. + + Returns: + Parameters for completed event. + """ + params: EventParams = {**static_context} + for param_name, arg_spec in args_map.items(): + params[param_name] = _extract_arg(arg_spec, args, kwargs) + if result_name is not None: + params[result_name] = result + params[elapsed_name] = elapsed_ms + for key, func in dynamic_context.items(): + params[key] = func(self) + return params + + def failed_params_fn( + self: Any, args: tuple[Any, ...], kwargs: dict[str, Any], exc: Exception + ) -> EventParams: + """Extract parameters for failed event. + + Args: + self: Instance emitting the event. + args: Positional arguments from decorated method. + kwargs: Keyword arguments from decorated method. + exc: Exception raised during execution. + + Returns: + Parameters for failed event. + """ + params: EventParams = {**static_context} + for param_name, arg_spec in args_map.items(): + params[param_name] = _extract_arg(arg_spec, args, kwargs) + params["error"] = str(exc) + for key, func in dynamic_context.items(): + params[key] = func(self) + return params + + return { + "started_params": started_params_fn, + "completed_params": completed_params_fn, + "failed_params": failed_params_fn, + } + + +def with_lifecycle_events( + prefix: EventPrefix, + *, + args_map: dict[str, str | int] | None = None, + context: dict[str, Any | Callable[[Any], Any]] | None = None, + result_name: str | None = None, + elapsed_name: str = "elapsed_ms", +) -> Callable[[Callable[Concatenate[Any, P], R]], Callable[Concatenate[Any, P], R]]: + """Decorator to automatically emit lifecycle events (started/completed/failed). + + This decorator wraps a method to emit events at different stages of execution: + - StartedEvent: Emitted before method execution + - CompletedEvent: Emitted after successful execution (includes timing via monotonic_ns) + - FailedEvent: Emitted if an exception occurs (re-raises the exception) + + Args: + prefix: Event prefix from the EventPrefix Literal type. Determines which + event classes to use (e.g., "task" -> TaskStartedEvent, etc.) + args_map: Maps event parameter names to function argument names (str) or positions (int). + Example: {"query": "query", "value": 0} extracts kwargs["query"] and args[0] + context: Static or dynamic context fields included in all events. + Values can be static (Any) or callables that receive self and return a value. + Example: {"source_type": "external_memory", "from_agent": lambda self: self.agent} + result_name: Name for the result in completed_params (e.g., "results", "output"). + If None, result is not included in the event. + elapsed_name: Name for elapsed time in completed_params (default: "elapsed_ms"). + + Returns: + Decorated function that emits lifecycle events. + + Example: + >>> @with_lifecycle_events( + ... "memory_save", + ... args_map={"value": "value", "metadata": "metadata"}, + ... context={ + ... "source_type": "external_memory", + ... "from_agent": lambda self: self.agent, + ... }, + ... elapsed_name="save_time_ms", + ... ) + ... def save(self, value: Any, metadata: dict[str, Any] | None = None) -> None: + ... pass + """ + param_extractors = lifecycle_params( + args_map=args_map, + context=context, + result_name=result_name, + elapsed_name=elapsed_name, + ) + started_params: StartedParamsFn = param_extractors["started_params"] + completed_params: CompletedParamsFn = param_extractors["completed_params"] + failed_params: FailedParamsFn = param_extractors["failed_params"] + + event_classes = EVENT_CLASS_MAP[prefix] + + def decorator( + func: Callable[Concatenate[Any, P], R], + ) -> Callable[Concatenate[Any, P], R]: + """Apply lifecycle event emission to the decorated function. + + Args: + func: Function to decorate. + + Returns: + Decorated function with lifecycle event emission. + """ + + @wraps(func) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R: + """Execute function with lifecycle event emission. + + Args: + self: Instance calling the method. + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + Result from the decorated function. + + Raises: + Exception: Re-raises any exception after emitting failed event. + """ + started_event_params = started_params(self, args, kwargs) + crewai_event_bus.emit( + self, + event_classes["started"](**started_event_params), + ) + + start_time = time.monotonic_ns() + try: + result = func(self, *args, **kwargs) + completed_event_params = completed_params( + self, + args, + kwargs, + result, + (time.monotonic_ns() - start_time) / 1_000_000, + ) + crewai_event_bus.emit( + self, + event_classes["completed"](**completed_event_params), + ) + + return result + except Exception as e: + failed_event_params = failed_params(self, args, kwargs, e) + crewai_event_bus.emit( + self, + event_classes["failed"](**failed_event_params), + ) + raise + + return cast(Callable[Concatenate[Any, P], R], wrapper) + + return decorator diff --git a/lib/crewai/src/crewai/memory/external/external_memory.py b/lib/crewai/src/crewai/memory/external/external_memory.py index c48ffd1e3d..a92005c5e4 100644 --- a/lib/crewai/src/crewai/memory/external/external_memory.py +++ b/lib/crewai/src/crewai/memory/external/external_memory.py @@ -1,17 +1,9 @@ from __future__ import annotations -import time -from typing import TYPE_CHECKING, Any - -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.memory_events import ( - MemoryQueryCompletedEvent, - MemoryQueryFailedEvent, - MemoryQueryStartedEvent, - MemorySaveCompletedEvent, - MemorySaveFailedEvent, - MemorySaveStartedEvent, -) +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, cast + +from crewai.events.lifecycle_decorator import with_lifecycle_events from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.memory import Memory from crewai.memory.storage.interface import Storage @@ -19,29 +11,31 @@ if TYPE_CHECKING: - from crewai.memory.storage.mem0_storage import Mem0Storage + from crewai.crew import Crew class ExternalMemory(Memory): - def __init__(self, storage: Storage | None = None, **data: Any): + def __init__(self, storage: Storage | None = None, **data: Any) -> None: super().__init__(storage=storage, **data) @staticmethod - def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage: - from crewai.memory.storage.mem0_storage import Mem0Storage + def _configure_mem0(crew: Crew, config: dict[str, Any]) -> Storage: + from crewai.memory.storage.mem0_storage import Mem0Config, Mem0Storage - return Mem0Storage(type="external", crew=crew, config=config) + return Mem0Storage( + type="external", crew=crew, config=cast(Mem0Config, cast(object, config)) + ) @staticmethod - def external_supported_storages() -> dict[str, Any]: + def external_supported_storages() -> dict[ + str, Callable[[Crew, dict[str, Any]], Storage] + ]: return { "mem0": ExternalMemory._configure_mem0, } @staticmethod - def create_storage( - crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None - ) -> Storage: + def create_storage(crew: Crew, embedder_config: ProviderSpec | None) -> Storage: if not embedder_config: raise ValueError("embedder_config is required") @@ -53,115 +47,59 @@ def create_storage( if provider not in supported_storages: raise ValueError(f"Provider {provider} not supported") - return supported_storages[provider](crew, embedder_config.get("config", {})) - + config = embedder_config.get("config", {}) + return supported_storages[provider](crew, cast(dict[str, Any], config)) + + @with_lifecycle_events( + "memory_save", + args_map={"value": "value", "metadata": "metadata"}, + context={ + "source_type": "external_memory", + "from_agent": lambda self: self.agent, + "from_task": lambda self: self.task, + }, + elapsed_name="save_time_ms", + ) def save( self, value: Any, metadata: dict[str, Any] | None = None, ) -> None: """Saves a value into the external storage.""" - crewai_event_bus.emit( - self, - event=MemorySaveStartedEvent( - value=value, - metadata=metadata, - source_type="external_memory", - from_agent=self.agent, - from_task=self.task, - ), + item = ExternalMemoryItem( + value=value, + metadata=metadata, + agent=self.agent.role if self.agent else None, ) - - start_time = time.time() - try: - item = ExternalMemoryItem( - value=value, - metadata=metadata, - agent=self.agent.role if self.agent else None, - ) - super().save(value=item.value, metadata=item.metadata) - - crewai_event_bus.emit( - self, - event=MemorySaveCompletedEvent( - value=value, - metadata=metadata, - save_time_ms=(time.time() - start_time) * 1000, - source_type="external_memory", - from_agent=self.agent, - from_task=self.task, - ), - ) - except Exception as e: - crewai_event_bus.emit( - self, - event=MemorySaveFailedEvent( - value=value, - metadata=metadata, - error=str(e), - source_type="external_memory", - from_agent=self.agent, - from_task=self.task, - ), - ) - raise - + super().save(value=item.value, metadata=item.metadata) + + @with_lifecycle_events( + "memory_query", + args_map={ + "query": "query", + "limit": "limit", + "score_threshold": "score_threshold", + }, + context={ + "source_type": "external_memory", + "from_agent": lambda self: self.agent, + "from_task": lambda self: self.task, + }, + result_name="results", + elapsed_name="query_time_ms", + ) def search( self, query: str, limit: int = 5, score_threshold: float = 0.6, - ): - crewai_event_bus.emit( - self, - event=MemoryQueryStartedEvent( - query=query, - limit=limit, - score_threshold=score_threshold, - source_type="external_memory", - from_agent=self.agent, - from_task=self.task, - ), - ) - - start_time = time.time() - try: - results = super().search( - query=query, limit=limit, score_threshold=score_threshold - ) - - crewai_event_bus.emit( - self, - event=MemoryQueryCompletedEvent( - query=query, - results=results, - limit=limit, - score_threshold=score_threshold, - query_time_ms=(time.time() - start_time) * 1000, - source_type="external_memory", - from_agent=self.agent, - from_task=self.task, - ), - ) - - return results - except Exception as e: - crewai_event_bus.emit( - self, - event=MemoryQueryFailedEvent( - query=query, - limit=limit, - score_threshold=score_threshold, - error=str(e), - source_type="external_memory", - ), - ) - raise + ) -> Any: + return super().search(query=query, limit=limit, score_threshold=score_threshold) def reset(self) -> None: self.storage.reset() - def set_crew(self, crew: Any) -> ExternalMemory: + def set_crew(self, crew: Crew) -> ExternalMemory: super().set_crew(crew) if not self.storage: diff --git a/lib/crewai/src/crewai/memory/memory.py b/lib/crewai/src/crewai/memory/memory.py index 74297f9e4c..fbc72760a9 100644 --- a/lib/crewai/src/crewai/memory/memory.py +++ b/lib/crewai/src/crewai/memory/memory.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel @@ -24,9 +24,6 @@ class Memory(BaseModel): _agent: Agent | None = None _task: Task | None = None - def __init__(self, storage: Any, **data: Any): - super().__init__(storage=storage, **data) - @property def task(self) -> Task | None: """Get the current task associated with this memory.""" @@ -62,8 +59,11 @@ def search( limit: int = 5, score_threshold: float = 0.6, ) -> list[Any]: - return self.storage.search( - query=query, limit=limit, score_threshold=score_threshold + return cast( + list[Any], + self.storage.search( + query=query, limit=limit, score_threshold=score_threshold + ), ) def set_crew(self, crew: Any) -> Memory: diff --git a/lib/crewai/src/crewai/memory/storage/mem0_storage.py b/lib/crewai/src/crewai/memory/storage/mem0_storage.py index 73820ab117..90445d98b0 100644 --- a/lib/crewai/src/crewai/memory/storage/mem0_storage.py +++ b/lib/crewai/src/crewai/memory/storage/mem0_storage.py @@ -1,16 +1,83 @@ -from collections import defaultdict +from __future__ import annotations + from collections.abc import Iterable import os import re -from typing import Any +from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict -from mem0 import Memory, MemoryClient # type: ignore[import-untyped,import-not-found] +from mem0 import Memory, MemoryClient # type: ignore[import-untyped] from crewai.memory.storage.interface import Storage from crewai.rag.chromadb.utils import _sanitize_collection_name -MAX_AGENT_ID_LENGTH_MEM0 = 255 +if TYPE_CHECKING: + from crewai.crew import Crew + from crewai.utilities.types import LLMMessage, MessageRole + + +MAX_AGENT_ID_LENGTH_MEM0: Final[int] = 255 +_ASSISTANT_MESSAGE_MARKER: Final[str] = "Final Answer:" +_USER_MESSAGE_PATTERN: Final[re.Pattern[str]] = re.compile(r"User message:\s*(.*)") + + +class BaseMetadata(TypedDict): + short_term: Literal["short_term"] + long_term: Literal["long_term"] + entities: Literal["entity"] + external: Literal["external"] + + +BASE_METADATA: Final[BaseMetadata] = { + "short_term": "short_term", + "long_term": "long_term", + "entities": "entity", + "external": "external", +} + +MEMORY_TYPE_MAP: Final[dict[str, dict[str, str]]] = { + "short_term": {"type": "short_term"}, + "long_term": {"type": "long_term"}, + "entities": {"type": "entity"}, + "external": {"type": "external"}, +} + + +class BaseParams(TypedDict, total=False): + """Parameters for Mem0 memory operations.""" + + metadata: dict[str, Any] + infer: bool + includes: Any + excludes: Any + output_format: str + version: str + run_id: str + user_id: str + agent_id: str + + +class Mem0Config(TypedDict, total=False): + """Configuration for Mem0Storage.""" + + run_id: str + includes: Any + excludes: Any + custom_categories: Any + infer: bool + api_key: str + org_id: str + project_id: str + local_mem0_config: Any + user_id: str + agent_id: str + + +class Mem0Filter(TypedDict, total=False): + """Filter dictionary for Mem0 search operations.""" + + AND: list[dict[str, Any]] + OR: list[dict[str, Any]] class Mem0Storage(Storage): @@ -18,33 +85,22 @@ class Mem0Storage(Storage): Extends Storage to handle embedding and searching across entities using Mem0. """ - def __init__(self, type, crew=None, config=None): - super().__init__() - - self._validate_type(type) + def __init__( + self, + type: Literal["short_term", "long_term", "entities", "external"], + crew: Crew | None = None, + config: Mem0Config | None = None, + ) -> None: self.memory_type = type self.crew = crew - self.config = config or {} - - self._extract_config_values() - self._initialize_memory() - - def _validate_type(self, type): - supported_types = {"short_term", "long_term", "entities", "external"} - if type not in supported_types: - raise ValueError( - f"Invalid type '{type}' for Mem0Storage. " - f"Must be one of: {', '.join(supported_types)}" - ) - - def _extract_config_values(self): - self.mem0_run_id = self.config.get("run_id") - self.includes = self.config.get("includes") - self.excludes = self.config.get("excludes") - self.custom_categories = self.config.get("custom_categories") - self.infer = self.config.get("infer", True) - - def _initialize_memory(self): + if config is None: + config = {} + self.config: Mem0Config = config + self.mem0_run_id = config.get("run_id") + self.includes = config.get("includes") + self.excludes = config.get("excludes") + self.custom_categories = config.get("custom_categories") + self.infer = config.get("infer", True) api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY") org_id = self.config.get("org_id") project_id = self.config.get("project_id") @@ -65,47 +121,39 @@ def _initialize_memory(self): else Memory() ) - def _create_filter_for_search(self): - """ + def _create_filter_for_search(self) -> Mem0Filter: + """Create filter dictionary for search operations. + Returns: - dict: A filter dictionary containing AND conditions for querying data. - - Includes user_id and agent_id if both are present. - - Includes user_id if only user_id is present. - - Includes agent_id if only agent_id is present. - - Includes run_id if memory_type is 'short_term' and - mem0_run_id is present. + Filter dictionary containing AND/OR conditions for querying data. """ - filter = defaultdict(list) - if self.memory_type == "short_term" and self.mem0_run_id: - filter["AND"].append({"run_id": self.mem0_run_id}) - else: - user_id = self.config.get("user_id", "") - agent_id = self.config.get("agent_id", "") + return {"AND": [{"run_id": self.mem0_run_id}]} - if user_id and agent_id: - filter["OR"].append({"user_id": user_id}) - filter["OR"].append({"agent_id": agent_id}) - elif user_id: - filter["AND"].append({"user_id": user_id}) - elif agent_id: - filter["AND"].append({"agent_id": agent_id}) - - return filter + user_id = self.config.get("user_id") + agent_id = self.config.get("agent_id") + if user_id and agent_id: + return {"OR": [{"user_id": user_id}, {"agent_id": agent_id}]} + if user_id: + return {"AND": [{"user_id": user_id}]} + if agent_id: + return {"AND": [{"agent_id": agent_id}]} + return {} def save(self, value: Any, metadata: dict[str, Any]) -> None: - def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: - return next( + def _last_content(messages_: Iterable[LLMMessage], role: MessageRole) -> str: + content = next( ( m.get("content", "") - for m in reversed(list(messages)) + for m in reversed(list(messages_)) if m.get("role") == role ), "", ) + return str(content) if content else "" conversations = [] - messages = metadata.pop("messages", None) + messages: Iterable[LLMMessage] = metadata.pop("messages", []) if messages: last_user = _last_content(messages, "user") last_assistant = _last_content(messages, "assistant") @@ -120,20 +168,11 @@ def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: user_id = self.config.get("user_id", "") - base_metadata = { - "short_term": "short_term", - "long_term": "long_term", - "entities": "entity", - "external": "external", - } - - # Shared base params - params: dict[str, Any] = { - "metadata": {"type": base_metadata[self.memory_type], **metadata}, + params: BaseParams = { + "metadata": {"type": BASE_METADATA[self.memory_type], **metadata}, "infer": self.infer, } - # MemoryClient-specific overrides if isinstance(self.memory, MemoryClient): params["includes"] = self.includes params["excludes"] = self.excludes @@ -154,7 +193,7 @@ def _last_content(messages: Iterable[dict[str, Any]], role: str) -> str: def search( self, query: str, limit: int = 5, score_threshold: float = 0.6 ) -> list[Any]: - params = { + params: dict[str, Any] = { "query": query, "limit": limit, "version": "v2", @@ -164,15 +203,8 @@ def search( if user_id := self.config.get("user_id", ""): params["user_id"] = user_id - memory_type_map = { - "short_term": {"type": "short_term"}, - "long_term": {"type": "long_term"}, - "entities": {"type": "entity"}, - "external": {"type": "external"}, - } - - if self.memory_type in memory_type_map: - params["metadata"] = memory_type_map[self.memory_type] + if self.memory_type in MEMORY_TYPE_MAP: + params["metadata"] = MEMORY_TYPE_MAP[self.memory_type] if self.memory_type == "short_term": params["run_id"] = self.mem0_run_id @@ -195,11 +227,12 @@ def search( return [r for r in results["results"]] - def reset(self): + def reset(self) -> None: if self.memory: self.memory.reset() - def _sanitize_role(self, role: str) -> str: + @staticmethod + def _sanitize_role(role: str) -> str: """ Sanitizes agent roles to ensure valid directory names. """ @@ -210,21 +243,20 @@ def _get_agent_name(self) -> str: return "" agents = self.crew.agents - agents = [self._sanitize_role(agent.role) for agent in agents] - agents = "_".join(agents) + agents_roles = "".join([self._sanitize_role(agent.role) for agent in agents]) return _sanitize_collection_name( - name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0 + name=agents_roles, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0 ) - def _get_assistant_message(self, text: str) -> str: - marker = "Final Answer:" - if marker in text: - return text.split(marker, 1)[1].strip() + @staticmethod + def _get_assistant_message(text: str) -> str: + if _ASSISTANT_MESSAGE_MARKER in text: + return text.split(_ASSISTANT_MESSAGE_MARKER, 1)[1].strip() return text - def _get_user_message(self, text: str) -> str: - pattern = r"User message:\s*(.*)" - match = re.search(pattern, text) + @staticmethod + def _get_user_message(text: str) -> str: + match = _USER_MESSAGE_PATTERN.search(text) if match: return match.group(1).strip() return text diff --git a/lib/crewai/src/crewai/utilities/types.py b/lib/crewai/src/crewai/utilities/types.py index bc331a97e9..d5cd832dbf 100644 --- a/lib/crewai/src/crewai/utilities/types.py +++ b/lib/crewai/src/crewai/utilities/types.py @@ -3,6 +3,9 @@ from typing import Any, Literal, TypedDict +MessageRole = Literal["user", "assistant", "system"] + + class LLMMessage(TypedDict): """Type for formatted LLM messages. @@ -11,5 +14,5 @@ class LLMMessage(TypedDict): instead of str | list[dict[str, str]] """ - role: Literal["user", "assistant", "system"] + role: MessageRole content: str | list[dict[str, Any]]