11import asyncio
2- import atexit
32import json
4- import threading
53import warnings
64from functools import wraps
75from typing import (
1513 List ,
1614 Optional ,
1715 Union ,
18- cast ,
1916)
2017
2118from redisvl .utils .utils import deprecated_function
@@ -524,6 +521,23 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int:
524521 else :
525522 return self ._redis_client .delete (keys ) # type: ignore
526523
524+ def expire_keys (
525+ self , keys : Union [str , List [str ]], ttl : int
526+ ) -> Union [int , List [int ]]:
527+ """Set the expiration time for a specific entry or entries in Redis.
528+
529+ Args:
530+ keys (Union[str, List[str]]): The document ID or IDs to set the expiration for.
531+ ttl (int): The time-to-live in seconds.
532+ """
533+ if isinstance (keys , list ):
534+ pipe = self ._redis_client .pipeline () # type: ignore
535+ for key in keys :
536+ pipe .expire (key , ttl )
537+ return pipe .execute ()
538+ else :
539+ return self ._redis_client .expire (keys , ttl ) # type: ignore
540+
527541 def load (
528542 self ,
529543 data : Iterable [Any ],
@@ -938,10 +952,6 @@ async def _get_client(self):
938952 )
939953 return self ._redis_client
940954
941- async def get_client (self ) -> aredis .Redis :
942- """Return this index's async Redis client."""
943- return await self ._get_client ()
944-
945955 async def _validate_client (self , redis_client : aredis .Redis ) -> aredis .Redis :
946956 if isinstance (redis_client , redis .Redis ):
947957 warnings .warn (
@@ -980,7 +990,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
980990 # overwrite an index in Redis; drop associated data (clean slate)
981991 await index.create(overwrite=True, drop=True)
982992 """
983- client = await self .get_client ()
993+ client = await self ._get_client ()
984994 redis_fields = self .schema .redis_fields
985995
986996 if not redis_fields :
@@ -1016,7 +1026,7 @@ async def delete(self, drop: bool = True):
10161026 Raises:
10171027 redis.exceptions.ResponseError: If the index does not exist.
10181028 """
1019- client = await self .get_client ()
1029+ client = await self ._get_client ()
10201030 try :
10211031 await client .ft (self .schema .index .name ).dropindex (delete_documents = drop )
10221032 except Exception as e :
@@ -1029,7 +1039,7 @@ async def clear(self) -> int:
10291039 Returns:
10301040 int: Count of records deleted from Redis.
10311041 """
1032- client = await self .get_client ()
1042+ client = await self ._get_client ()
10331043 total_records_deleted : int = 0
10341044
10351045 async for batch in self .paginate (
@@ -1050,12 +1060,30 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10501060 Returns:
10511061 int: Count of records deleted from Redis.
10521062 """
1053- client = await self .get_client ()
1063+ client = await self ._get_client ()
10541064 if isinstance (keys , list ):
10551065 return await client .delete (* keys )
10561066 else :
10571067 return await client .delete (keys )
10581068
1069+ async def expire_keys (
1070+ self , keys : Union [str , List [str ]], ttl : int
1071+ ) -> Union [int , List [int ]]:
1072+ """Set the expiration time for a specific entry or entries in Redis.
1073+
1074+ Args:
1075+ keys (Union[str, List[str]]): The document ID or IDs to set the expiration for.
1076+ ttl (int): The time-to-live in seconds.
1077+ """
1078+ client = await self ._get_client ()
1079+ if isinstance (keys , list ):
1080+ pipe = client .pipeline ()
1081+ for key in keys :
1082+ pipe .expire (key , ttl )
1083+ return await pipe .execute ()
1084+ else :
1085+ return await client .expire (keys , ttl )
1086+
10591087 async def load (
10601088 self ,
10611089 data : Iterable [Any ],
@@ -1114,7 +1142,7 @@ async def add_field(d):
11141142 keys = await index.load(data, preprocess=add_field)
11151143
11161144 """
1117- client = await self .get_client ()
1145+ client = await self ._get_client ()
11181146 try :
11191147 return await self ._storage .awrite (
11201148 client ,
@@ -1141,7 +1169,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11411169 Returns:
11421170 Dict[str, Any]: The fetched object.
11431171 """
1144- client = await self .get_client ()
1172+ client = await self ._get_client ()
11451173 obj = await self ._storage .aget (client , [self .key (id )])
11461174 if obj :
11471175 return convert_bytes (obj [0 ])
@@ -1157,7 +1185,7 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11571185 Returns:
11581186 Result: Raw Redis aggregation results.
11591187 """
1160- client = await self .get_client ()
1188+ client = await self ._get_client ()
11611189 try :
11621190 return client .ft (self .schema .index .name ).aggregate (* args , ** kwargs )
11631191 except Exception as e :
@@ -1173,7 +1201,7 @@ async def search(self, *args, **kwargs) -> "Result":
11731201 Returns:
11741202 Result: Raw Redis search results.
11751203 """
1176- client = await self .get_client ()
1204+ client = await self ._get_client ()
11771205 try :
11781206 return await client .ft (self .schema .index .name ).search (* args , ** kwargs )
11791207 except Exception as e :
@@ -1266,7 +1294,7 @@ async def listall(self) -> List[str]:
12661294 Returns:
12671295 List[str]: The list of indices in the database.
12681296 """
1269- client = await self .get_client ()
1297+ client = await self ._get_client ()
12701298 return convert_bytes (await client .execute_command ("FT._LIST" ))
12711299
12721300 async def exists (self ) -> bool :
@@ -1287,7 +1315,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12871315 Returns:
12881316 dict: A dictionary containing the information about the index.
12891317 """
1290- client = await self .get_client ()
1318+ client = await self ._get_client ()
12911319 index_name = name or self .schema .index .name
12921320 return await type (self )._info (index_name , client )
12931321
0 commit comments