Skip to content

Commit 85e97af

Browse files
committed
Use sentinel conversion functions
1 parent c9c6e25 commit 85e97af

File tree

7 files changed

+328
-154
lines changed

7 files changed

+328
-154
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 83 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
2525
from langgraph.checkpoint.redis.base import BaseRedisSaver
2626
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
27+
from langgraph.checkpoint.redis.util import (
28+
EMPTY_ID_SENTINEL,
29+
from_storage_safe_id,
30+
from_storage_safe_str,
31+
to_storage_safe_id,
32+
to_storage_safe_str,
33+
)
2734
from langgraph.checkpoint.redis.version import __lib_name__, __version__
2835

2936

@@ -78,16 +85,22 @@ def list(
7885
# Construct the filter expression
7986
filter_expression = []
8087
if config:
81-
thread_id = config["configurable"]["thread_id"]
82-
checkpoint_ns = config["configurable"].get("checkpoint_ns")
83-
checkpoint_id = get_checkpoint_id(config)
84-
filter_expression.append(Tag("thread_id") == thread_id)
88+
filter_expression.append(
89+
Tag("thread_id")
90+
== to_storage_safe_id(config["configurable"]["thread_id"])
91+
)
8592

86-
# Reproducing logic from the Postgres implementation.
87-
if checkpoint_ns is not None:
88-
filter_expression.append(Tag("checkpoint_ns") == checkpoint_ns)
89-
if checkpoint_id:
90-
filter_expression.append(Tag("checkpoint_id") == checkpoint_id)
93+
# Reproducing the logic from the Postgres implementation, we'll
94+
# search for checkpoints with any namespace, including an empty
95+
# string, while `checkpoint_id` has to have a value.
96+
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
97+
filter_expression.append(
98+
Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
99+
)
100+
if checkpoint_id := get_checkpoint_id(config):
101+
filter_expression.append(
102+
Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)
103+
)
91104

