7
7
import logging
8
8
import os
9
9
from contextlib import asynccontextmanager
10
- from functools import partial
11
10
from types import TracebackType
12
11
from typing import (
13
12
Any ,
34
33
)
35
34
from langgraph .constants import TASKS
36
35
from redis .asyncio import Redis as AsyncRedis
37
- from redis .asyncio .client import Pipeline
38
36
from redis .asyncio .cluster import RedisCluster as AsyncRedisCluster
39
37
from redisvl .index import AsyncSearchIndex
40
38
from redisvl .query import FilterQuery
41
39
from redisvl .query .filter import Num , Tag
42
- from redisvl .redis .connection import RedisConnectionFactory
43
40
44
41
from langgraph .checkpoint .redis .base import BaseRedisSaver
45
42
from langgraph .checkpoint .redis .util import (
54
51
logger = logging .getLogger (__name__ )
55
52
56
53
57
- async def _write_obj_tx (
58
- pipe : Pipeline ,
59
- key : str ,
60
- write_obj : Dict [str , Any ],
61
- upsert_case : bool ,
62
- ) -> None :
63
- exists : int = await pipe .exists (key )
64
- if upsert_case :
65
- if exists :
66
- await pipe .json ().set (key , "$.channel" , write_obj ["channel" ])
67
- await pipe .json ().set (key , "$.type" , write_obj ["type" ])
68
- await pipe .json ().set (key , "$.blob" , write_obj ["blob" ])
69
- else :
70
- await pipe .json ().set (key , "$" , write_obj )
71
- else :
72
- if not exists :
73
- await pipe .json ().set (key , "$" , write_obj )
74
-
75
-
76
54
class AsyncRedisSaver (
77
55
BaseRedisSaver [Union [AsyncRedis , AsyncRedisCluster ], AsyncSearchIndex ]
78
56
):
@@ -568,7 +546,7 @@ async def aput(
568
546
# store at top-level for filters in list()
569
547
if all (key in metadata for key in ["source" , "step" ]):
570
548
checkpoint_data ["source" ] = metadata ["source" ]
571
- checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
549
+ checkpoint_data ["step" ] = metadata ["step" ]
572
550
573
551
# Prepare checkpoint key
574
552
checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
@@ -587,11 +565,11 @@ async def aput(
587
565
588
566
if self .cluster_mode :
589
567
# For cluster mode, execute operations individually
590
- await self ._redis .json ().set (checkpoint_key , "$" , checkpoint_data )
568
+ await self ._redis .json ().set (checkpoint_key , "$" , checkpoint_data ) # type: ignore[misc]
591
569
592
570
if blobs :
593
571
for key , data in blobs :
594
- await self ._redis .json ().set (key , "$" , data )
572
+ await self ._redis .json ().set (key , "$" , data ) # type: ignore[misc]
595
573
596
574
# Apply TTL if configured
597
575
if self .ttl_config and "default_ttl" in self .ttl_config :
@@ -604,12 +582,12 @@ async def aput(
604
582
pipeline = self ._redis .pipeline (transaction = True )
605
583
606
584
# Add checkpoint data to pipeline
607
- await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
585
+ pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
608
586
609
587
if blobs :
610
588
# Add all blob operations to the pipeline
611
589
for key , data in blobs :
612
- await pipeline .json ().set (key , "$" , data )
590
+ pipeline .json ().set (key , "$" , data )
613
591
614
592
# Execute all operations atomically
615
593
await pipeline .execute ()
@@ -654,13 +632,13 @@ async def aput(
654
632
655
633
if self .cluster_mode :
656
634
# For cluster mode, execute operation directly
657
- await self ._redis .json ().set (
635
+ await self ._redis .json ().set ( # type: ignore[misc]
658
636
checkpoint_key , "$" , checkpoint_data
659
637
)
660
638
else :
661
639
# For non-cluster mode, use pipeline
662
640
pipeline = self ._redis .pipeline (transaction = True )
663
- await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
641
+ pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
664
642
await pipeline .execute ()
665
643
except Exception :
666
644
# If this also fails, we just propagate the original cancellation
@@ -739,24 +717,18 @@ async def aput_writes(
739
717
exists = await self ._redis .exists (key )
740
718
if exists :
741
719
# Update existing key
742
- await self ._redis .json ().set (
743
- key , "$.channel" , write_obj ["channel" ]
744
- )
745
- await self ._redis .json ().set (
746
- key , "$.type" , write_obj ["type" ]
747
- )
748
- await self ._redis .json ().set (
749
- key , "$.blob" , write_obj ["blob" ]
750
- )
720
+ await self ._redis .json ().set (key , "$.channel" , write_obj ["channel" ]) # type: ignore[misc, arg-type]
721
+ await self ._redis .json ().set (key , "$.type" , write_obj ["type" ]) # type: ignore[misc, arg-type]
722
+ await self ._redis .json ().set (key , "$.blob" , write_obj ["blob" ]) # type: ignore[misc, arg-type]
751
723
else :
752
724
# Create new key
753
- await self ._redis .json ().set (key , "$" , write_obj )
725
+ await self ._redis .json ().set (key , "$" , write_obj ) # type: ignore[misc]
754
726
created_keys .append (key )
755
727
else :
756
728
# For non-upsert case, only set if key doesn't exist
757
729
exists = await self ._redis .exists (key )
758
730
if not exists :
759
- await self ._redis .json ().set (key , "$" , write_obj )
731
+ await self ._redis .json ().set (key , "$" , write_obj ) # type: ignore[misc]
760
732
created_keys .append (key )
761
733
762
734
# Apply TTL to newly created keys
@@ -788,20 +760,30 @@ async def aput_writes(
788
760
exists = await self ._redis .exists (key )
789
761
if exists :
790
762
# Update existing key
791
- await pipeline .json ().set (
792
- key , "$.channel" , write_obj ["channel" ]
763
+ pipeline .json ().set (
764
+ key ,
765
+ "$.channel" ,
766
+ write_obj ["channel" ], # type: ignore[arg-type]
767
+ )
768
+ pipeline .json ().set (
769
+ key ,
770
+ "$.type" ,
771
+ write_obj ["type" ], # type: ignore[arg-type]
772
+ )
773
+ pipeline .json ().set (
774
+ key ,
775
+ "$.blob" ,
776
+ write_obj ["blob" ], # type: ignore[arg-type]
793
777
)
794
- await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
795
- await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
796
778
else :
797
779
# Create new key
798
- await pipeline .json ().set (key , "$" , write_obj )
780
+ pipeline .json ().set (key , "$" , write_obj )
799
781
created_keys .append (key )
800
782
else :
801
783
# For non-upsert case, only set if key doesn't exist
802
784
exists = await self ._redis .exists (key )
803
785
if not exists :
804
- await pipeline .json ().set (key , "$" , write_obj )
786
+ pipeline .json ().set (key , "$" , write_obj )
805
787
created_keys .append (key )
806
788
807
789
# Execute all operations atomically
0 commit comments