Skip to content

Commit e9d4b77

Browse files
committed
fix: resolve HITL interrupt persistence in shallow savers (#133)
The shallow savers were prematurely cleaning up writes when a new checkpoint was saved, causing Human-in-the-Loop interrupts to be deleted before they could be consumed during resume. Changes: - Remove aggressive write cleanup from aput/put methods in shallow savers - Fix registry key consistency using to_storage_safe_str() for checkpoint_ns - Add comprehensive HITL integration tests (7 new tests) - Update existing tests to reflect new write persistence behavior Writes are now cleaned up only via delete_thread, TTL expiration, or explicit overwrite - not during checkpoint transitions.
1 parent 963efbf commit e9d4b77

File tree

5 files changed

+665
-87
lines changed

5 files changed

+665
-87
lines changed

langgraph/checkpoint/redis/ashallow.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,6 @@ async def aput(
229229
# Create pipeline for all operations
230230
pipeline = self._redis.pipeline(transaction=False)
231231

232-
# Get the previous checkpoint ID to potentially clean up its writes
233-
pipeline.json().get(checkpoint_key)
234-
235232
# Set the new checkpoint data
236233
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
237234

@@ -240,41 +237,21 @@ async def aput(
240237
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
241238
pipeline.expire(checkpoint_key, ttl_seconds)
242239

243-
# Execute pipeline to get prev data and set new data
244-
results = await pipeline.execute()
245-
prev_checkpoint_data = results[0]
246-
247-
# Check if we need to clean up old writes
248-
prev_checkpoint_id = None
249-
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
250-
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")
251-
252-
# If checkpoint changed, clean up old writes in a second pipeline
253-
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
254-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
255-
256-
# Create cleanup pipeline
257-
cleanup_pipeline = self._redis.pipeline(transaction=False)
258-
259-
# Get all existing write keys
260-
cleanup_pipeline.zrange(thread_zset_key, 0, -1)
261-
262-
# Delete the registry
263-
cleanup_pipeline.delete(thread_zset_key)
264-
265-
# Execute to get keys and delete registry
266-
cleanup_results = await cleanup_pipeline.execute()
267-
existing_write_keys = cleanup_results[0]
240+
# Execute pipeline to set new checkpoint data
241+
await pipeline.execute()
268242

269-
# If there are keys to delete, do it in another pipeline
270-
if existing_write_keys:
271-
delete_pipeline = self._redis.pipeline(transaction=False)
272-
for old_key in existing_write_keys:
273-
old_key_str = (
274-
old_key.decode() if isinstance(old_key, bytes) else old_key
275-
)
276-
delete_pipeline.delete(old_key_str)
277-
await delete_pipeline.execute()
243+
# NOTE: We intentionally do NOT clean up old writes here.
244+
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
245+
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
246+
# when the checkpoint changes, we would delete the interrupt writes
247+
# before they can be consumed when resuming.
248+
#
249+
# Writes are cleaned up in the following scenarios:
250+
# 1. When delete_thread is called
251+
# 2. When TTL expires (if configured)
252+
# 3. When put_writes is called again for the same task/idx (overwrites)
253+
#
254+
# See Issue #133 for details on this bug fix.
278255

279256
return next_config
280257

@@ -505,7 +482,11 @@ async def aput_writes(
505482
writes_objects.append(write_obj)
506483

507484
# Thread-level sorted set for write keys
508-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
485+
# Use to_storage_safe_str for consistent key naming
486+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
487+
thread_zset_key = (
488+
f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
489+
)
509490

510491
# Collect all write keys
511492
write_keys = []
@@ -631,7 +612,9 @@ async def _aload_pending_writes(
631612
return []
632613

633614
# Use thread-level sorted set
634-
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
615+
# Use to_storage_safe_str for consistent key naming
616+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
617+
thread_zset_key = f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
635618

636619
try:
637620
# Check if we have any writes in the thread sorted set

langgraph/checkpoint/redis/shallow.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,42 +200,29 @@ def put(
200200
thread_id, checkpoint_ns
201201
)
202202

203-
# Get the previous checkpoint ID to clean up its writes
204-
prev_checkpoint_data = self._redis.json().get(checkpoint_key)
205-
prev_checkpoint_id = None
206-
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
207-
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")
208-
209203
with self._redis.pipeline(transaction=False) as pipeline:
210204
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
211205

212-
# If checkpoint changed, clean up old writes
213-
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
214-
# Clean up writes from the previous checkpoint
215-
thread_write_registry_key = (
216-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
217-
)
218-
219-
# Get all existing write keys and delete them
220-
existing_write_keys = self._redis.zrange(
221-
thread_write_registry_key, 0, -1
222-
)
223-
for old_key in existing_write_keys:
224-
old_key_str = (
225-
old_key.decode() if isinstance(old_key, bytes) else old_key
226-
)
227-
pipeline.delete(old_key_str)
228-
229-
# Clear the registry
230-
pipeline.delete(thread_write_registry_key)
231-
232206
# Apply TTL if configured
233207
if self.ttl_config and "default_ttl" in self.ttl_config:
234208
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
235209
pipeline.expire(checkpoint_key, ttl_seconds)
236210

237211
pipeline.execute()
238212

213+
# NOTE: We intentionally do NOT clean up old writes here.
214+
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
215+
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
216+
# when the checkpoint changes, we would delete the interrupt writes
217+
# before they can be consumed when resuming.
218+
#
219+
# Writes are cleaned up in the following scenarios:
220+
# 1. When delete_thread is called
221+
# 2. When TTL expires (if configured)
222+
# 3. When put_writes is called again for the same task/idx (overwrites)
223+
#
224+
# See Issue #133 for details on this bug fix.
225+
239226
return next_config
240227

241228
def list(
@@ -501,8 +488,10 @@ def put_writes(
501488
writes_objects.append(write_obj)
502489

503490
# THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint
491+
# Use to_storage_safe_str for consistent key naming with delete_thread
492+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
504493
thread_write_registry_key = (
505-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
494+
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
506495
)
507496

508497
# Collect all write keys
@@ -550,8 +539,10 @@ def _load_pending_writes(
550539

551540
# Use thread-level registry that only contains current checkpoint writes
552541
# All writes belong to the current checkpoint
542+
# Use to_storage_safe_str for consistent key naming with delete_thread
543+
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
553544
thread_write_registry_key = (
554-
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
545+
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
555546
)
556547

557548
# Get all write keys from the thread's registry (already sorted by index)

0 commit comments

Comments
 (0)