Skip to content

Commit fa0e3bb

Browse files
authored
Merge pull request #62 from redis/fix/change-get-or-create-working-memory
Add new_session field to WorkingMemoryResponse to indicate if a session was created, new client method get_or_create_working_memory_session()
2 parents fc3c19a + cde631a commit fa0e3bb

File tree

19 files changed

+453
-364
lines changed

19 files changed

+453
-364
lines changed

agent-memory-client/agent_memory_client/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
memory management capabilities for AI agents and applications.
66
"""
77

8-
__version__ = "0.11.1"
8+
__version__ = "0.12.0"
99

1010
from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client
1111
from .exceptions import (

agent-memory-client/agent_memory_client/client.py

Lines changed: 93 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import asyncio
8+
import logging # noqa: F401
89
import re
910
from collections.abc import AsyncIterator, Sequence
1011
from typing import TYPE_CHECKING, Any, Literal, TypedDict
@@ -39,7 +40,6 @@
3940
RecencyConfig,
4041
SessionListResponse,
4142
WorkingMemory,
42-
WorkingMemoryGetOrCreateResponse,
4343
WorkingMemoryResponse,
4444
)
4545

@@ -120,10 +120,16 @@ def __init__(self, config: MemoryClientConfig):
120120
Args:
121121
config: MemoryClientConfig instance with server connection details
122122
"""
123+
from . import __version__
124+
123125
self.config = config
124126
self._client = httpx.AsyncClient(
125127
base_url=config.base_url,
126128
timeout=config.timeout,
129+
headers={
130+
"User-Agent": f"agent-memory-client/{__version__}",
131+
"X-Client-Version": __version__,
132+
},
127133
)
128134

129135
async def close(self) -> None:
@@ -289,11 +295,11 @@ async def get_or_create_working_memory(
289295
namespace: str | None = None,
290296
model_name: ModelNameLiteral | None = None,
291297
context_window_max: int | None = None,
292-
) -> WorkingMemoryGetOrCreateResponse:
298+
) -> tuple[bool, WorkingMemory]:
293299
"""
294300
Get working memory for a session, creating it if it doesn't exist.
295301
296-
This method returns both the working memory and whether it was created or found.
302+
This method returns a tuple with the creation status and the working memory.
297303
This is important for applications that need to know if they're working with
298304
a new session or an existing one.
299305
@@ -305,24 +311,24 @@ async def get_or_create_working_memory(
305311
context_window_max: Optional direct specification of context window tokens
306312
307313
Returns:
308-
WorkingMemoryGetOrCreateResponse containing the memory and creation status
314+
Tuple of (created: bool, memory: WorkingMemory)
315+
- created: True if the session was created, False if it already existed
316+
- memory: The WorkingMemory object
309317
310318
Example:
311319
```python
312320
# Get or create session memory
313-
result = await client.get_or_create_working_memory(
321+
created, memory = await client.get_or_create_working_memory(
314322
session_id="chat_session_123",
315323
user_id="user_456"
316324
)
317325
318-
if result.created:
319-
print("Created new session")
326+
if created:
327+
logging.info("Created new session")
320328
else:
321-
print("Found existing session")
329+
logging.info("Found existing session")
322330
323-
# Access the memory
324-
memory = result.memory
325-
print(f"Session has {len(memory.messages)} messages")
331+
logging.info(f"Session has {len(memory.messages)} messages")
326332
```
327333
"""
328334
try:
@@ -334,29 +340,54 @@ async def get_or_create_working_memory(
334340
model_name=model_name,
335341
context_window_max=context_window_max,
336342
)
337-
return WorkingMemoryGetOrCreateResponse(
338-
memory=existing_memory, created=False
339-
)
340-
except Exception:
341-
# Session doesn't exist, create it
342-
empty_memory = WorkingMemory(
343-
session_id=session_id,
344-
namespace=namespace or self.config.default_namespace,
345-
messages=[],
346-
memories=[],
347-
data={},
348-
user_id=user_id,
349-
)
350343

351-
created_memory = await self.put_working_memory(
352-
session_id=session_id,
353-
memory=empty_memory,
354-
user_id=user_id,
355-
model_name=model_name,
356-
context_window_max=context_window_max,
357-
)
344+
# Check if this is an unsaved session (deprecated behavior for old clients)
345+
if getattr(existing_memory, "unsaved", None) is True:
346+
# This is an unsaved session - we need to create it properly
347+
empty_memory = WorkingMemory(
348+
session_id=session_id,
349+
namespace=namespace or self.config.default_namespace,
350+
messages=[],
351+
memories=[],
352+
data={},
353+
user_id=user_id,
354+
)
355+
356+
created_memory = await self.put_working_memory(
357+
session_id=session_id,
358+
memory=empty_memory,
359+
user_id=user_id,
360+
model_name=model_name,
361+
context_window_max=context_window_max,
362+
)
358363

