Skip to content

Commit 397e7f3

Browse files
committed
fix: resolve HITL interrupt persistence in shallow savers (#133)
The shallow savers were prematurely cleaning up writes when a new checkpoint was saved, causing Human-in-the-Loop interrupts to be deleted before they could be consumed during resume. Changes: - Remove aggressive write cleanup from aput/put methods in shallow savers - Fix registry key consistency using to_storage_safe_str() for checkpoint_ns - Add comprehensive HITL integration tests (7 new tests) - Update existing tests to reflect new write persistence behavior Writes are now cleaned up only via delete_thread, TTL expiration, or explicit overwrite - not during checkpoint transitions.
1 parent 963efbf commit 397e7f3

File tree

10 files changed

+705
-111
lines changed

10 files changed

+705
-111
lines changed

langgraph/checkpoint/redis/aio.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ async def aput(
10631063

10641064
if self.cluster_mode:
10651065
# For cluster mode, execute operation directly
1066-
await self._redis.json().set( # type: ignore[misc]
1066+
await self._redis.json().set(
10671067
checkpoint_key, "$", checkpoint_data
10681068
)
10691069
else:
@@ -1146,7 +1146,7 @@ async def aput_writes(
11461146
)
11471147

11481148
# Redis JSON.SET is an UPSERT by default
1149-
await self._redis.json().set(key, "$", cast(Any, write_obj)) # type: ignore[misc]
1149+
await self._redis.json().set(key, "$", cast(Any, write_obj))
11501150
created_keys.append(key)
11511151

11521152
# Apply TTL to newly created keys
@@ -1181,7 +1181,7 @@ async def aput_writes(
11811181

11821182
# Add all write keys with their index as score for ordering
11831183
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
1184-
await self._redis.zadd(zset_key, zadd_mapping)
1184+
await self._redis.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]
11851185

11861186
# Apply TTL to registry key if configured
11871187
if self.ttl_config and "default_ttl" in self.ttl_config:
@@ -1243,7 +1243,7 @@ async def aput_writes(
12431243

12441244
# Add all write keys with their index as score for ordering
12451245
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
1246-
pipeline.zadd(zset_key, zadd_mapping)
1246+
pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]
12471247

12481248
# Apply TTL to registry key if configured
12491249
if self.ttl_config and "default_ttl" in self.ttl_config:
@@ -1291,7 +1291,7 @@ async def aput_writes(
12911291
zadd_mapping = {
12921292
key: idx for idx, key in enumerate(write_keys)
12931293
}
1294-
fallback_pipeline.zadd(zset_key, zadd_mapping)
1294+
fallback_pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]
12951295
if self.ttl_config and "default_ttl" in self.ttl_config:
12961296
ttl_seconds = int(
12971297
self.ttl_config.get("default_ttl") * 60
@@ -1304,14 +1304,16 @@ async def aput_writes(
13041304
# Update has_writes flag separately for older Redis
13051305
if checkpoint_key:
13061306
try:
1307-
checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc]
1307+
checkpoint_data = await self._redis.json().get(
1308+
checkpoint_key
1309+
)
13081310
if isinstance(
13091311
checkpoint_data, dict
13101312
) and not checkpoint_data.get("has_writes"):
13111313
checkpoint_data["has_writes"] = True
13121314
await self._redis.json().set(
13131315
checkpoint_key, "$", checkpoint_data
1314-
) # type: ignore[misc]
1316+
)
13151317
except Exception:
13161318
# If this fails, it's not critical - the writes are still saved
13171319
pass
@@ -1477,7 +1479,7 @@ async def aget_channel_values(
14771479
)
14781480

14791481
# Single JSON.GET operation to retrieve checkpoint with inline channel_values
1480-
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc]
1482+
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint")
14811483

14821484
if not checkpoint_data:
14831485
return {}

langgraph/checkpoint/redis/ashallow.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def __aexit__(
9494
_tb: Optional[TracebackType],
9595
) -> None:
9696
if self._owns_its_client:
97-
await self._redis.aclose()
97+
await self._redis.aclose() # type: ignore[attr-defined]
9898
# RedisCluster doesn't have connection_pool attribute
9999
if getattr(self._redis, "connection_pool", None):
100100
coro = self._redis.connection_pool.disconnect()
@@ -229,9 +229,6 @@ async def aput(
229229
# Create pipeline for all operations
230230
pipeline = self._redis.pipeline(transaction=False)
231231

232-
# Get the previous checkpoint ID to potentially clean up its writes
233-
pipeline.json().get(checkpoint_key)
234-
235232
# Set the new checkpoint data
236233
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
237234

@@ -240,41 +237,21 @@ async def aput(
240237
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
241238
pipeline.expire(checkpoint_key, ttl_seconds)
242239

243-
# Execute pipeline to get prev data and set new data
244-
results = await pipeline.execute()
245-
prev_checkpoint_data = results[0]
246-
247-
# Check if we need to clean up old writes
248-
prev_checkpoint_id = None
249-
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
250-
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")
251-
252-
# If checkpoint changed, clean up old writes in a second pipeline
253-
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
254-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
255-
256-
# Create cleanup pipeline
257-
cleanup_pipeline = self._redis.pipeline(transaction=False)
258-
259-
# Get all existing write keys
260-
cleanup_pipeline.zrange(thread_zset_key, 0, -1)
261-
262-
# Delete the registry
263-
cleanup_pipeline.delete(thread_zset_key)
264-
265-
# Execute to get keys and delete registry
266-
cleanup_results = await cleanup_pipeline.execute()
267-
existing_write_keys = cleanup_results[0]
240+
# Execute pipeline to set new checkpoint data
241+
await pipeline.execute()
268242

269-
# If there are keys to delete, do it in another pipeline
270-
if existing_write_keys:
271-
delete_pipeline = self._redis.pipeline(transaction=False)
272-
for old_key in existing_write_keys:
273-
old_key_str = (
274-
old_key.decode() if isinstance(old_key, bytes) else old_key
275-
)
276-
delete_pipeline.delete(old_key_str)
277-
await delete_pipeline.execute()
243+
# NOTE: We intentionally do NOT clean up old writes here.
244+
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
245+
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
246+
# when the checkpoint changes, we would delete the interrupt writes
247+
# before they can be consumed when resuming.
248+
#
249+
# Writes are cleaned up in the following scenarios:
250+
# 1. When delete_thread is called
251+
# 2. When TTL expires (if configured)
252+
# 3. When put_writes is called again for the same task/idx (overwrites)
253+
#
254+
# See Issue #133 for details on this bug fix.
278255

279256
return next_config
280257

@@ -388,7 +365,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
388365
)
389366

390367
# Single fetch gets everything inline - matching sync implementation
391-
full_checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc]
368+
full_checkpoint_data = await self._redis.json().get(checkpoint_key)
392369
if not full_checkpoint_data or not isinstance(full_checkpoint_data, dict):
393370
return None
394371

@@ -505,7 +482,11 @@ async def aput_writes(
505482
writes_objects.append(write_obj)
506483

507484
# Thread-level sorted set for write keys
508-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
485+
# Use to_storage_safe_str for consistent key naming
486+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
487+
thread_zset_key = (
488+
f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
489+
)
509490

510491
# Collect all write keys
511492
write_keys = []
@@ -529,7 +510,7 @@ async def aput_writes(
529510

530511
# Use thread-level sorted set
531512
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
532-
pipeline.zadd(thread_zset_key, zadd_mapping)
513+
pipeline.zadd(thread_zset_key, zadd_mapping) # type: ignore[arg-type]
533514

534515
# Apply TTL to registry key if configured
535516
if self.ttl_config and "default_ttl" in self.ttl_config:
@@ -563,7 +544,7 @@ async def aget_channel_values(
563544
)
564545

565546
# Single JSON.GET operation to retrieve checkpoint with inline channel_values
566-
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc]
547+
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint")
567548

568549
if not checkpoint_data:
569550
return {}
@@ -631,7 +612,9 @@ async def _aload_pending_writes(
631612
return []
632613

633614
# Use thread-level sorted set
634-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
615+
# Use to_storage_safe_str for consistent key naming
616+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
617+
thread_zset_key = f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
635618

636619
try:
637620
# Check if we have any writes in the thread sorted set

langgraph/checkpoint/redis/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
563563
return []
564564

565565
writes = []
566-
for write in result["writes"]: # type: ignore[call-overload]
566+
for write in result["writes"]:
567567
writes.append(
568568
(
569569
write["task_id"],
@@ -636,17 +636,17 @@ def put_writes(
636636
# UPSERT case - only update specific fields
637637
if key_exists:
638638
# Update only channel, type, and blob fields
639-
pipeline.set(key, "$.channel", write_obj["channel"])
640-
pipeline.set(key, "$.type", write_obj["type"])
641-
pipeline.set(key, "$.blob", write_obj["blob"])
639+
pipeline.json().set(key, "$.channel", write_obj["channel"])
640+
pipeline.json().set(key, "$.type", write_obj["type"])
641+
pipeline.json().set(key, "$.blob", write_obj["blob"])
642642
else:
643643
# For new records, set the complete object
644-
pipeline.set(key, "$", write_obj)
644+
pipeline.json().set(key, "$", write_obj)
645645
created_keys.append(key)
646646
else:
647647
# INSERT case - only insert if doesn't exist
648648
if not key_exists:
649-
pipeline.set(key, "$", write_obj)
649+
pipeline.json().set(key, "$", write_obj)
650650
created_keys.append(key)
651651

652652
pipeline.execute()

langgraph/checkpoint/redis/key_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def register_write_keys_batch(
105105
)
106106
# Use index as score to maintain order
107107
mapping = {key: idx for idx, key in enumerate(write_keys)}
108-
self._redis.zadd(zset_key, mapping)
108+
self._redis.zadd(zset_key, mapping) # type: ignore[arg-type]
109109

110110
def get_write_keys(
111111
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
@@ -215,7 +215,7 @@ async def register_write_keys_batch(
215215
thread_id, checkpoint_ns, checkpoint_id
216216
)
217217
mapping = {key: idx for idx, key in enumerate(write_keys)}
218-
await self._redis.zadd(zset_key, mapping)
218+
await self._redis.zadd(zset_key, mapping) # type: ignore[arg-type]
219219

220220
async def get_write_keys(
221221
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str

langgraph/checkpoint/redis/shallow.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -200,42 +200,29 @@ def put(
200200
thread_id, checkpoint_ns
201201
)
202202

203-
# Get the previous checkpoint ID to clean up its writes
204-
prev_checkpoint_data = self._redis.json().get(checkpoint_key)
205-
prev_checkpoint_id = None
206-
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
207-
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")
208-
209203
with self._redis.pipeline(transaction=False) as pipeline:
210204
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
211205

212-
# If checkpoint changed, clean up old writes
213-
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
214-
# Clean up writes from the previous checkpoint
215-
thread_write_registry_key = (
216-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
217-
)
218-
219-
# Get all existing write keys and delete them
220-
existing_write_keys = self._redis.zrange(
221-
thread_write_registry_key, 0, -1
222-
)
223-
for old_key in existing_write_keys:
224-
old_key_str = (
225-
old_key.decode() if isinstance(old_key, bytes) else old_key
226-
)
227-
pipeline.delete(old_key_str)
228-
229-
# Clear the registry
230-
pipeline.delete(thread_write_registry_key)
231-
232206
# Apply TTL if configured
233207
if self.ttl_config and "default_ttl" in self.ttl_config:
234208
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
235209
pipeline.expire(checkpoint_key, ttl_seconds)
236210

237211
pipeline.execute()
238212

213+
# NOTE: We intentionally do NOT clean up old writes here.
214+
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
215+
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
216+
# when the checkpoint changes, we would delete the interrupt writes
217+
# before they can be consumed when resuming.
218+
#
219+
# Writes are cleaned up in the following scenarios:
220+
# 1. When delete_thread is called
221+
# 2. When TTL expires (if configured)
222+
# 3. When put_writes is called again for the same task/idx (overwrites)
223+
#
224+
# See Issue #133 for details on this bug fix.
225+
239226
return next_config
240227

241228
def list(
@@ -501,8 +488,10 @@ def put_writes(
501488
writes_objects.append(write_obj)
502489

503490
# THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint
491+
# Use to_storage_safe_str for consistent key naming with delete_thread
492+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
504493
thread_write_registry_key = (
505-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
494+
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
506495
)
507496

508497
# Collect all write keys
@@ -525,7 +514,7 @@ def put_writes(
525514
# THREAD-LEVEL REGISTRY: Store write keys in thread-level sorted set
526515
# These will be cleared when checkpoint changes
527516
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
528-
pipeline.zadd(thread_write_registry_key, zadd_mapping)
517+
pipeline.zadd(thread_write_registry_key, zadd_mapping) # type: ignore[arg-type]
529518

530519
# Note: We don't update has_writes on the checkpoint anymore
531520
# because put_writes can be called before the checkpoint exists
@@ -550,8 +539,10 @@ def _load_pending_writes(
550539

551540
# Use thread-level registry that only contains current checkpoint writes
552541
# All writes belong to the current checkpoint
542+
# Use to_storage_safe_str for consistent key naming with delete_thread
543+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
553544
thread_write_registry_key = (
554-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
545+
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
555546
)
556547

557548
# Get all write keys from the thread's registry (already sorted by index)

langgraph/store/redis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def _batch_search_ops(
541541
if not isinstance(store_doc, dict):
542542
try:
543543
store_doc = json.loads(
544-
store_doc # type: ignore[arg-type]
544+
store_doc
545545
) # Attempt to parse if it's a JSON string
546546
except (json.JSONDecodeError, TypeError):
547547
logger.error(f"Failed to parse store_doc: {store_doc}")

langgraph/store/redis/aio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ async def __aexit__(
376376

377377
# Close Redis connections if we own them
378378
if self._owns_its_client:
379-
await self._redis.aclose()
379+
await self._redis.aclose() # type: ignore[attr-defined]
380380
await self._redis.connection_pool.disconnect()
381381

382382
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
@@ -781,7 +781,7 @@ async def _batch_search_ops(
781781
)
782782
result_map[store_key] = doc
783783
# Fetch individually in cluster mode
784-
store_doc_item = await self._redis.json().get(store_key) # type: ignore
784+
store_doc_item = await self._redis.json().get(store_key)
785785
store_docs.append(store_doc_item)
786786
store_docs_raw = store_docs
787787
else:

0 commit comments

Comments
 (0)