Skip to content

Commit 6259604

Browse files
committed
WIP
1 parent 3723cf0 commit 6259604

File tree

1 file changed

+108
-115
lines changed

1 file changed

+108
-115
lines changed

redisvl/index/index.py

Lines changed: 108 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from redis.commands.search.document import Document
2222
from redis.commands.search.result import Result
2323
from redisvl.query.query import BaseQuery
24+
import redis.asyncio
2425

2526
import redis
2627
import redis.asyncio as aredis
@@ -793,16 +794,29 @@ class AsyncSearchIndex(BaseSearchIndex):
793794
794795
"""
795796

797+
# TODO: The `aredis.Redis` type is not working for type checks.
798+
_redis_client: Optional[redis.asyncio.Redis] = None
799+
_redis_url: Optional[str] = None
800+
_redis_kwargs: Dict[str, Any] = {}
801+
796802
def __init__(
797803
self,
798804
schema: IndexSchema,
805+
*,
806+
redis_url: Optional[str] = None,
807+
redis_client: Optional[aredis.Redis] = None,
808+
redis_kwargs: Optional[Dict[str, Any]] = None,
799809
**kwargs,
800810
):
801811
"""Initialize the RedisVL async search index with a schema.
802812
803813
Args:
804814
schema (IndexSchema): Index schema object.
805-
connection_args (Dict[str, Any], optional): Redis client connection
815+
redis_url (Optional[str], optional): The URL of the Redis server to
816+
connect to.
817+
redis_client (Optional[aredis.Redis], optional): An
818+
instantiated redis client.
819+
redis_kwargs (Dict[str, Any], optional): Redis client connection
806820
args.
807821
"""
808822
# final validation on schema object
@@ -813,39 +827,21 @@ def __init__(
813827

814828
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
815829

816-
# set up empty redis connection
817-
self._redis_client: Optional[aredis.Redis] = None
818-
819-
if "redis_client" in kwargs or "redis_url" in kwargs:
820-
logger.warning(
821-
"Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
822-
)
823-
824-
atexit.register(self._cleanup_connection)
825-
826-
def _cleanup_connection(self):
827-
if self._redis_client:
828-
829-
def run_in_thread():
830-
try:
831-
loop = asyncio.new_event_loop()
832-
asyncio.set_event_loop(loop)
833-
loop.run_until_complete(self._redis_client.aclose())
834-
loop.close()
835-
except RuntimeError:
836-
pass
837-
838-
# Run cleanup in a background thread to avoid event loop issues
839-
thread = threading.Thread(target=run_in_thread)
840-
thread.start()
841-
thread.join()
830+
# Store connection parameters
831+
if redis_client and redis_url:
832+
raise ValueError("Cannot provide both redis_client and redis_url")
842833

834+
self._redis_client = redis_client
835+
self._redis_url = redis_url
836+
self._redis_kwargs = redis_kwargs or {}
837+
self._lock = asyncio.Lock()
838+
839+
async def disconnect(self):
840+
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
841+
if self._redis_client is not None:
842+
await self._redis_client.aclose() # type: ignore
843843
self._redis_client = None
844844

845-
def disconnect(self):
846-
"""Disconnect and cleanup the underlying async redis connection."""
847-
self._cleanup_connection()
848-
849845
@classmethod
850846
async def from_existing(
851847
cls,
@@ -902,69 +898,59 @@ def client(self) -> Optional[aredis.Redis]:
902898
return self._redis_client
903899

904900
async def connect(self, redis_url: Optional[str] = None, **kwargs):
905-
"""Connect to a Redis instance using the provided `redis_url`, falling
906-
back to the `REDIS_URL` environment variable (if available).
901+
"""[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
902+
import warnings
907903

908-
Note: Additional keyword arguments (`**kwargs`) can be used to provide
909-
extra options specific to the Redis connection.
910-
911-
Args:
912-
redis_url (Optional[str], optional): The URL of the Redis server to
913-
connect to. If not provided, the method defaults to using the
914-
`REDIS_URL` environment variable.
915-
916-
Raises:
917-
redis.exceptions.ConnectionError: If the connection to the Redis
918-
server fails.
919-
ValueError: If the Redis URL is not provided nor accessible
920-
through the `REDIS_URL` environment variable.
921-
922-
.. code-block:: python
923-
924-
index.connect(redis_url="redis://localhost:6379")
925-
926-
"""
904+
warnings.warn(
905+
"connect() is deprecated; pass connection parameters in __init__",
906+
DeprecationWarning,
907+
)
927908
client = RedisConnectionFactory.connect(
928909
redis_url=redis_url, use_async=True, **kwargs
929910
)
930911
return await self.set_client(client)
931912

