|
5 | 5 | import asyncio
|
6 | 6 | import json
|
7 | 7 | from contextlib import asynccontextmanager
|
8 |
| -from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast |
| 8 | +from functools import partial |
| 9 | +from types import TracebackType |
| 10 | +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast |
9 | 11 |
|
10 | 12 | from langchain_core.runnables import RunnableConfig
|
11 | 13 | from redisvl.index import AsyncSearchIndex
|
|
30 | 32 | )
|
31 | 33 | from langgraph.constants import TASKS
|
32 | 34 | from redis.asyncio import Redis as AsyncRedis
|
| 35 | +from redis.asyncio.client import Pipeline |
33 | 36 |
|
34 | 37 | SCHEMAS = [
|
35 | 38 | {
|
|
77 | 80 | ]
|
78 | 81 |
|
79 | 82 |
|
| 83 | +# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], |
| 84 | +async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None: |
| 85 | + exists: int = await pipe.exists(key) |
| 86 | + if exists: |
| 87 | + await pipe.json().set(key, "$.channel", write_obj["channel"]) |
| 88 | + await pipe.json().set(key, "$.type", write_obj["type"]) |
| 89 | + await pipe.json().set(key, "$.blob", write_obj["blob"]) |
| 90 | + else: |
| 91 | + await pipe.json().set(key, "$", write_obj) |
| 92 | + |
| 93 | + |
80 | 94 | class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
|
81 | 95 | """Async Redis implementation that only stores the most recent checkpoint."""
|
82 | 96 |
|
@@ -104,7 +118,12 @@ def __init__(
|
104 | 118 | async def __aenter__(self) -> AsyncShallowRedisSaver:
|
105 | 119 | return self
|
106 | 120 |
|
107 |
| - async def __aexit__(self, exc_type, exc, tb) -> None: |
| 121 | + async def __aexit__( |
| 122 | + self, |
| 123 | + exc_type: Optional[Type[BaseException]], |
| 124 | + exc: Optional[BaseException], |
| 125 | + tb: Optional[TracebackType], |
| 126 | + ) -> None: |
108 | 127 | if self._owns_its_client:
|
109 | 128 | await self._redis.aclose() # type: ignore[attr-defined]
|
110 | 129 | await self._redis.connection_pool.disconnect()
|
@@ -403,18 +422,7 @@ async def aput_writes(
|
403 | 422 | write_obj["idx"],
|
404 | 423 | )
|
405 | 424 | if upsert_case:
|
406 |
| - |
407 |
| - async def tx(pipe, key=key, write_obj=write_obj): |
408 |
| - exists = await pipe.exists(key) |
409 |
| - if exists: |
410 |
| - await pipe.json().set( |
411 |
| - key, "$.channel", write_obj["channel"] |
412 |
| - ) |
413 |
| - await pipe.json().set(key, "$.type", write_obj["type"]) |
414 |
| - await pipe.json().set(key, "$.blob", write_obj["blob"]) |
415 |
| - else: |
416 |
| - await pipe.json().set(key, "$", write_obj) |
417 |
| - |
| 425 | + tx = partial(_write_obj_tx, key=key, write_obj=write_obj) |
418 | 426 | await self._redis.transaction(tx, key)
|
419 | 427 | else:
|
420 | 428 | # Unlike AsyncRedisSaver, the shallow implementation always overwrites
|
|
0 commit comments