2121 from redis .commands .search .document import Document
2222 from redis .commands .search .result import Result
2323 from redisvl .query .query import BaseQuery
24+ import redis .asyncio
2425
2526import redis
2627import redis .asyncio as aredis
@@ -793,16 +794,29 @@ class AsyncSearchIndex(BaseSearchIndex):
793794
794795 """
795796
797+ # TODO: The `aredis.Redis` type is not working for type checks.
798+ _redis_client : Optional [redis .asyncio .Redis ] = None
799+ _redis_url : Optional [str ] = None
800+ _redis_kwargs : Dict [str , Any ] = {}
801+
796802 def __init__ (
797803 self ,
798804 schema : IndexSchema ,
805+ * ,
806+ redis_url : Optional [str ] = None ,
807+ redis_client : Optional [aredis .Redis ] = None ,
808+ redis_kwargs : Optional [Dict [str , Any ]] = None ,
799809 ** kwargs ,
800810 ):
801811 """Initialize the RedisVL async search index with a schema.
802812
803813 Args:
804814 schema (IndexSchema): Index schema object.
805- connection_args (Dict[str, Any], optional): Redis client connection
815+ redis_url (Optional[str], optional): The URL of the Redis server to
816+ connect to.
817+ redis_client (Optional[aredis.Redis], optional): An
818+ instantiated redis client.
819+ redis_kwargs (Dict[str, Any], optional): Redis client connection
806820 args.
807821 """
808822 # final validation on schema object
@@ -813,39 +827,21 @@ def __init__(
813827
814828 self ._lib_name : Optional [str ] = kwargs .pop ("lib_name" , None )
815829
816- # set up empty redis connection
817- self ._redis_client : Optional [aredis .Redis ] = None
818-
819- if "redis_client" in kwargs or "redis_url" in kwargs :
820- logger .warning (
821- "Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
822- )
823-
824- atexit .register (self ._cleanup_connection )
825-
826- def _cleanup_connection (self ):
827- if self ._redis_client :
828-
829- def run_in_thread ():
830- try :
831- loop = asyncio .new_event_loop ()
832- asyncio .set_event_loop (loop )
833- loop .run_until_complete (self ._redis_client .aclose ())
834- loop .close ()
835- except RuntimeError :
836- pass
837-
838- # Run cleanup in a background thread to avoid event loop issues
839- thread = threading .Thread (target = run_in_thread )
840- thread .start ()
841- thread .join ()
830+ # Store connection parameters
831+ if redis_client and redis_url :
832+ raise ValueError ("Cannot provide both redis_client and redis_url" )
842833
834+ self ._redis_client = redis_client
835+ self ._redis_url = redis_url
836+ self ._redis_kwargs = redis_kwargs or {}
837+ self ._lock = asyncio .Lock ()
838+
839+ async def disconnect (self ):
840+ """Asynchronously disconnect and cleanup the underlying async redis connection."""
841+ if self ._redis_client is not None :
842+ await self ._redis_client .aclose () # type: ignore
843843 self ._redis_client = None
844844
845- def disconnect (self ):
846- """Disconnect and cleanup the underlying async redis connection."""
847- self ._cleanup_connection ()
848-
849845 @classmethod
850846 async def from_existing (
851847 cls ,
@@ -902,69 +898,59 @@ def client(self) -> Optional[aredis.Redis]:
902898 return self ._redis_client
903899
904900 async def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
905- """Connect to a Redis instance using the provided `redis_url`, falling
906- back to the `REDIS_URL` environment variable (if available).
901+ """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
902+ import warnings
907903
908- Note: Additional keyword arguments (`**kwargs`) can be used to provide
909- extra options specific to the Redis connection.
910-
911- Args:
912- redis_url (Optional[str], optional): The URL of the Redis server to
913- connect to. If not provided, the method defaults to using the
914- `REDIS_URL` environment variable.
915-
916- Raises:
917- redis.exceptions.ConnectionError: If the connection to the Redis
918- server fails.
919- ValueError: If the Redis URL is not provided nor accessible
920- through the `REDIS_URL` environment variable.
921-
922- .. code-block:: python
923-
924- index.connect(redis_url="redis://localhost:6379")
925-
926- """
904+ warnings .warn (
905+ "connect() is deprecated; pass connection parameters in __init__" ,
906+ DeprecationWarning ,
907+ )
927908 client = RedisConnectionFactory .connect (
928909 redis_url = redis_url , use_async = True , ** kwargs
929910 )
930911 return await self .set_client (client )
931912
932- @setup_async_redis ()
933- async def set_client (self , redis_client : aredis .Redis ):
934- """Manually set the Redis client to use with the search index.
935-
936- This method configures the search index to use a specific
937- Async Redis client. It is useful for cases where an external,
938- custom-configured client is preferred instead of creating a new one.
939-
940- Args:
941- redis_client (aredis.Redis): An Async Redis
942- client instance to be used for the connection.
943-
944- Raises:
945- TypeError: If the provided client is not valid.
946-
947- .. code-block:: python
913+ async def set_client (self , redis_client : Optional [aredis .Redis ]):
914+ """[DEPRECATED] Manually set the Redis client to use with the search index.
915+ This method is deprecated; please provide connection parameters in __init__.
916+ """
917+ import warnings
948918
949- import redis.asyncio as aredis
950- from redisvl.index import AsyncSearchIndex
919+ warnings .warn (
920+ "set_client() is deprecated; pass connection parameters in __init__" ,
921+ DeprecationWarning ,
922+ )
923+ return await self ._set_client (redis_client )
951924
952- # async Redis client and index
953- client = aredis.Redis.from_url("redis://localhost:6379")
954- index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
955- await index.set_client(client)
925+ async def _set_client (self , redis_client : Optional [redis .asyncio .Redis ]):
926+ """
927+ Set the Redis client to use with the search index.
956928
929+ NOTE: Remove this method once the deprecation period is over.
957930 """
958- if isinstance (redis_client , redis .Redis ):
959- print ("Setting client and converting from async" , flush = True )
960- self ._redis_client = RedisConnectionFactory .sync_to_async_redis (
961- redis_client
962- )
963- else :
931+ if self ._redis_client is not None :
932+ await self ._redis_client .aclose () # type: ignore
933+ async with self ._lock :
964934 self ._redis_client = redis_client
965-
966935 return self
967936
937+ async def _get_client (self ) -> aredis .Redis :
938+ """Lazily instantiate and return the async Redis client."""
939+ if self ._redis_client is None :
940+ async with self ._lock :
941+ # Double-check to protect against concurrent access
942+ if self ._redis_client is None :
943+ kwargs = self ._redis_kwargs
944+ if self ._redis_url :
945+ kwargs ["redis_url" ] = self ._redis_url
946+ self ._redis_client = (
947+ RedisConnectionFactory .get_async_redis_connection (** kwargs )
948+ )
949+ await RedisConnectionFactory .validate_async_redis (
950+ self ._redis_client , self ._lib_name
951+ )
952+ return self ._redis_client
953+
968954 async def create (self , overwrite : bool = False , drop : bool = False ) -> None :
969955 """Asynchronously create an index in Redis with the current schema
970956 and properties.
@@ -990,6 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
990976 # overwrite an index in Redis; drop associated data (clean slate)
991977 await index.create(overwrite=True, drop=True)
992978 """
979+ client = await self ._get_client ()
993980 redis_fields = self .schema .redis_fields
994981
995982 if not redis_fields :
@@ -1005,7 +992,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
1005992 await self .delete (drop )
1006993
1007994 try :
1008- await self . _redis_client . ft (self .schema .index .name ).create_index ( # type: ignore
995+ await client . ft (self .schema .index .name ).create_index (
1009996 fields = redis_fields ,
1010997 definition = IndexDefinition (
1011998 prefix = [self .schema .index .prefix ], index_type = self ._storage .type
@@ -1025,10 +1012,9 @@ async def delete(self, drop: bool = True):
10251012 Raises:
10261013 redis.exceptions.ResponseError: If the index does not exist.
10271014 """
1015+ client = await self ._get_client ()
10281016 try :
1029- await self ._redis_client .ft (self .schema .index .name ).dropindex ( # type: ignore
1030- delete_documents = drop
1031- )
1017+ await client .ft (self .schema .index .name ).dropindex (delete_documents = drop )
10321018 except Exception as e :
10331019 raise RedisSearchError (f"Error while deleting index: { str (e )} " ) from e
10341020
@@ -1039,16 +1025,15 @@ async def clear(self) -> int:
10391025 Returns:
10401026 int: Count of records deleted from Redis.
10411027 """
1042- # Track deleted records
1028+ client = await self . _get_client ()
10431029 total_records_deleted : int = 0
10441030
1045- # Paginate using queries and delete in batches
10461031 async for batch in self .paginate (
10471032 FilterQuery (FilterExpression ("*" ), return_fields = ["id" ]), page_size = 500
10481033 ):
10491034 batch_keys = [record ["id" ] for record in batch ]
1050- records_deleted = await self . _redis_client . delete (* batch_keys ) # type: ignore
1051- total_records_deleted += records_deleted # type: ignore
1035+ records_deleted = await client . delete (* batch_keys )
1036+ total_records_deleted += records_deleted
10521037
10531038 return total_records_deleted
10541039
@@ -1061,10 +1046,11 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10611046 Returns:
10621047 int: Count of records deleted from Redis.
10631048 """
1064- if isinstance (keys , List ):
1065- return await self ._redis_client .delete (* keys ) # type: ignore
1049+ client = await self ._get_client ()
1050+ if isinstance (keys , list ):
1051+ return await client .delete (* keys )
10661052 else :
1067- return await self . _redis_client . delete (keys ) # type: ignore
1053+ return await client . delete (keys )
10681054
10691055 async def load (
10701056 self ,
@@ -1124,9 +1110,10 @@ async def add_field(d):
11241110 keys = await index.load(data, preprocess=add_field)
11251111
11261112 """
1113+ client = await self ._get_client ()
11271114 try :
11281115 return await self ._storage .awrite (
1129- self . _redis_client , # type: ignore
1116+ client ,
11301117 objects = data ,
11311118 id_field = id_field ,
11321119 keys = keys ,
@@ -1150,7 +1137,8 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11501137 Returns:
11511138 Dict[str, Any]: The fetched object.
11521139 """
1153- obj = await self ._storage .aget (self ._redis_client , [self .key (id )]) # type: ignore
1140+ client = await self ._get_client ()
1141+ obj = await self ._storage .aget (client , [self .key (id )])
11541142 if obj :
11551143 return convert_bytes (obj [0 ])
11561144 return None
@@ -1165,10 +1153,10 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11651153 Returns:
11661154 Result: Raw Redis aggregation results.
11671155 """
1156+ client = await self ._get_client ()
11681157 try :
1169- return await self ._redis_client .ft (self .schema .index .name ).aggregate ( # type: ignore
1170- * args , ** kwargs
1171- )
1158+ # TODO: Typing
1159+ return await client .ft (self .schema .index .name ).aggregate (* args , ** kwargs )
11721160 except Exception as e :
11731161 raise RedisSearchError (f"Error while aggregating: { str (e )} " ) from e
11741162
@@ -1182,10 +1170,10 @@ async def search(self, *args, **kwargs) -> "Result":
11821170 Returns:
11831171 Result: Raw Redis search results.
11841172 """
1173+ client = await self ._get_client ()
11851174 try :
1186- return await self ._redis_client .ft (self .schema .index .name ).search ( # type: ignore
1187- * args , ** kwargs
1188- )
1175+ # TODO: Typing
1176+ return await client .ft (self .schema .index .name ).search (* args , ** kwargs )
11891177 except Exception as e :
11901178 raise RedisSearchError (f"Error while searching: { str (e )} " ) from e
11911179
@@ -1256,7 +1244,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato
12561244
12571245 """
12581246 if not isinstance (page_size , int ):
1259- raise TypeError ("page_size must be an integer " )
1247+ raise TypeError ("page_size must be of type int " )
12601248
12611249 if page_size <= 0 :
12621250 raise ValueError ("page_size must be greater than 0" )
@@ -1268,7 +1256,6 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato
12681256 if not results :
12691257 break
12701258 yield results
1271- # increment the pagination tracker
12721259 first += page_size
12731260
12741261 async def listall (self ) -> List [str ]:
@@ -1277,9 +1264,8 @@ async def listall(self) -> List[str]:
12771264 Returns:
12781265 List[str]: The list of indices in the database.
12791266 """
1280- return convert_bytes (
1281- await self ._redis_client .execute_command ("FT._LIST" ) # type: ignore
1282- )
1267+ client : aredis .Redis = await self ._get_client ()
1268+ return convert_bytes (await client .execute_command ("FT._LIST" ))
12831269
12841270 async def exists (self ) -> bool :
12851271 """Check if the index exists in Redis.
@@ -1289,15 +1275,6 @@ async def exists(self) -> bool:
12891275 """
12901276 return self .schema .index .name in await self .listall ()
12911277
1292- @staticmethod
1293- async def _info (name : str , redis_client : aredis .Redis ) -> Dict [str , Any ]:
1294- try :
1295- return convert_bytes (await redis_client .ft (name ).info ()) # type: ignore
1296- except Exception as e :
1297- raise RedisSearchError (
1298- f"Error while fetching { name } index info: { str (e )} "
1299- ) from e
1300-
13011278 async def info (self , name : Optional [str ] = None ) -> Dict [str , Any ]:
13021279 """Get information about the index.
13031280
@@ -1308,5 +1285,21 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
13081285 Returns:
13091286 dict: A dictionary containing the information about the index.
13101287 """
1288+ client : aredis .Redis = await self ._get_client ()
13111289 index_name = name or self .schema .index .name
1312- return await self ._info (index_name , self ._redis_client ) # type: ignore
1290+ return await type (self )._info (index_name , client )
1291+
1292+ @staticmethod
1293+ async def _info (name : str , redis_client : aredis .Redis ) -> Dict [str , Any ]:
1294+ try :
1295+ return convert_bytes (await redis_client .ft (name ).info ()) # type: ignore
1296+ except Exception as e :
1297+ raise RedisSearchError (
1298+ f"Error while fetching { name } index info: { str (e )} "
1299+ ) from e
1300+
1301+ async def __aenter__ (self ):
1302+ return self
1303+
1304+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
1305+ await self .disconnect ()
0 commit comments