932-
@setup_async_redis()
933-
async def set_client(self, redis_client: aredis.Redis):
934-
"""Manually set the Redis client to use with the search index.
935-
936-
This method configures the search index to use a specific
937-
Async Redis client. It is useful for cases where an external,
938-
custom-configured client is preferred instead of creating a new one.
939-
940-
Args:
941-
redis_client (aredis.Redis): An Async Redis
942-
client instance to be used for the connection.
943-
944-
Raises:
945-
TypeError: If the provided client is not valid.
946-
947-
.. code-block:: python
913+
async def set_client(self, redis_client: Optional[aredis.Redis]):
914+
"""[DEPRECATED] Manually set the Redis client to use with the search index.
915+
This method is deprecated; please provide connection parameters in __init__.
916+
"""
917+
import warnings
948918

949-
import redis.asyncio as aredis
950-
from redisvl.index import AsyncSearchIndex
919+
warnings.warn(
920+
"set_client() is deprecated; pass connection parameters in __init__",
921+
DeprecationWarning,
922+
)
923+
return await self._set_client(redis_client)
951924

952-
# async Redis client and index
953-
client = aredis.Redis.from_url("redis://localhost:6379")
954-
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
955-
await index.set_client(client)
925+
async def _set_client(self, redis_client: Optional[redis.asyncio.Redis]):
926+
"""
927+
Set the Redis client to use with the search index.
956928
929+
NOTE: Remove this method once the deprecation period is over.
957930
"""
958-
if isinstance(redis_client, redis.Redis):
959-
print("Setting client and converting from async", flush=True)
960-
self._redis_client = RedisConnectionFactory.sync_to_async_redis(
961-
redis_client
962-
)
963-
else:
931+
if self._redis_client is not None:
932+
await self._redis_client.aclose() # type: ignore
933+
async with self._lock:
964934
self._redis_client = redis_client
965-
966935
return self
967936

937+
async def _get_client(self) -> aredis.Redis:
938+
"""Lazily instantiate and return the async Redis client."""
939+
if self._redis_client is None:
940+
async with self._lock:
941+
# Double-check to protect against concurrent access
942+
if self._redis_client is None:
943+
kwargs = self._redis_kwargs
944+
if self._redis_url:
945+
kwargs["redis_url"] = self._redis_url
946+
self._redis_client = (
947+
RedisConnectionFactory.get_async_redis_connection(**kwargs)
948+
)
949+
await RedisConnectionFactory.validate_async_redis(
950+
self._redis_client, self._lib_name
951+
)
952+
return self._redis_client
953+
968954
async def create(self, overwrite: bool = False, drop: bool = False) -> None:
969955
"""Asynchronously create an index in Redis with the current schema
970956
and properties.
@@ -990,6 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
990976
# overwrite an index in Redis; drop associated data (clean slate)
991977
await index.create(overwrite=True, drop=True)
992978
"""
979+
client = await self._get_client()
993980
redis_fields = self.schema.redis_fields
994981

995982
if not redis_fields:
@@ -1005,7 +992,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
1005992
await self.delete(drop)
1006993

1007994
try:
1008-
await self._redis_client.ft(self.schema.index.name).create_index( # type: ignore
995+
await client.ft(self.schema.index.name).create_index(
1009996
fields=redis_fields,
1010997
definition=IndexDefinition(
1011998
prefix=[self.schema.index.prefix], index_type=self._storage.type
@@ -1025,10 +1012,9 @@ async def delete(self, drop: bool = True):
10251012
Raises:
10261013
redis.exceptions.ResponseError: If the index does not exist.
10271014
"""
1015+
client = await self._get_client()
10281016
try:
1029-
await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
1030-
delete_documents=drop
1031-
)
1017+
await client.ft(self.schema.index.name).dropindex(delete_documents=drop)
10321018
except Exception as e:
10331019
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e
10341020

@@ -1039,16 +1025,15 @@ async def clear(self) -> int:
10391025
Returns:
10401026
int: Count of records deleted from Redis.
10411027
"""
1042-
# Track deleted records
1028+
client = await self._get_client()
10431029
total_records_deleted: int = 0
10441030

1045-
# Paginate using queries and delete in batches
10461031
async for batch in self.paginate(
10471032
FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500
10481033
):
10491034
batch_keys = [record["id"] for record in batch]
1050-
records_deleted = await self._redis_client.delete(*batch_keys) # type: ignore
1051-
total_records_deleted += records_deleted # type: ignore
1035+
records_deleted = await client.delete(*batch_keys)
1036+
total_records_deleted += records_deleted
10521037

10531038
return total_records_deleted
10541039

@@ -1061,10 +1046,11 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10611046
Returns:
10621047
int: Count of records deleted from Redis.
10631048
"""
1064-
if isinstance(keys, List):
1065-
return await self._redis_client.delete(*keys) # type: ignore
1049+
client = await self._get_client()
1050+
if isinstance(keys, list):
1051+
return await client.delete(*keys)
10661052
else:
1067-
return await self._redis_client.delete(keys) # type: ignore
1053+
return await client.delete(keys)
10681054

