Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,32 @@ with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as ch
# Use the checkpointer...
```

This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up.
### Removing TTL (Pinning Threads)

You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire:

```python
from langgraph.checkpoint.redis import RedisSaver

# Create saver with default TTL
saver = RedisSaver.from_conn_string("redis://localhost:6379", ttl={"default_ttl": 60})
saver.setup()

# Save a checkpoint
config = {"configurable": {"thread_id": "important-thread", "checkpoint_ns": ""}}
saved_config = saver.put(config, checkpoint, metadata, {})

# Remove TTL from the checkpoint to make it persistent
checkpoint_id = saved_config["configurable"]["checkpoint_id"]
checkpoint_key = f"checkpoint:important-thread:__empty__:{checkpoint_id}"
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)

# The checkpoint is now persistent and won't expire
```

When no TTL configuration is provided, checkpoints are persistent by default (no expiration).

This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up while keeping important data persistent.

## Redis Stores

Expand Down Expand Up @@ -370,11 +395,13 @@ For Redis Stores with vector search:

Both Redis checkpoint savers and stores leverage Redis's native key expiration:

- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command for setting TTL
- **TTL Removal**: Uses Redis's `PERSIST` command to remove TTL (with `ttl_minutes=-1`)
- **Automatic Cleanup**: Redis automatically removes expired keys
- **Configurable Default TTL**: Set a default TTL for all keys in minutes
- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed
- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes)
- **Persistent by Default**: When no TTL is configured, keys are persistent (no expiration)

## Contributing

Expand Down
27 changes: 27 additions & 0 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ async def _apply_ttl_to_keys(
main_key: The primary Redis key
related_keys: Additional Redis keys that should expire at the same time
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
Use -1 to remove TTL (make keys persistent)

Returns:
Result of the Redis operation
Expand All @@ -305,6 +306,32 @@ async def _apply_ttl_to_keys(
ttl_minutes = self.ttl_config.get("default_ttl")

if ttl_minutes is not None:
# Special case: -1 means remove TTL (make persistent)
if ttl_minutes == -1:
if self.cluster_mode:
# For cluster mode, execute PERSIST operations individually
await self._redis.persist(main_key)

if related_keys:
for key in related_keys:
await self._redis.persist(key)

return True
else:
# For non-cluster mode, use pipeline for efficiency
pipeline = self._redis.pipeline()

# Remove TTL for main key
pipeline.persist(main_key)

# Remove TTL for related keys
if related_keys:
for key in related_keys:
pipeline.persist(key)

return await pipeline.execute()

# Regular TTL setting
ttl_seconds = int(ttl_minutes * 60)

if self.cluster_mode:
Expand Down
30 changes: 30 additions & 0 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _apply_ttl_to_keys(
main_key: The primary Redis key
related_keys: Additional Redis keys that should expire at the same time
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided
Use -1 to remove TTL (make keys persistent)

Returns:
Result of the Redis operation
Expand All @@ -248,6 +249,35 @@ def _apply_ttl_to_keys(
ttl_minutes = self.ttl_config.get("default_ttl")

if ttl_minutes is not None:
# Special case: -1 means remove TTL (make persistent)
if ttl_minutes == -1:
# Check if cluster mode is detected (for sync checkpoint savers)
cluster_mode = getattr(self, "cluster_mode", False)

if cluster_mode:
# For cluster mode, execute PERSIST operations individually
self._redis.persist(main_key)

if related_keys:
for key in related_keys:
self._redis.persist(key)

return True
else:
# For non-cluster mode, use pipeline for efficiency
pipeline = self._redis.pipeline()

# Remove TTL for main key
pipeline.persist(main_key)

# Remove TTL for related keys
if related_keys:
for key in related_keys:
pipeline.persist(key)

return pipeline.execute()

# Regular TTL setting
ttl_seconds = int(ttl_minutes * 60)

# Check if cluster mode is detected (for sync checkpoint savers)
Expand Down
234 changes: 234 additions & 0 deletions tests/test_ttl_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""Tests for TTL removal feature (issue #66)."""

import time
from uuid import uuid4

import pytest
from langgraph.checkpoint.base import create_checkpoint, empty_checkpoint

from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver


def test_ttl_removal_with_negative_one(redis_url: str) -> None:
"""Test that ttl_minutes=-1 removes TTL from keys."""
saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) # 1 minute default TTL
saver.setup()

thread_id = str(uuid4())
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
)
checkpoint["channel_values"]["messages"] = ["test"]

config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

# Save checkpoint (will have TTL)
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})

checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"

# Verify TTL is set
ttl = saver._redis.ttl(checkpoint_key)
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"

# Remove TTL using -1
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)

# Verify TTL is removed
ttl_after = saver._redis.ttl(checkpoint_key)
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"


def test_ttl_removal_with_related_keys(redis_url: str) -> None:
"""Test that TTL removal works for main key and related keys."""
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
saver.setup()

thread_id = str(uuid4())

# Create a checkpoint with writes (to have related keys)
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
)
checkpoint["channel_values"]["messages"] = ["test"]

config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
"checkpoint_id": "test-checkpoint",
}
}

# Save checkpoint and writes
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
saver.put_writes(
saved_config, [("channel1", "value1"), ("channel2", "value2")], "task-1"
)

# Get the keys
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
write_key1 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:0"
write_key2 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:1"

# All keys should have TTL
assert 50 <= saver._redis.ttl(checkpoint_key) <= 60
assert 50 <= saver._redis.ttl(write_key1) <= 60
assert 50 <= saver._redis.ttl(write_key2) <= 60

# Remove TTL from all keys
saver._apply_ttl_to_keys(checkpoint_key, [write_key1, write_key2], ttl_minutes=-1)

# All keys should be persistent
assert saver._redis.ttl(checkpoint_key) == -1
assert saver._redis.ttl(write_key1) == -1
assert saver._redis.ttl(write_key2) == -1


def test_no_ttl_means_persistent(redis_url: str) -> None:
"""Test that no TTL configuration means keys are persistent."""
# Create saver with no TTL config
saver = RedisSaver(redis_url) # No TTL config
saver.setup()

thread_id = str(uuid4())
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1
)
checkpoint["channel_values"]["messages"] = ["test"]

config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

# Save checkpoint
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})

# Check TTL
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
ttl = saver._redis.ttl(checkpoint_key)

# Should be -1 (persistent) when no TTL configured
assert ttl == -1, "Key should be persistent when no TTL configured"


def test_ttl_removal_preserves_data(redis_url: str) -> None:
"""Test that removing TTL doesn't affect the data."""
saver = RedisSaver(redis_url, ttl={"default_ttl": 1})
saver.setup()

thread_id = str(uuid4())
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(), channels={"messages": ["original data"]}, step=1
)
checkpoint["channel_values"]["messages"] = ["original data"]

config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

# Save checkpoint
saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})

# Load data before TTL removal
loaded_before = saver.get_tuple(saved_config)
assert loaded_before.checkpoint["channel_values"]["messages"] == ["original data"]

# Remove TTL
checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"
saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)

# Load data after TTL removal
loaded_after = saver.get_tuple(saved_config)
assert loaded_after.checkpoint["channel_values"]["messages"] == ["original data"]

# Verify TTL is removed
assert saver._redis.ttl(checkpoint_key) == -1


@pytest.mark.asyncio
async def test_async_ttl_removal(redis_url: str) -> None:
"""Test TTL removal with async saver."""
async with AsyncRedisSaver.from_conn_string(
redis_url, ttl={"default_ttl": 1}
) as saver:
thread_id = str(uuid4())
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(), channels={"messages": ["async test"]}, step=1
)
checkpoint["channel_values"]["messages"] = ["async test"]

config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

# Save checkpoint
saved_config = await saver.aput(
config, checkpoint, {"source": "test", "step": 1}, {}
)

checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}"

# Verify TTL is set
ttl = await saver._redis.ttl(checkpoint_key)
assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}"

# Remove TTL using -1
await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)

# Verify TTL is removed
ttl_after = await saver._redis.ttl(checkpoint_key)
assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1"


def test_pin_thread_use_case(redis_url: str) -> None:
"""Test the 'pin thread' use case from issue #66.

This simulates pinning a specific thread by removing its TTL,
making it persistent while other threads expire.
"""
saver = RedisSaver(
redis_url, ttl={"default_ttl": 0.1}
) # 6 seconds TTL for quick test
saver.setup()

# Create two threads
thread_to_pin = str(uuid4())
thread_to_expire = str(uuid4())

# Store checkpoint IDs to avoid using wildcards (more efficient and precise)
checkpoint_ids = {}

for thread_id in [thread_to_pin, thread_to_expire]:
checkpoint = create_checkpoint(
checkpoint=empty_checkpoint(),
channels={"messages": [f"Thread {thread_id}"]},
step=1,
)
checkpoint["channel_values"]["messages"] = [f"Thread {thread_id}"]

config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
checkpoint_ids[thread_id] = saved_config["configurable"]["checkpoint_id"]

# Pin the first thread by removing its TTL using exact key
pinned_checkpoint_key = (
f"checkpoint:{thread_to_pin}:__empty__:{checkpoint_ids[thread_to_pin]}"
)
saver._apply_ttl_to_keys(pinned_checkpoint_key, ttl_minutes=-1)

# Verify pinned thread has no TTL
assert saver._redis.exists(pinned_checkpoint_key) == 1
assert saver._redis.ttl(pinned_checkpoint_key) == -1

# Verify other thread still has TTL
expiring_checkpoint_key = (
f"checkpoint:{thread_to_expire}:__empty__:{checkpoint_ids[thread_to_expire]}"
)
assert saver._redis.exists(expiring_checkpoint_key) == 1
ttl = saver._redis.ttl(expiring_checkpoint_key)
assert 0 < ttl <= 6

# Wait for expiring thread to expire
time.sleep(7)

# Pinned thread should still exist
assert saver._redis.exists(pinned_checkpoint_key) == 1

# Expiring thread should be gone
assert saver._redis.exists(expiring_checkpoint_key) == 0