Skip to content

Commit c4c5ff6

Browse files
committed
fix: Fix types and transaction usage
1 parent 2fd1f01 commit c4c5ff6

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

langgraph/checkpoint/redis/aio.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
from collections.abc import AsyncIterator
88
from contextlib import asynccontextmanager
9+
from functools import partial
910
from types import TracebackType
1011
from typing import Any, List, Optional, Sequence, Tuple, Type, cast
1112

@@ -27,6 +28,26 @@
2728
from langgraph.checkpoint.redis.base import BaseRedisSaver
2829
from langgraph.constants import TASKS
2930
from redis.asyncio import Redis as AsyncRedis
31+
from redis.asyncio.client import Pipeline
32+
33+
34+
async def _write_obj_tx(
35+
pipe: Pipeline,
36+
key: str,
37+
write_obj: dict[str, Any],
38+
upsert_case: bool,
39+
) -> None:
40+
exists: int = await pipe.exists(key)
41+
if upsert_case:
42+
if exists:
43+
await pipe.json().set(key, "$.channel", write_obj["channel"])
44+
await pipe.json().set(key, "$.type", write_obj["type"])
45+
await pipe.json().set(key, "$.blob", write_obj["blob"])
46+
else:
47+
await pipe.json().set(key, "$", write_obj)
48+
else:
49+
if not exists:
50+
await pipe.json().set(key, "$", write_obj)
3051

3152

3253
class AsyncRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
@@ -428,24 +449,9 @@ async def aput_writes(
428449
task_id,
429450
write_obj["idx"],
430451
)
431-
432-
async def tx(
433-
pipe, key=key, write_obj=write_obj, upsert_case=upsert_case
434-
):
435-
exists = await pipe.exists(key)
436-
if upsert_case:
437-
if exists:
438-
await pipe.json().set(
439-
key, "$.channel", write_obj["channel"]
440-
)
441-
await pipe.json().set(key, "$.type", write_obj["type"])
442-
await pipe.json().set(key, "$.blob", write_obj["blob"])
443-
else:
444-
await pipe.json().set(key, "$", write_obj)
445-
else:
446-
if not exists:
447-
await pipe.json().set(key, "$", write_obj)
448-
452+
tx = partial(
453+
_write_obj_tx, key=key, write_obj=write_obj, upsert_case=upsert_case
454+
)
449455
await self._redis.transaction(tx, key)
450456

451457
def put_writes(

langgraph/checkpoint/redis/ashallow.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import asyncio
66
import json
77
from contextlib import asynccontextmanager
8-
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast
8+
from functools import partial
9+
from types import TracebackType
10+
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast
911

1012
from langchain_core.runnables import RunnableConfig
1113
from redisvl.index import AsyncSearchIndex
@@ -30,6 +32,7 @@
3032
)
3133
from langgraph.constants import TASKS
3234
from redis.asyncio import Redis as AsyncRedis
35+
from redis.asyncio.client import Pipeline
3336

3437
SCHEMAS = [
3538
{
@@ -77,6 +80,17 @@
7780
]
7881

7982

83+
# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
84+
async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None:
85+
exists: int = await pipe.exists(key)
86+
if exists:
87+
await pipe.json().set(key, "$.channel", write_obj["channel"])
88+
await pipe.json().set(key, "$.type", write_obj["type"])
89+
await pipe.json().set(key, "$.blob", write_obj["blob"])
90+
else:
91+
await pipe.json().set(key, "$", write_obj)
92+
93+
8094
class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
8195
"""Async Redis implementation that only stores the most recent checkpoint."""
8296

@@ -104,7 +118,12 @@ def __init__(
104118
async def __aenter__(self) -> AsyncShallowRedisSaver:
105119
return self
106120

107-
async def __aexit__(self, exc_type, exc, tb) -> None:
121+
async def __aexit__(
122+
self,
123+
exc_type: Optional[Type[BaseException]],
124+
exc: Optional[BaseException],
125+
tb: Optional[TracebackType],
126+
) -> None:
108127
if self._owns_its_client:
109128
await self._redis.aclose() # type: ignore[attr-defined]
110129
await self._redis.connection_pool.disconnect()
@@ -403,18 +422,7 @@ async def aput_writes(
403422
write_obj["idx"],
404423
)
405424
if upsert_case:
406-
407-
async def tx(pipe, key=key, write_obj=write_obj):
408-
exists = await pipe.exists(key)
409-
if exists:
410-
await pipe.json().set(
411-
key, "$.channel", write_obj["channel"]
412-
)
413-
await pipe.json().set(key, "$.type", write_obj["type"])
414-
await pipe.json().set(key, "$.blob", write_obj["blob"])
415-
else:
416-
await pipe.json().set(key, "$", write_obj)
417-
425+
tx = partial(_write_obj_tx, key=key, write_obj=write_obj)
418426
await self._redis.transaction(tx, key)
419427
else:
420428
# Unlike AsyncRedisSaver, the shallow implementation always overwrites

0 commit comments

Comments
 (0)