10691055
async def load(
10701056
self,
@@ -1124,9 +1110,10 @@ async def add_field(d):
11241110
keys = await index.load(data, preprocess=add_field)
11251111
11261112
"""
1113+
client = await self._get_client()
11271114
try:
11281115
return await self._storage.awrite(
1129-
self._redis_client, # type: ignore
1116+
client,
11301117
objects=data,
11311118
id_field=id_field,
11321119
keys=keys,
@@ -1150,7 +1137,8 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11501137
Returns:
11511138
Dict[str, Any]: The fetched object.
11521139
"""
1153-
obj = await self._storage.aget(self._redis_client, [self.key(id)]) # type: ignore
1140+
client = await self._get_client()
1141+
obj = await self._storage.aget(client, [self.key(id)])
11541142
if obj:
11551143
return convert_bytes(obj[0])
11561144
return None
@@ -1165,10 +1153,10 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11651153
Returns:
11661154
Result: Raw Redis aggregation results.
11671155
"""
1156+
client = await self._get_client()
11681157
try:
1169-
return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
1170-
*args, **kwargs
1171-
)
1158+
# TODO: Typing
1159+
return await client.ft(self.schema.index.name).aggregate(*args, **kwargs)
11721160
except Exception as e:
11731161
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
11741162

@@ -1182,10 +1170,10 @@ async def search(self, *args, **kwargs) -> "Result":
11821170
Returns:
11831171
Result: Raw Redis search results.
11841172
"""
1173+
client = await self._get_client()
11851174
try:
1186-
return await self._redis_client.ft(self.schema.index.name).search( # type: ignore
1187-
*args, **kwargs
1188-
)
1175+
# TODO: Typing
1176+
return await client.ft(self.schema.index.name).search(*args, **kwargs)
11891177
except Exception as e:
11901178
raise RedisSearchError(f"Error while searching: {str(e)}") from e
11911179

@@ -1256,7 +1244,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato
12561244
12571245
"""
12581246
if not isinstance(page_size, int):
1259-
raise TypeError("page_size must be an integer")
1247+
raise TypeError("page_size must be of type int")
12601248

12611249
if page_size <= 0:
12621250
raise ValueError("page_size must be greater than 0")
@@ -1268,7 +1256,6 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato
12681256
if not results:
12691257
break
12701258
yield results
1271-
# increment the pagination tracker
12721259
first += page_size
12731260

12741261
async def listall(self) -> List[str]:
@@ -1277,9 +1264,8 @@ async def listall(self) -> List[str]:
12771264
Returns:
12781265
List[str]: The list of indices in the database.
12791266
"""
1280-
return convert_bytes(
1281-
await self._redis_client.execute_command("FT._LIST") # type: ignore
1282-
)
1267+
client: aredis.Redis = await self._get_client()
1268+
return convert_bytes(await client.execute_command("FT._LIST"))
12831269

12841270
async def exists(self) -> bool:
12851271
"""Check if the index exists in Redis.
@@ -1289,15 +1275,6 @@ async def exists(self) -> bool:
12891275
"""
12901276
return self.schema.index.name in await self.listall()
12911277

1292-
@staticmethod
1293-
async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
1294-
try:
1295-
return convert_bytes(await redis_client.ft(name).info()) # type: ignore
1296-
except Exception as e:
1297-
raise RedisSearchError(
1298-
f"Error while fetching {name} index info: {str(e)}"
1299-
) from e
1300-
13011278
async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
13021279
"""Get information about the index.
13031280
@@ -1308,5 +1285,21 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
13081285
Returns:
13091286
dict: A dictionary containing the information about the index.
13101287
"""
1288+
client: aredis.Redis = await self._get_client()
13111289
index_name = name or self.schema.index.name
1312-
return await self._info(index_name, self._redis_client) # type: ignore
1290+
return await type(self)._info(index_name, client)
1291+
1292+
@staticmethod
1293+
async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
1294+
try:
1295+
return convert_bytes(await redis_client.ft(name).info()) # type: ignore
1296+
except Exception as e:
1297+
raise RedisSearchError(
1298+
f"Error while fetching {name} index info: {str(e)}"
1299+
) from e
1300+
1301+
async def __aenter__(self):
1302+
return self
1303+
1304+
async def __aexit__(self, exc_type, exc_val, exc_tb):
1305+
await self.disconnect()

0 commit comments

Comments
 (0)