Skip to content

Commit bf23a3b

Browse files
committed
Fix --no-worker mode and add deduplicate parameter
- Fix background task execution in --no-worker mode by using HybridBackgroundTasks directly as type annotation instead of Depends(get_background_tasks) - Add deduplicate parameter to create_long_term_memory endpoint with default true - Pass background_tasks parameter through search_long_term_memory and memory_prompt endpoints - Update MCP server to create HybridBackgroundTasks instance for memory_prompt calls - Update tests to patch actual functions instead of BackgroundTasks class - Add test to verify tasks run inline in --no-worker mode - Bump server version to 0.12.3 and client version to 0.12.7
1 parent 08b51f2 commit bf23a3b

File tree

12 files changed

+202
-59
lines changed

12 files changed

+202
-59
lines changed

agent-memory-client/agent_memory_client/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
memory management capabilities for AI agents and applications.
66
"""
77

8-
__version__ = "0.12.6"
8+
__version__ = "0.12.7"
99

1010
from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client
1111
from .exceptions import (

agent-memory-client/agent_memory_client/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,16 @@ async def add_memories_to_working_memory(
618618
return await self.put_working_memory(session_id, working_memory)
619619

620620
async def create_long_term_memory(
621-
self, memories: Sequence[ClientMemoryRecord | MemoryRecord]
621+
self,
622+
memories: Sequence[ClientMemoryRecord | MemoryRecord],
623+
deduplicate: bool = True,
622624
) -> AckResponse:
623625
"""
624626
Create long-term memories for later retrieval.
625627
626628
Args:
627629
memories: List of MemoryRecord objects to store
630+
deduplicate: Whether to deduplicate memories before indexing (default: True)
628631
629632
Returns:
630633
AckResponse indicating success
@@ -668,7 +671,10 @@ async def create_long_term_memory(
668671
memory.id = str(ULID())
669672

670673
payload = {
671-
"memories": [m.model_dump(exclude_none=True, mode="json") for m in memories]
674+
"memories": [
675+
m.model_dump(exclude_none=True, mode="json") for m in memories
676+
],
677+
"deduplicate": deduplicate,
672678
}
673679

674680
try:

agent_memory_server/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Redis Agent Memory Server - A memory system for conversational AI."""
22

3-
__version__ = "0.12.2"
3+
__version__ = "0.12.3"

agent_memory_server/api.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from agent_memory_server import long_term_memory, working_memory
1010
from agent_memory_server.auth import UserInfo, get_current_user
1111
from agent_memory_server.config import settings
12-
from agent_memory_server.dependencies import get_background_tasks
12+
from agent_memory_server.dependencies import HybridBackgroundTasks
1313
from agent_memory_server.filters import SessionId, UserId
1414
from agent_memory_server.llms import get_model_client, get_model_config
1515
from agent_memory_server.logging import get_logger
@@ -456,9 +456,9 @@ async def get_working_memory(
456456
async def put_working_memory(
457457
session_id: str,
458458
memory: UpdateWorkingMemory,
459+
background_tasks: HybridBackgroundTasks,
459460
model_name: ModelNameLiteral | None = None,
460461
context_window_max: int | None = None,
461-
background_tasks=Depends(get_background_tasks),
462462
current_user: UserInfo = Depends(get_current_user),
463463
):
464464
"""
@@ -593,7 +593,7 @@ async def delete_working_memory(
593593
@router.post("/v1/long-term-memory/", response_model=AckResponse)
594594
async def create_long_term_memory(
595595
payload: CreateMemoryRecordRequest,
596-
background_tasks=Depends(get_background_tasks),
596+
background_tasks: HybridBackgroundTasks,
597597
current_user: UserInfo = Depends(get_current_user),
598598
):
599599
"""
@@ -609,9 +609,9 @@ async def create_long_term_memory(
609609
if not settings.long_term_memory:
610610
raise HTTPException(status_code=400, detail="Long-term memory is disabled")
611611

612-
# Validate and process memories according to Stage 2 requirements
612+
# Validate and process memories
613613
for memory in payload.memories:
614-
# Enforce that id is required on memory sent from clients
614+
# Enforce that ID is required on memory sent from clients
615615
if not memory.id:
616616
raise HTTPException(
617617
status_code=400, detail="id is required for all memory records"
@@ -624,13 +624,15 @@ async def create_long_term_memory(
624624
background_tasks.add_task(
625625
long_term_memory.index_long_term_memories,
626626
memories=payload.memories,
627+
deduplicate=payload.deduplicate,
627628
)
628629
return AckResponse(status="ok")
629630

630631

631632
@router.post("/v1/long-term-memory/search", response_model=MemoryRecordResultsResponse)
632633
async def search_long_term_memory(
633634
payload: SearchRequest,
635+
background_tasks: HybridBackgroundTasks,
634636
optimize_query: bool = False,
635637
current_user: UserInfo = Depends(get_current_user),
636638
):
@@ -752,7 +754,6 @@ def _vals(f):
752754
# Update last_accessed in background with rate limiting
753755
ids = [m.id for m in ranked if m.id]
754756
if ids:
755-
background_tasks = get_background_tasks()
756757
background_tasks.add_task(long_term_memory.update_last_accessed, ids)
757758

758759
raw_results.memories = ranked
@@ -853,6 +854,7 @@ async def update_long_term_memory(
853854
@router.post("/v1/memory/prompt", response_model=MemoryPromptResponse)
854855
async def memory_prompt(
855856
params: MemoryPromptRequest,
857+
background_tasks: HybridBackgroundTasks,
856858
optimize_query: bool = False,
857859
current_user: UserInfo = Depends(get_current_user),
858860
) -> MemoryPromptResponse:
@@ -992,6 +994,7 @@ async def memory_prompt(
992994
logger.debug(f"[memory_prompt] Search payload: {search_payload}")
993995
long_term_memories = await search_long_term_memory(
994996
search_payload,
997+
background_tasks,
995998
optimize_query=optimize_query,
996999
)
9971000

agent_memory_server/dependencies.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,20 @@ def get_background_tasks() -> HybridBackgroundTasks:
6161
"""
6262
Dependency function that returns a HybridBackgroundTasks instance.
6363
64-
This is used by API endpoints to inject a consistent background tasks object.
64+
NOTE: This function is deprecated. Use HybridBackgroundTasks directly as a type
65+
annotation in your endpoint instead of Depends(get_background_tasks).
66+
67+
Example:
68+
# Old way (deprecated):
69+
async def endpoint(background_tasks=Depends(get_background_tasks)):
70+
...
71+
72+
# New way (correct):
73+
async def endpoint(background_tasks: HybridBackgroundTasks):
74+
...
75+
76+
FastAPI will automatically inject the correct instance when you use
77+
HybridBackgroundTasks as a type annotation.
6578
"""
6679
logger.info("Getting background tasks class")
6780
return HybridBackgroundTasks()

agent_memory_server/long_term_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ async def compact_long_term_memories(
431431
session_id: str | None = None,
432432
llm_client: OpenAIClientWrapper | AnthropicClientWrapper | None = None,
433433
redis_client: Redis | None = None,
434-
vector_distance_threshold: float = 0.12,
434+
vector_distance_threshold: float = 0.2,
435435
compact_hash_duplicates: bool = True,
436436
compact_semantic_duplicates: bool = True,
437437
perpetual: Perpetual = Perpetual(
@@ -1173,7 +1173,7 @@ async def deduplicate_by_semantic_search(
11731173
namespace: str | None = None,
11741174
user_id: str | None = None,
11751175
session_id: str | None = None,
1176-
vector_distance_threshold: float = 0.12,
1176+
vector_distance_threshold: float = 0.2,
11771177
) -> tuple[MemoryRecord | None, bool]:
11781178
"""
11791179
Check if a memory has semantic duplicates and merge if found.

agent_memory_server/mcp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,14 @@ async def memory_prompt(
689689
if search_payload is not None:
690690
_params["long_term_search"] = search_payload
691691

692+
# Create a background tasks instance for the MCP call
693+
from agent_memory_server.dependencies import HybridBackgroundTasks
694+
695+
background_tasks = HybridBackgroundTasks()
696+
692697
return await core_memory_prompt(
693698
params=MemoryPromptRequest(query=query, **_params),
699+
background_tasks=background_tasks,
694700
optimize_query=optimize_query,
695701
)
696702

agent_memory_server/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ class CreateMemoryRecordRequest(BaseModel):
526526
"""Payload for creating memory records"""
527527

528528
memories: list[ExtractedMemoryRecord]
529+
deduplicate: bool = Field(
530+
default=True,
531+
description="Whether to deduplicate memories before indexing",
532+
)
529533

530534

531535
class GetSessionsQuery(BaseModel):

tests/conftest.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from agent_memory_server.api import router as memory_router
1616
from agent_memory_server.config import settings
17-
from agent_memory_server.dependencies import HybridBackgroundTasks, get_background_tasks
17+
from agent_memory_server.dependencies import HybridBackgroundTasks
1818
from agent_memory_server.healthcheck import router as health_router
1919
from agent_memory_server.llms import OpenAIClientWrapper
2020
from agent_memory_server.models import (
@@ -407,8 +407,8 @@ def app(use_test_redis_connection):
407407

408408

409409
@pytest.fixture()
410-
def app_with_mock_background_tasks(use_test_redis_connection, mock_background_tasks):
411-
"""Create a test FastAPI app with routers"""
410+
def app_with_mock_background_tasks(use_test_redis_connection):
411+
"""Create a test FastAPI app with routers and mocked background tasks"""
412412
app = FastAPI()
413413

414414
# Include routers
@@ -423,7 +423,6 @@ async def mock_get_redis_conn(*args, **kwargs):
423423
from agent_memory_server.utils.redis import get_redis_conn
424424

425425
app.dependency_overrides[get_redis_conn] = mock_get_redis_conn
426-
app.dependency_overrides[get_background_tasks] = lambda: mock_background_tasks
427426

428427
return app
429428

@@ -447,9 +446,27 @@ async def client(app):
447446

448447

449448
@pytest.fixture()
450-
async def client_with_mock_background_tasks(app_with_mock_background_tasks):
451-
async with AsyncClient(
452-
transport=ASGITransport(app=app_with_mock_background_tasks),
453-
base_url="http://test",
454-
) as client:
455-
yield client
449+
async def client_with_mock_background_tasks(
450+
app_with_mock_background_tasks, mock_background_tasks
451+
):
452+
"""Client with mocked background tasks - patches the HybridBackgroundTasks class"""
453+
# Patch the HybridBackgroundTasks class to return our mock
454+
# We need to patch it in multiple places since FastAPI creates instances directly
455+
patches = [
456+
mock.patch(
457+
"agent_memory_server.api.HybridBackgroundTasks",
458+
return_value=mock_background_tasks,
459+
),
460+
mock.patch(
461+
"agent_memory_server.dependencies.HybridBackgroundTasks",
462+
return_value=mock_background_tasks,
463+
),
464+
mock.patch("fastapi.BackgroundTasks", return_value=mock_background_tasks),
465+
]
466+
467+
with patches[0], patches[1], patches[2]:
468+
async with AsyncClient(
469+
transport=ASGITransport(app=app_with_mock_background_tasks),
470+
base_url="http://test",
471+
) as client:
472+
yield client

tests/test_api.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import pytest
55

66
from agent_memory_server.config import Settings
7-
from agent_memory_server.long_term_memory import (
8-
promote_working_memory_to_long_term,
9-
)
107
from agent_memory_server.models import (
118
MemoryMessage,
129
MemoryRecordResult,
@@ -509,11 +506,8 @@ async def test_working_memory_reconstruction_from_long_term(
509506

510507
@pytest.mark.requires_api_keys
511508
@pytest.mark.asyncio
512-
async def test_put_memory_stores_messages_in_long_term_memory(
513-
self, client_with_mock_background_tasks, mock_background_tasks
514-
):
509+
async def test_put_memory_stores_messages_in_long_term_memory(self, client):
515510
"""Test the put_memory endpoint"""
516-
client = client_with_mock_background_tasks
517511
payload = {
518512
"messages": [
519513
{"role": "user", "content": "Hello"},
@@ -526,7 +520,13 @@ async def test_put_memory_stores_messages_in_long_term_memory(
526520
}
527521
mock_settings = Settings(long_term_memory=True)
528522

529-
with patch("agent_memory_server.api.settings", mock_settings):
523+
# Mock the promote function to verify it's called
524+
with (
525+
patch("agent_memory_server.api.settings", mock_settings),
526+
patch(
527+
"agent_memory_server.api.long_term_memory.promote_working_memory_to_long_term"
528+
) as mock_promote,
529+
):
530530
response = await client.put("/v1/working-memory/test-session", json=payload)
531531

532532
assert response.status_code == 200
@@ -537,22 +537,15 @@ async def test_put_memory_stores_messages_in_long_term_memory(
537537
assert "context" in data
538538
assert data["context"] == "Previous context"
539539

540-
# Check that background tasks were called
541-
assert mock_background_tasks.add_task.call_count == 1
542-
543-
# Check that the last call was for long-term memory promotion
544-
assert (
545-
mock_background_tasks.add_task.call_args_list[-1][0][0]
546-
== promote_working_memory_to_long_term
547-
)
540+
# Check that the promotion function was called as a background task
541+
# In --no-worker mode, it runs inline, so it should have been called
542+
assert mock_promote.call_count == 1
543+
assert mock_promote.call_args[1]["session_id"] == "test-session"
548544

549545
@pytest.mark.requires_api_keys
550546
@pytest.mark.asyncio
551-
async def test_put_memory_with_structured_memories_triggers_promotion(
552-
self, client_with_mock_background_tasks, mock_background_tasks
553-
):
547+
async def test_put_memory_with_structured_memories_triggers_promotion(self, client):
554548
"""Test that structured memories trigger background promotion task"""
555-
client = client_with_mock_background_tasks
556549
payload = {
557550
"messages": [],
558551
"memories": [
@@ -569,7 +562,13 @@ async def test_put_memory_with_structured_memories_triggers_promotion(
569562
}
570563
mock_settings = Settings(long_term_memory=True)
571564

572-
with patch("agent_memory_server.api.settings", mock_settings):
565+
# Mock the promote function to verify it's called
566+
with (
567+
patch("agent_memory_server.api.settings", mock_settings),
568+
patch(
569+
"agent_memory_server.api.long_term_memory.promote_working_memory_to_long_term"
570+
) as mock_promote,
571+
):
573572
response = await client.put("/v1/working-memory/test-session", json=payload)
574573

575574
assert response.status_code == 200
@@ -579,20 +578,11 @@ async def test_put_memory_with_structured_memories_triggers_promotion(
579578
assert len(data["memories"]) == 1
580579
assert data["memories"][0]["text"] == "User prefers dark mode"
581580

582-
# Check that promotion background task was called
583-
assert mock_background_tasks.add_task.call_count == 1
584-
585-
# Check that it was the promotion task, not indexing
586-
assert (
587-
mock_background_tasks.add_task.call_args_list[0][0][0]
588-
== promote_working_memory_to_long_term
589-
)
590-
591-
# Check the arguments passed to the promotion task
592-
task_call = mock_background_tasks.add_task.call_args_list[0]
593-
task_kwargs = task_call[1]
594-
assert task_kwargs["session_id"] == "test-session"
595-
assert task_kwargs["namespace"] == "test-namespace"
581+
# Check that the promotion function was called as a background task
582+
# In --no-worker mode, it runs inline, so it should have been called
583+
assert mock_promote.call_count == 1
584+
assert mock_promote.call_args[1]["session_id"] == "test-session"
585+
assert mock_promote.call_args[1]["namespace"] == "test-namespace"
596586

597587
@pytest.mark.requires_api_keys
598588
@pytest.mark.asyncio

0 commit comments

Comments
 (0)