359-
return WorkingMemoryGetOrCreateResponse(memory=created_memory, created=True)
364+
return (True, created_memory)
365+
366+
return (False, existing_memory)
367+
except httpx.HTTPStatusError as e:
368+
if e.response.status_code == 404:
369+
# Session doesn't exist, create it
370+
empty_memory = WorkingMemory(
371+
session_id=session_id,
372+
namespace=namespace or self.config.default_namespace,
373+
messages=[],
374+
memories=[],
375+
data={},
376+
user_id=user_id,
377+
)
378+
379+
created_memory = await self.put_working_memory(
380+
session_id=session_id,
381+
memory=empty_memory,
382+
user_id=user_id,
383+
model_name=model_name,
384+
context_window_max=context_window_max,
385+
)
386+
387+
return (True, created_memory)
388+
else:
389+
# Re-raise other HTTP errors
390+
raise
360391

361392
async def put_working_memory(
362393
self,
@@ -484,11 +515,10 @@ async def set_working_memory_data(
484515
existing_memory = None
485516
if preserve_existing:
486517
try:
487-
result_obj = await self.get_or_create_working_memory(
518+
created, existing_memory = await self.get_or_create_working_memory(
488519
session_id=session_id,
489520
namespace=namespace,
490521
)
491-
existing_memory = result_obj.memory
492522
except Exception:
493523
existing_memory = None
494524

@@ -544,11 +574,10 @@ async def add_memories_to_working_memory(
544574
```
545575
"""
546576
# Get existing memory
547-
result_obj = await self.get_or_create_working_memory(
577+
created, existing_memory = await self.get_or_create_working_memory(
548578
session_id=session_id,
549579
namespace=namespace,
550580
)
551-
existing_memory = result_obj.memory
552581

553582
# Determine final memories list
554583
if replace or not existing_memory:
@@ -610,7 +639,7 @@ async def create_long_term_memory(
610639
]
611640
612641
response = await client.create_long_term_memory(memories)
613-
print(f"Stored memories: {response.status}")
642+
logging.info(f"Stored memories: {response.status}")
614643
```
615644
"""
616645
# Apply default namespace and ensure IDs are present
@@ -764,9 +793,9 @@ async def search_long_term_memory(
764793
distance_threshold=0.3
765794
)
766795
767-
print(f"Found {results.total} memories")
796+
logging.info(f"Found {results.total} memories")
768797
for memory in results.memories:
769-
print(f"- {memory.text[:100]}... (distance: {memory.dist})")
798+
logging.info(f"- {memory.text[:100]}... (distance: {memory.dist})")
770799
```
771800
"""
772801
# Convert dictionary filters to their proper filter objects if needed
@@ -916,9 +945,9 @@ async def search_memory_tool(
916945
min_relevance=0.7
917946
)
918947
919-
print(result["summary"]) # "Found 2 relevant memories for: user preferences about UI themes"
948+
logging.info(result["summary"]) # "Found 2 relevant memories for: user preferences about UI themes"
920949
for memory in result["memories"]:
921-
print(f"- {memory['text']} (score: {memory['relevance_score']})")
950+
logging.info(f"- {memory['text']} (score: {memory['relevance_score']})")
922951
```
923952
924953
LLM Framework Integration:
@@ -1119,18 +1148,17 @@ async def get_working_memory_tool(
11191148
session_id="current_session"
11201149
)
11211150
1122-
print(memory_state["summary"]) # Human-readable summary
1123-
print(f"Messages: {memory_state['message_count']}")
1124-
print(f"Memories: {len(memory_state['memories'])}")
1151+
logging.info(memory_state["summary"]) # Human-readable summary
1152+
logging.info(f"Messages: {memory_state['message_count']}")
1153+
logging.info(f"Memories: {len(memory_state['memories'])}")
11251154
```
11261155
"""
11271156
try:
1128-
result_obj = await self.get_or_create_working_memory(
1157+
created, result = await self.get_or_create_working_memory(
11291158
session_id=session_id,
11301159
namespace=namespace or self.config.default_namespace,
11311160
user_id=user_id,
11321161
)
1133-
result = result_obj.memory
11341162

11351163
# Format for LLM consumption
11361164
message_count = len(result.messages) if result.messages else 0
@@ -1200,24 +1228,23 @@ async def get_or_create_working_memory_tool(
12001228
)
12011229
12021230
if memory_state["created"]:
1203-
print("Created new session")
1231+
logging.info("Created new session")
12041232
else:
1205-
print("Found existing session")
1233+
logging.info("Found existing session")
12061234
1207-
print(memory_state["summary"]) # Human-readable summary
1208-
print(f"Messages: {memory_state['message_count']}")
1209-
print(f"Memories: {len(memory_state['memories'])}")
1235+
logging.info(memory_state["summary"]) # Human-readable summary
1236+
logging.info(f"Messages: {memory_state['message_count']}")
1237+
logging.info(f"Memories: {len(memory_state['memories'])}")
12101238
```
12111239
"""
12121240
try:
1213-
result_obj = await self.get_or_create_working_memory(
1241+
created, result = await self.get_or_create_working_memory(
12141242
session_id=session_id,
12151243
namespace=namespace or self.config.default_namespace,
12161244
user_id=user_id,
12171245
)
12181246

12191247
# Format for LLM consumption
1220-
result = result_obj.memory
12211248
message_count = len(result.messages) if result.messages else 0
12221249
memory_count = len(result.memories) if result.memories else 0
12231250
data_keys = list(result.data.keys()) if result.data else []
@@ -1238,11 +1265,11 @@ async def get_or_create_working_memory_tool(
12381265
}
12391266
)
12401267

1241-
status_text = "new session" if result_obj.created else "existing session"
1268+
status_text = "new session" if created else "existing session"
12421269

12431270
return {
12441271
"session_id": session_id,
1245-
"created": result_obj.created,
1272+
"created": created,
12461273
"message_count": message_count,
12471274
"memory_count": memory_count,
12481275
"memories": formatted_memories,
@@ -1299,7 +1326,7 @@ async def add_memory_tool(
12991326
entities=["vegetarian", "restaurants"]
13001327
)
13011328
1302-
print(result["summary"]) # "Successfully stored semantic memory"
1329+
logging.info(result["summary"]) # "Successfully stored semantic memory"
13031330
```
13041331
"""
13051332
try:
@@ -1373,7 +1400,7 @@ async def update_memory_data_tool(
13731400
}
13741401
)
13751402
1376-
print(result["summary"]) # "Successfully updated 3 data entries"
1403+
logging.info(result["summary"]) # "Successfully updated 3 data entries"
13771404
```
13781405
"""
13791406
try:
@@ -1948,9 +1975,9 @@ async def resolve_tool_call(
19481975
)
19491976
19501977
if result["success"]:
1951-
print(result["formatted_response"])
1978+
logging.info(result["formatted_response"])
19521979
else:
1953-
print(f"Error: {result['error']}")
1980+
logging.error(f"Error: {result['error']}")
19541981
```
19551982
"""
19561983
try:
@@ -2004,7 +2031,7 @@ async def resolve_tool_calls(
20042031
20052032
for result in results:
20062033
if result["success"]:
2007-
print(f"{result['function_name']}: {result['formatted_response']}")
2034+
logging.info(f"{result['function_name']}: {result['formatted_response']}")
20082035
```
20092036
"""
20102037
results = []
@@ -2062,9 +2089,9 @@ async def resolve_function_call(
20622089
)
20632090
20642091
if result["success"]:
2065-
print(result["formatted_response"])
2092+
logging.info(result["formatted_response"])
20662093
else:
2067-
print(f"Error: {result['error']}")
2094+
logging.error(f"Error: {result['error']}")
20682095
```
20692096
"""
20702097
import json
@@ -2352,7 +2379,7 @@ async def resolve_function_calls(
23522379
results = await client.resolve_function_calls(calls, "session123")
23532380
for result in results:
23542381
if result["success"]:
2355-
print(f"{result['function_name']}: {result['formatted_response']}")
2382+
logging.info(f"{result['function_name']}: {result['formatted_response']}")
23562383
```
23572384
"""
23582385
results = []
@@ -2395,10 +2422,9 @@ async def promote_working_memories_to_long_term(
23952422
Acknowledgement of promotion operation
23962423
"""
23972424
# Get current working memory
2398-
result_obj = await self.get_or_create_working_memory(
2425+
created, working_memory = await self.get_or_create_working_memory(
23992426
session_id=session_id, namespace=namespace
24002427
)
2401-
working_memory = result_obj.memory
24022428

24032429
# Filter memories if specific IDs are requested
24042430
memories_to_promote = working_memory.memories
@@ -2611,10 +2637,9 @@ async def update_working_memory_data(
26112637
WorkingMemoryResponse with updated memory
26122638
"""
26132639
# Get existing memory
2614-
result_obj = await self.get_or_create_working_memory(
2640+
created, existing_memory = await self.get_or_create_working_memory(
26152641
session_id=session_id, namespace=namespace, user_id=user_id
26162642
)
2617-
existing_memory = result_obj.memory
26182643

26192644
# Determine final data based on merge strategy
26202645
if existing_memory and existing_memory.data:
@@ -2667,10 +2692,9 @@ async def append_messages_to_working_memory(
26672692
WorkingMemoryResponse with updated memory (potentially summarized if token limit exceeded)
26682693
"""
26692694
# Get existing memory
2670-
result_obj = await self.get_or_create_working_memory(
2695+
created, existing_memory = await self.get_or_create_working_memory(
26712696
session_id=session_id, namespace=namespace, user_id=user_id
26722697
)
2673-
existing_memory = result_obj.memory
26742698

26752699
# Convert messages to MemoryMessage objects
26762700
converted_messages = []

0 commit comments

Comments
 (0)