24
24
from langgraph .checkpoint .redis .ashallow import AsyncShallowRedisSaver
25
25
from langgraph .checkpoint .redis .base import BaseRedisSaver
26
26
from langgraph .checkpoint .redis .shallow import ShallowRedisSaver
27
+ from langgraph .checkpoint .redis .util import (
28
+ EMPTY_ID_SENTINEL ,
29
+ from_storage_safe_id ,
30
+ from_storage_safe_str ,
31
+ to_storage_safe_id ,
32
+ to_storage_safe_str ,
33
+ )
27
34
from langgraph .checkpoint .redis .version import __lib_name__ , __version__
28
35
29
36
@@ -78,16 +85,22 @@ def list(
78
85
# Construct the filter expression
79
86
filter_expression = []
80
87
if config :
81
- thread_id = config [ "configurable" ][ "thread_id" ]
82
- checkpoint_ns = config [ "configurable" ]. get ( "checkpoint_ns " )
83
- checkpoint_id = get_checkpoint_id (config )
84
- filter_expression . append ( Tag ( "thread_id" ) == thread_id )
88
+ filter_expression . append (
89
+ Tag ( "thread_id " )
90
+ == to_storage_safe_id (config [ "configurable" ][ "thread_id" ] )
91
+ )
85
92
86
- # Reproducing logic from the Postgres implementation.
87
- if checkpoint_ns is not None :
88
- filter_expression .append (Tag ("checkpoint_ns" ) == checkpoint_ns )
89
- if checkpoint_id :
90
- filter_expression .append (Tag ("checkpoint_id" ) == checkpoint_id )
93
+ # Reproducing the logic from the Postgres implementation, we'll
94
+ # search for checkpoints with any namespace, including an empty
95
+ # string, while `checkpoint_id` has to have a value.
96
+ if checkpoint_ns := config ["configurable" ].get ("checkpoint_ns" ):
97
+ filter_expression .append (
98
+ Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns )
99
+ )
100
+ if checkpoint_id := get_checkpoint_id (config ):
101
+ filter_expression .append (
102
+ Tag ("checkpoint_id" ) == to_storage_safe_id (checkpoint_id )
103
+ )
91
104
92
105
if filter :
93
106
for k , v in filter .items ():
@@ -125,9 +138,10 @@ def list(
125
138
126
139
# Process the results
127
140
for doc in results .docs :
128
- thread_id = str (getattr (doc , "thread_id" , "" ))
129
- checkpoint_ns = str (getattr (doc , "checkpoint_ns" , "" ))
130
- checkpoint_id = str (getattr (doc , "checkpoint_id" , "" ))
141
+ thread_id = from_storage_safe_id (doc ["thread_id" ])
142
+ checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
143
+ checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
144
+ parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
131
145
132
146
# Fetch channel_values
133
147
channel_values = self .get_channel_values (
@@ -138,11 +152,11 @@ def list(
138
152
139
153
# Fetch pending_sends from parent checkpoint
140
154
pending_sends = []
141
- if doc [ " parent_checkpoint_id" ] :
155
+ if parent_checkpoint_id :
142
156
pending_sends = self ._load_pending_sends (
143
157
thread_id = thread_id ,
144
158
checkpoint_ns = checkpoint_ns ,
145
- parent_checkpoint_id = doc [ " parent_checkpoint_id" ] ,
159
+ parent_checkpoint_id = parent_checkpoint_id ,
146
160
)
147
161
148
162
# Fetch and parse metadata
@@ -166,7 +180,7 @@ def list(
166
180
"configurable" : {
167
181
"thread_id" : thread_id ,
168
182
"checkpoint_ns" : checkpoint_ns ,
169
- "checkpoint_id" : doc [ " checkpoint_id" ] ,
183
+ "checkpoint_id" : checkpoint_id ,
170
184
}
171
185
}
172
186
@@ -197,27 +211,36 @@ def put(
197
211
) -> RunnableConfig :
198
212
"""Store a checkpoint to Redis."""
199
213
configurable = config ["configurable" ].copy ()
214
+
200
215
thread_id = configurable .pop ("thread_id" )
201
- checkpoint_ns = configurable .pop ("checkpoint_ns" , "" )
202
- checkpoint_id = configurable .pop (
216
+ checkpoint_ns = configurable .pop ("checkpoint_ns" )
217
+ checkpoint_id = checkpoint_id = configurable .pop (
203
218
"checkpoint_id" , configurable .pop ("thread_ts" , "" )
204
219
)
205
220
221
+ # For values we store in Redis, we need to convert empty strings to the
222
+ # sentinel value.
223
+ storage_safe_thread_id = to_storage_safe_id (thread_id )
224
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
225
+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
226
+
206
227
copy = checkpoint .copy ()
228
+ # When we return the config, we need to preserve empty strings that
229
+ # were passed in, instead of the sentinel value.
207
230
next_config = {
208
231
"configurable" : {
209
232
"thread_id" : thread_id ,
210
233
"checkpoint_ns" : checkpoint_ns ,
211
- "checkpoint_id" : checkpoint [ "id" ] ,
234
+ "checkpoint_id" : checkpoint_id ,
212
235
}
213
236
}
214
237
215
- # Store checkpoint data
238
+ # Store checkpoint data.
216
239
checkpoint_data = {
217
- "thread_id" : thread_id or "" ,
218
- "checkpoint_ns" : checkpoint_ns or "" ,
219
- "checkpoint_id" : checkpoint [ "id" ] or "" ,
220
- "parent_checkpoint_id" : checkpoint_id or "" ,
240
+ "thread_id" : storage_safe_thread_id ,
241
+ "checkpoint_ns" : storage_safe_checkpoint_ns ,
242
+ "checkpoint_id" : storage_safe_checkpoint_id ,
243
+ "parent_checkpoint_id" : storage_safe_checkpoint_id ,
221
244
"checkpoint" : self ._dump_checkpoint (copy ),
222
245
"metadata" : self ._dump_metadata (metadata ),
223
246
}
@@ -231,15 +254,17 @@ def put(
231
254
[checkpoint_data ],
232
255
keys = [
233
256
BaseRedisSaver ._make_redis_checkpoint_key (
234
- thread_id , checkpoint_ns , checkpoint ["id" ]
257
+ storage_safe_thread_id ,
258
+ storage_safe_checkpoint_ns ,
259
+ storage_safe_checkpoint_id ,
235
260
)
236
261
],
237
262
)
238
263
239
- # Store blob values
264
+ # Store blob values.
240
265
blobs = self ._dump_blobs (
241
- thread_id ,
242
- checkpoint_ns ,
266
+ storage_safe_thread_id ,
267
+ storage_safe_checkpoint_ns ,
243
268
copy .get ("channel_values" , {}),
244
269
new_versions ,
245
270
)
@@ -264,17 +289,19 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
264
289
checkpoint_id = get_checkpoint_id (config )
265
290
checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
266
291
267
- # Reproducing logic from the Postgres implementation.
268
- if checkpoint_id :
292
+ ascending = True
293
+
294
+ if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL :
269
295
checkpoint_filter_expression = (
270
- (Tag ("thread_id" ) == thread_id )
271
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
272
- & (Tag ("checkpoint_id" ) == str (checkpoint_id ))
296
+ (Tag ("thread_id" ) == to_storage_safe_id ( thread_id ) )
297
+ & (Tag ("checkpoint_ns" ) == to_storage_safe_str ( checkpoint_ns ) )
298
+ & (Tag ("checkpoint_id" ) == to_storage_safe_id (checkpoint_id ))
273
299
)
274
300
else :
275
- checkpoint_filter_expression = (Tag ("thread_id" ) == thread_id ) & (
276
- Tag ("checkpoint_ns" ) == checkpoint_ns
277
- )
301
+ checkpoint_filter_expression = (
302
+ Tag ("thread_id" ) == to_storage_safe_id (thread_id )
303
+ ) & (Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns ))
304
+ ascending = False
278
305
279
306
# Construct the query
280
307
checkpoints_query = FilterQuery (
@@ -289,19 +316,18 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
289
316
],
290
317
num_results = 1 ,
291
318
)
292
- checkpoints_query .sort_by ("checkpoint_id" , asc = False )
319
+ checkpoints_query .sort_by ("checkpoint_id" , asc = ascending )
293
320
294
321
# Execute the query
295
322
results = self .checkpoints_index .search (checkpoints_query )
296
323
if not results .docs :
297
324
return None
298
325
299
326
doc = results .docs [0 ]
300
-
301
- doc_thread_id = doc ["thread_id" ]
302
- doc_checkpoint_ns = doc ["checkpoint_ns" ]
303
- doc_checkpoint_id = doc ["checkpoint_id" ]
304
- doc_parent_checkpoint_id = doc ["parent_checkpoint_id" ]
327
+ doc_thread_id = from_storage_safe_id (doc ["thread_id" ])
328
+ doc_checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
329
+ doc_checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
330
+ doc_parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
305
331
306
332
# Fetch channel_values
307
333
channel_values = self .get_channel_values (
@@ -388,10 +414,14 @@ def get_channel_values(
388
414
self , thread_id : str , checkpoint_ns : str = "" , checkpoint_id : str = ""
389
415
) -> dict [str , Any ]:
390
416
"""Retrieve channel_values dictionary with properly constructed message objects."""
417
+ storage_safe_thread_id = to_storage_safe_id (thread_id )
418
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
419
+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
420
+
391
421
checkpoint_query = FilterQuery (
392
- filter_expression = (Tag ("thread_id" ) == thread_id )
393
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
394
- & (Tag ("checkpoint_id" ) == checkpoint_id ),
422
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
423
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
424
+ & (Tag ("checkpoint_id" ) == storage_safe_checkpoint_id ),
395
425
return_fields = ["$.checkpoint.channel_versions" ],
396
426
num_results = 1 ,
397
427
)
@@ -409,8 +439,8 @@ def get_channel_values(
409
439
channel_values = {}
410
440
for channel , version in channel_versions .items ():
411
441
blob_query = FilterQuery (
412
- filter_expression = (Tag ("thread_id" ) == thread_id )
413
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
442
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
443
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
414
444
& (Tag ("channel" ) == channel )
415
445
& (Tag ("version" ) == version ),
416
446
return_fields = ["type" , "$.blob" ],
@@ -446,11 +476,15 @@ def _load_pending_sends(
446
476
Returns:
447
477
List of (type, blob) tuples representing pending sends
448
478
"""
479
+ storage_safe_thread_id = to_storage_safe_str (thread_id )
480
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
481
+ storage_safe_parent_checkpoint_id = to_storage_safe_str (parent_checkpoint_id )
482
+
449
483
# Query checkpoint_writes for parent checkpoint's TASKS channel
450
484
parent_writes_query = FilterQuery (
451
- filter_expression = (Tag ("thread_id" ) == thread_id )
452
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
453
- & (Tag ("checkpoint_id" ) == parent_checkpoint_id )
485
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
486
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
487
+ & (Tag ("checkpoint_id" ) == storage_safe_parent_checkpoint_id )
454
488
& (Tag ("channel" ) == TASKS ),
455
489
return_fields = ["type" , "blob" , "task_path" , "task_id" , "idx" ],
456
490
num_results = 100 , # Adjust as needed
0 commit comments