92105
if filter:
93106
for k, v in filter.items():
@@ -125,9 +138,10 @@ def list(
125138

126139
# Process the results
127140
for doc in results.docs:
128-
thread_id = str(getattr(doc, "thread_id", ""))
129-
checkpoint_ns = str(getattr(doc, "checkpoint_ns", ""))
130-
checkpoint_id = str(getattr(doc, "checkpoint_id", ""))
141+
thread_id = from_storage_safe_id(doc["thread_id"])
142+
checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
143+
checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
144+
parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
131145

132146
# Fetch channel_values
133147
channel_values = self.get_channel_values(
@@ -138,11 +152,11 @@ def list(
138152

139153
# Fetch pending_sends from parent checkpoint
140154
pending_sends = []
141-
if doc["parent_checkpoint_id"]:
155+
if parent_checkpoint_id:
142156
pending_sends = self._load_pending_sends(
143157
thread_id=thread_id,
144158
checkpoint_ns=checkpoint_ns,
145-
parent_checkpoint_id=doc["parent_checkpoint_id"],
159+
parent_checkpoint_id=parent_checkpoint_id,
146160
)
147161

148162
# Fetch and parse metadata
@@ -166,7 +180,7 @@ def list(
166180
"configurable": {
167181
"thread_id": thread_id,
168182
"checkpoint_ns": checkpoint_ns,
169-
"checkpoint_id": doc["checkpoint_id"],
183+
"checkpoint_id": checkpoint_id,
170184
}
171185
}
172186

@@ -197,27 +211,36 @@ def put(
197211
) -> RunnableConfig:
198212
"""Store a checkpoint to Redis."""
199213
configurable = config["configurable"].copy()
214+
200215
thread_id = configurable.pop("thread_id")
201-
checkpoint_ns = configurable.pop("checkpoint_ns", "")
202-
checkpoint_id = configurable.pop(
216+
checkpoint_ns = configurable.pop("checkpoint_ns")
217+
checkpoint_id = checkpoint_id = configurable.pop(
203218
"checkpoint_id", configurable.pop("thread_ts", "")
204219
)
205220

221+
# For values we store in Redis, we need to convert empty strings to the
222+
# sentinel value.
223+
storage_safe_thread_id = to_storage_safe_id(thread_id)
224+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
225+
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)
226+
206227
copy = checkpoint.copy()
228+
# When we return the config, we need to preserve empty strings that
229+
# were passed in, instead of the sentinel value.
207230
next_config = {
208231
"configurable": {
209232
"thread_id": thread_id,
210233
"checkpoint_ns": checkpoint_ns,
211-
"checkpoint_id": checkpoint["id"],
234+
"checkpoint_id": checkpoint_id,
212235
}
213236
}
214237

215-
# Store checkpoint data
238+
# Store checkpoint data.
216239
checkpoint_data = {
217-
"thread_id": thread_id or "",
218-
"checkpoint_ns": checkpoint_ns or "",
219-
"checkpoint_id": checkpoint["id"] or "",
220-
"parent_checkpoint_id": checkpoint_id or "",
240+
"thread_id": storage_safe_thread_id,
241+
"checkpoint_ns": storage_safe_checkpoint_ns,
242+
"checkpoint_id": storage_safe_checkpoint_id,
243+
"parent_checkpoint_id": storage_safe_checkpoint_id,
221244
"checkpoint": self._dump_checkpoint(copy),
222245
"metadata": self._dump_metadata(metadata),
223246
}
@@ -231,15 +254,17 @@ def put(
231254
[checkpoint_data],
232255
keys=[
233256
BaseRedisSaver._make_redis_checkpoint_key(
234-
thread_id, checkpoint_ns, checkpoint["id"]
257+
storage_safe_thread_id,
258+
storage_safe_checkpoint_ns,
259+
storage_safe_checkpoint_id,
235260
)
236261
],
237262
)
238263

239-
# Store blob values
264+
# Store blob values.
240265
blobs = self._dump_blobs(
241-
thread_id,
242-
checkpoint_ns,
266+
storage_safe_thread_id,
267+
storage_safe_checkpoint_ns,
243268
copy.get("channel_values", {}),
244269
new_versions,
245270
)
@@ -264,17 +289,19 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
264289
checkpoint_id = get_checkpoint_id(config)
265290
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
266291

267-
# Reproducing logic from the Postgres implementation.
268-
if checkpoint_id:
292+
ascending = True
293+
294+
if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL:
269295
checkpoint_filter_expression = (
270-
(Tag("thread_id") == thread_id)
271-
& (Tag("checkpoint_ns") == checkpoint_ns)
272-
& (Tag("checkpoint_id") == str(checkpoint_id))
296+
(Tag("thread_id") == to_storage_safe_id(thread_id))
297+
& (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
298+
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id))
273299
)
274300
else:
275-
checkpoint_filter_expression = (Tag("thread_id") == thread_id) & (
276-
Tag("checkpoint_ns") == checkpoint_ns
277-
)
301+
checkpoint_filter_expression = (
302+
Tag("thread_id") == to_storage_safe_id(thread_id)
303+
) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
304+
ascending = False
278305

279306
# Construct the query
280307
checkpoints_query = FilterQuery(
@@ -289,19 +316,18 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
289316
],
290317
num_results=1,
291318
)
292-
checkpoints_query.sort_by("checkpoint_id", asc=False)
319+
checkpoints_query.sort_by("checkpoint_id", asc=ascending)
293320

294321
# Execute the query
295322
results = self.checkpoints_index.search(checkpoints_query)
296323
if not results.docs:
297324
return None
298325

299326
doc = results.docs[0]
300-
301-
doc_thread_id = doc["thread_id"]
302-
doc_checkpoint_ns = doc["checkpoint_ns"]
303-
doc_checkpoint_id = doc["checkpoint_id"]
304-
doc_parent_checkpoint_id = doc["parent_checkpoint_id"]
327+
doc_thread_id = from_storage_safe_id(doc["thread_id"])
328+
doc_checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
329+
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
330+
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
305331

306332
# Fetch channel_values
307333
channel_values = self.get_channel_values(
@@ -388,10 +414,14 @@ def get_channel_values(
388414
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
389415
) -> dict[str, Any]:
390416
"""Retrieve channel_values dictionary with properly constructed message objects."""
417+
storage_safe_thread_id = to_storage_safe_id(thread_id)
418+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
419+
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)
420+
391421
checkpoint_query = FilterQuery(
392-
filter_expression=(Tag("thread_id") == thread_id)
393-
& (Tag("checkpoint_ns") == checkpoint_ns)
394-
& (Tag("checkpoint_id") == checkpoint_id),
422+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
423+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
424+
& (Tag("checkpoint_id") == storage_safe_checkpoint_id),
395425
return_fields=["$.checkpoint.channel_versions"],
396426
num_results=1,
397427
)
@@ -409,8 +439,8 @@ def get_channel_values(
409439
channel_values = {}
410440
for channel, version in channel_versions.items():
411441
blob_query = FilterQuery(
412-
filter_expression=(Tag("thread_id") == thread_id)
413-
& (Tag("checkpoint_ns") == checkpoint_ns)
442+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
443+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
414444
& (Tag("channel") == channel)
415445
& (Tag("version") == version),
416446
return_fields=["type", "$.blob"],
@@ -446,11 +476,15 @@ def _load_pending_sends(
446476
Returns:
447477
List of (type, blob) tuples representing pending sends
448478
"""
479+
storage_safe_thread_id = to_storage_safe_str(thread_id)
480+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
481+
storage_safe_parent_checkpoint_id = to_storage_safe_str(parent_checkpoint_id)
482+
449483
# Query checkpoint_writes for parent checkpoint's TASKS channel
450484
parent_writes_query = FilterQuery(
451-
filter_expression=(Tag("thread_id") == thread_id)
452-
& (Tag("checkpoint_ns") == checkpoint_ns)
453-
& (Tag("checkpoint_id") == parent_checkpoint_id)
485+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
486+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
487+
& (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id)
454488
& (Tag("channel") == TASKS),
455489
return_fields=["type", "blob", "task_path", "task_id", "idx"],
456490
num_results=100, # Adjust as needed

0 commit comments

Comments
 (0)