@@ -94,7 +94,7 @@ async def __aexit__(
9494 _tb : Optional [TracebackType ],
9595 ) -> None :
9696 if self ._owns_its_client :
97- await self ._redis .aclose ()
97+ await self ._redis .aclose () # type: ignore[attr-defined]
9898 # RedisCluster doesn't have connection_pool attribute
9999 if getattr (self ._redis , "connection_pool" , None ):
100100 coro = self ._redis .connection_pool .disconnect ()
@@ -229,9 +229,6 @@ async def aput(
229229 # Create pipeline for all operations
230230 pipeline = self ._redis .pipeline (transaction = False )
231231
232- # Get the previous checkpoint ID to potentially clean up its writes
233- pipeline .json ().get (checkpoint_key )
234-
235232 # Set the new checkpoint data
236233 pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
237234
@@ -240,41 +237,21 @@ async def aput(
240237 ttl_seconds = int (self .ttl_config .get ("default_ttl" ) * 60 )
241238 pipeline .expire (checkpoint_key , ttl_seconds )
242239
243- # Execute pipeline to get prev data and set new data
244- results = await pipeline .execute ()
245- prev_checkpoint_data = results [0 ]
246-
247- # Check if we need to clean up old writes
248- prev_checkpoint_id = None
249- if prev_checkpoint_data and isinstance (prev_checkpoint_data , dict ):
250- prev_checkpoint_id = prev_checkpoint_data .get ("checkpoint_id" )
251-
252- # If checkpoint changed, clean up old writes in a second pipeline
253- if prev_checkpoint_id and prev_checkpoint_id != checkpoint ["id" ]:
254- thread_zset_key = f"write_keys_zset:{ thread_id } :{ checkpoint_ns } :shallow"
255-
256- # Create cleanup pipeline
257- cleanup_pipeline = self ._redis .pipeline (transaction = False )
258-
259- # Get all existing write keys
260- cleanup_pipeline .zrange (thread_zset_key , 0 , - 1 )
261-
262- # Delete the registry
263- cleanup_pipeline .delete (thread_zset_key )
264-
265- # Execute to get keys and delete registry
266- cleanup_results = await cleanup_pipeline .execute ()
267- existing_write_keys = cleanup_results [0 ]
240+ # Execute pipeline to set new checkpoint data
241+ await pipeline .execute ()
268242
269- # If there are keys to delete, do it in another pipeline
270- if existing_write_keys :
271- delete_pipeline = self ._redis .pipeline (transaction = False )
272- for old_key in existing_write_keys :
273- old_key_str = (
274- old_key .decode () if isinstance (old_key , bytes ) else old_key
275- )
276- delete_pipeline .delete (old_key_str )
277- await delete_pipeline .execute ()
243+ # NOTE: We intentionally do NOT clean up old writes here.
244+ # In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
245+ # put_writes BEFORE the new checkpoint is saved. If we clean up writes
246+ # when the checkpoint changes, we would delete the interrupt writes
247+ # before they can be consumed when resuming.
248+ #
249+ # Writes are cleaned up in the following scenarios:
250+ # 1. When delete_thread is called
251+ # 2. When TTL expires (if configured)
252+ # 3. When put_writes is called again for the same task/idx (overwrites)
253+ #
254+ # See Issue #133 for details on this bug fix.
278255
279256 return next_config
280257
@@ -388,7 +365,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
388365 )
389366
390367 # Single fetch gets everything inline - matching sync implementation
391- full_checkpoint_data = await self ._redis .json ().get (checkpoint_key ) # type: ignore[misc]
368+ full_checkpoint_data = await self ._redis .json ().get (checkpoint_key )
392369 if not full_checkpoint_data or not isinstance (full_checkpoint_data , dict ):
393370 return None
394371
@@ -505,7 +482,11 @@ async def aput_writes(
505482 writes_objects .append (write_obj )
506483
507484 # Thread-level sorted set for write keys
508- thread_zset_key = f"write_keys_zset:{ thread_id } :{ checkpoint_ns } :shallow"
485+ # Use to_storage_safe_str for consistent key naming
486+ safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
487+ thread_zset_key = (
488+ f"write_keys_zset:{ thread_id } :{ safe_checkpoint_ns } :shallow"
489+ )
509490
510491 # Collect all write keys
511492 write_keys = []
@@ -529,7 +510,7 @@ async def aput_writes(
529510
530511 # Use thread-level sorted set
531512 zadd_mapping = {key : idx for idx , key in enumerate (write_keys )}
532- pipeline .zadd (thread_zset_key , zadd_mapping )
513+ pipeline .zadd (thread_zset_key , zadd_mapping ) # type: ignore[arg-type]
533514
534515 # Apply TTL to registry key if configured
535516 if self .ttl_config and "default_ttl" in self .ttl_config :
@@ -563,7 +544,7 @@ async def aget_channel_values(
563544 )
564545
565546 # Single JSON.GET operation to retrieve checkpoint with inline channel_values
566- checkpoint_data = await self ._redis .json ().get (checkpoint_key , "$.checkpoint" ) # type: ignore[misc]
547+ checkpoint_data = await self ._redis .json ().get (checkpoint_key , "$.checkpoint" )
567548
568549 if not checkpoint_data :
569550 return {}
@@ -631,7 +612,9 @@ async def _aload_pending_writes(
631612 return []
632613
633614 # Use thread-level sorted set
634- thread_zset_key = f"write_keys_zset:{ thread_id } :{ checkpoint_ns } :shallow"
615+ # Use to_storage_safe_str for consistent key naming
616+ safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
617+ thread_zset_key = f"write_keys_zset:{ thread_id } :{ safe_checkpoint_ns } :shallow"
635618
636619 try :
637620 # Check if we have any writes in the thread sorted set
0 commit comments