Skip to content

Commit 49bf7c4

Browse files
committed
Put key expiration behind an interface
1 parent d27c36e commit 49bf7c4

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,13 @@ async def adrop(
281281
def _refresh_ttl(self, key: str) -> None:
282282
"""Refresh the time-to-live for the specified key."""
283283
if self._ttl:
284-
self._index.client.expire(key, self._ttl) # type: ignore
284+
self._index.expire_keys(key, self._ttl)
285285

286286
async def _async_refresh_ttl(self, key: str) -> None:
287287
"""Async refresh the time-to-live for the specified key."""
288288
aindex = await self._get_async_index()
289289
if self._ttl:
290-
client = await aindex.get_client()
291-
await client.expire(key, self._ttl) # type: ignore
290+
await aindex.expire_keys(key, self._ttl)
292291

293292
def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
294293
"""Converts a text prompt to its vector representation using the

redisvl/index/index.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import asyncio
2-
import atexit
32
import json
4-
import threading
53
import warnings
64
from functools import wraps
75
from typing import (
@@ -15,7 +13,6 @@
1513
List,
1614
Optional,
1715
Union,
18-
cast,
1916
)
2017

2118
from redisvl.utils.utils import deprecated_function
@@ -524,6 +521,23 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int:
524521
else:
525522
return self._redis_client.delete(keys) # type: ignore
526523

524+
def expire_keys(
525+
self, keys: Union[str, List[str]], ttl: int
526+
) -> Union[int, List[int]]:
527+
"""Set the expiration time for a specific entry or entries in Redis.
528+
529+
Args:
530+
keys (Union[str, List[str]]): The document ID or IDs to set the expiration for.
531+
ttl (int): The time-to-live in seconds.
532+
"""
533+
if isinstance(keys, list):
534+
pipe = self._redis_client.pipeline() # type: ignore
535+
for key in keys:
536+
pipe.expire(key, ttl)
537+
return pipe.execute()
538+
else:
539+
return self._redis_client.expire(keys, ttl) # type: ignore
540+
527541
def load(
528542
self,
529543
data: Iterable[Any],
@@ -938,10 +952,6 @@ async def _get_client(self):
938952
)
939953
return self._redis_client
940954

941-
async def get_client(self) -> aredis.Redis:
942-
"""Return this index's async Redis client."""
943-
return await self._get_client()
944-
945955
async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis:
946956
if isinstance(redis_client, redis.Redis):
947957
warnings.warn(
@@ -980,7 +990,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
980990
# overwrite an index in Redis; drop associated data (clean slate)
981991
await index.create(overwrite=True, drop=True)
982992
"""
983-
client = await self.get_client()
993+
client = await self._get_client()
984994
redis_fields = self.schema.redis_fields
985995

986996
if not redis_fields:
@@ -1016,7 +1026,7 @@ async def delete(self, drop: bool = True):
10161026
Raises:
10171027
redis.exceptions.ResponseError: If the index does not exist.
10181028
"""
1019-
client = await self.get_client()
1029+
client = await self._get_client()
10201030
try:
10211031
await client.ft(self.schema.index.name).dropindex(delete_documents=drop)
10221032
except Exception as e:
@@ -1029,7 +1039,7 @@ async def clear(self) -> int:
10291039
Returns:
10301040
int: Count of records deleted from Redis.
10311041
"""
1032-
client = await self.get_client()
1042+
client = await self._get_client()
10331043
total_records_deleted: int = 0
10341044

10351045
async for batch in self.paginate(
@@ -1050,12 +1060,30 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10501060
Returns:
10511061
int: Count of records deleted from Redis.
10521062
"""
1053-
client = await self.get_client()
1063+
client = await self._get_client()
10541064
if isinstance(keys, list):
10551065
return await client.delete(*keys)
10561066
else:
10571067
return await client.delete(keys)
10581068

1069+
async def expire_keys(
1070+
self, keys: Union[str, List[str]], ttl: int
1071+
) -> Union[int, List[int]]:
1072+
"""Set the expiration time for a specific entry or entries in Redis.
1073+
1074+
Args:
1075+
keys (Union[str, List[str]]): The document ID or IDs to set the expiration for.
1076+
ttl (int): The time-to-live in seconds.
1077+
"""
1078+
client = await self._get_client()
1079+
if isinstance(keys, list):
1080+
pipe = client.pipeline()
1081+
for key in keys:
1082+
pipe.expire(key, ttl)
1083+
return await pipe.execute()
1084+
else:
1085+
return await client.expire(keys, ttl)
1086+
10591087
async def load(
10601088
self,
10611089
data: Iterable[Any],
@@ -1114,7 +1142,7 @@ async def add_field(d):
11141142
keys = await index.load(data, preprocess=add_field)
11151143
11161144
"""
1117-
client = await self.get_client()
1145+
client = await self._get_client()
11181146
try:
11191147
return await self._storage.awrite(
11201148
client,
@@ -1141,7 +1169,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11411169
Returns:
11421170
Dict[str, Any]: The fetched object.
11431171
"""
1144-
client = await self.get_client()
1172+
client = await self._get_client()
11451173
obj = await self._storage.aget(client, [self.key(id)])
11461174
if obj:
11471175
return convert_bytes(obj[0])
@@ -1157,7 +1185,7 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11571185
Returns:
11581186
Result: Raw Redis aggregation results.
11591187
"""
1160-
client = await self.get_client()
1188+
client = await self._get_client()
11611189
try:
11621190
return client.ft(self.schema.index.name).aggregate(*args, **kwargs)
11631191
except Exception as e:
@@ -1173,7 +1201,7 @@ async def search(self, *args, **kwargs) -> "Result":
11731201
Returns:
11741202
Result: Raw Redis search results.
11751203
"""
1176-
client = await self.get_client()
1204+
client = await self._get_client()
11771205
try:
11781206
return await client.ft(self.schema.index.name).search(*args, **kwargs)
11791207
except Exception as e:
@@ -1266,7 +1294,7 @@ async def listall(self) -> List[str]:
12661294
Returns:
12671295
List[str]: The list of indices in the database.
12681296
"""
1269-
client = await self.get_client()
1297+
client = await self._get_client()
12701298
return convert_bytes(await client.execute_command("FT._LIST"))
12711299

12721300
async def exists(self) -> bool:
@@ -1287,7 +1315,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12871315
Returns:
12881316
dict: A dictionary containing the information about the index.
12891317
"""
1290-
client = await self.get_client()
1318+
client = await self._get_client()
12911319
index_name = name or self.schema.index.name
12921320
return await type(self)._info(index_name, client)
12931321

0 commit comments

Comments
 (0)