Skip to content

Commit 6e007f7

Browse files
abrookinsjustin-cechmanek
authored andcommitted
Validate passed-in Redis clients (#296)
Prior to RedisVL 0.4.0, we validated passed-in Redis clients when the user called `set_client()`. This PR reintroduces similar behavior by validating all clients, whether we created them or not, on first access through the lazy-client mechanism. Closes RAAE-694.
1 parent f7a4b9e commit 6e007f7

File tree

4 files changed

+70
-7
lines changed

4 files changed

+70
-7
lines changed

redisvl/index/index.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def __init__(
317317
self._connection_kwargs = connection_kwargs or {}
318318
self._lock = threading.Lock()
319319

320+
self._validated_client = False
320321
self._owns_redis_client = redis_client is None
321322
if self._owns_redis_client:
322323
weakref.finalize(self, self.disconnect)
@@ -396,6 +397,12 @@ def _redis_client(self) -> Optional[redis.Redis]:
396397
redis_url=self._redis_url,
397398
**self._connection_kwargs,
398399
)
400+
if not self._validated_client:
401+
RedisConnectionFactory.validate_sync_redis(
402+
self.__redis_client,
403+
self._lib_name,
404+
)
405+
self._validated_client = True
399406
return self.__redis_client
400407

401408
@deprecated_function("connect", "Pass connection parameters in __init__.")
@@ -931,6 +938,7 @@ def __init__(
931938
self._connection_kwargs = connection_kwargs or {}
932939
self._lock = asyncio.Lock()
933940

941+
self._validated_client = False
934942
self._owns_redis_client = redis_client is None
935943
if self._owns_redis_client:
936944
weakref.finalize(self, sync_wrapper(self.disconnect))
@@ -1027,9 +1035,12 @@ async def _get_client(self) -> aredis.Redis:
10271035
self._redis_client = (
10281036
await RedisConnectionFactory._get_aredis_connection(**kwargs)
10291037
)
1038+
if not self._validated_client:
10301039
await RedisConnectionFactory.validate_async_redis(
1031-
self._redis_client, self._lib_name
1040+
self._redis_client,
1041+
self._lib_name,
10321042
)
1043+
self._validated_client = True
10331044
return self._redis_client
10341045

10351046
async def _validate_client(

redisvl/redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def validate_modules(
159159
required_modules: List of required modules.
160160
161161
Raises:
162-
ValueError: If required Redis modules are not installed.
162+
RedisModuleVersionError: If required Redis modules are not installed.
163163
"""
164164
required_modules = required_modules or DEFAULT_REQUIRED_MODULES
165165

tests/integration/test_async_search_index.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import warnings
2+
from unittest import mock
23

34
import pytest
45
from redis import Redis as SyncRedis
5-
from redis.asyncio import Redis
6+
from redis.asyncio import Redis as AsyncRedis
67

7-
from redisvl.exceptions import RedisSearchError
8+
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
89
from redisvl.index import AsyncSearchIndex
910
from redisvl.query import VectorQuery
1011
from redisvl.redis.utils import convert_bytes
@@ -172,12 +173,12 @@ async def test_search_index_set_client(client, redis_url, index_schema):
172173
with warnings.catch_warnings():
173174
warnings.filterwarnings("ignore", category=DeprecationWarning)
174175
await async_index.create(overwrite=True, drop=True)
175-
assert isinstance(async_index.client, Redis)
176+
assert isinstance(async_index.client, AsyncRedis)
176177

177178
# Tests deprecated sync -> async conversion behavior
178179
assert isinstance(client, SyncRedis)
179180
await async_index.set_client(client)
180-
assert isinstance(async_index.client, Redis)
181+
assert isinstance(async_index.client, AsyncRedis)
181182

182183
await async_index.disconnect()
183184
assert async_index.client is None
@@ -410,3 +411,28 @@ async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis
410411
await async_index.create(overwrite=True, drop=True)
411412
await async_index.disconnect()
412413
assert async_index._redis_client is None
414+
415+
416+
@pytest.mark.asyncio
417+
async def test_async_search_index_validates_redis_modules(redis_url):
418+
"""
419+
A regression test for RAAE-694: we should validate that a passed-in
420+
Redis client has the correct modules installed.
421+
"""
422+
client = AsyncRedis.from_url(redis_url)
423+
with mock.patch(
424+
"redisvl.index.index.RedisConnectionFactory.validate_async_redis"
425+
) as mock_validate_async_redis:
426+
mock_validate_async_redis.side_effect = RedisModuleVersionError(
427+
"Required modules not installed"
428+
)
429+
with pytest.raises(RedisModuleVersionError):
430+
index = AsyncSearchIndex(
431+
schema=IndexSchema.from_dict(
432+
{"index": {"name": "my_index"}, "fields": fields}
433+
),
434+
redis_client=client,
435+
)
436+
await index.create(overwrite=True, drop=True)
437+
438+
mock_validate_async_redis.assert_called_once()

tests/integration/test_search_index.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
2+
from unittest import mock
23

34
import pytest
5+
from redis import Redis
46

5-
from redisvl.exceptions import RedisSearchError
7+
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
68
from redisvl.index import SearchIndex
79
from redisvl.query import VectorQuery
810
from redisvl.redis.utils import convert_bytes
@@ -363,3 +365,27 @@ def test_search_index_that_owns_client_disconnect(index_schema, redis_url):
363365
index.create(overwrite=True, drop=True)
364366
index.disconnect()
365367
assert index.client is None
368+
369+
370+
def test_search_index_validates_redis_modules(redis_url):
371+
"""
372+
A regression test for RAAE-694: we should validate that a passed-in
373+
Redis client has the correct modules installed.
374+
"""
375+
client = Redis.from_url(redis_url)
376+
with mock.patch(
377+
"redisvl.index.index.RedisConnectionFactory.validate_sync_redis"
378+
) as mock_validate_sync_redis:
379+
mock_validate_sync_redis.side_effect = RedisModuleVersionError(
380+
"Required modules not installed"
381+
)
382+
with pytest.raises(RedisModuleVersionError):
383+
index = SearchIndex(
384+
schema=IndexSchema.from_dict(
385+
{"index": {"name": "my_index"}, "fields": fields}
386+
),
387+
redis_client=client,
388+
)
389+
index.create(overwrite=True, drop=True)
390+
391+
mock_validate_sync_redis.assert_called_once()

0 commit comments

Comments
 (0)