Skip to content

Commit 26f9d24

Browse files
committed
Add batch_search to sync Index
1 parent 29cb397 commit 26f9d24

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

redisvl/index/index.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import threading
4+
import time
45
import warnings
56
import weakref
67
from typing import (
@@ -26,6 +27,8 @@
2627

2728
import redis
2829
import redis.asyncio as aredis
30+
from redis.client import NEVER_DECODE
31+
from redis.commands.helpers import get_protocol_version # type: ignore
2932
from redis.commands.search.indexDefinition import IndexDefinition
3033

3134
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
@@ -349,7 +352,7 @@ def client(self) -> Optional[redis.Redis]:
349352
return self.__redis_client
350353

351354
@property
352-
def _redis_client(self) -> Optional[redis.Redis]:
355+
def _redis_client(self) -> redis.Redis:
353356
"""
354357
Get a Redis client instance.
355358
@@ -652,6 +655,79 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult":
652655
except Exception as e:
653656
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
654657

658+
def batch_search(
659+
self, queries: List[str], batch_size: int = 100, **query_params
660+
) -> List[List[Dict[str, Any]]]:
661+
"""Perform a search against the index for multiple queries.
662+
663+
This method takes a list of queries and returns a list of search results.
664+
The results are returned in the same order as the queries.
665+
666+
Args:
667+
queries (List[str]): The queries to search for.
668+
batch_size (int, optional): The number of queries to search for at a time.
669+
Defaults to 100.
670+
query_params (dict, optional): The query parameters to pass to the search
671+
for each query.
672+
673+
Returns:
674+
List[List[Dict[str, Any]]]: The search results.
675+
"""
676+
all_parsed = []
677+
search = self._redis_client.ft(self.schema.index.name)
678+
options = {}
679+
if get_protocol_version(self._redis_client) not in ["3", 3]:
680+
options[NEVER_DECODE] = True
681+
682+
for i in range(0, len(queries), batch_size):
683+
batch_queries = queries[i : i + batch_size]
684+
print("batch queries", batch_queries)
685+
686+
# redis-py doesn't support calling `search` in a pipeline,
687+
# so we need to manually execute each command in a pipeline
688+
# and parse the results
689+
with self._redis_client.pipeline(transaction=False) as pipe:
690+
batch_built_queries = []
691+
for query in batch_queries:
692+
query_args, q = search._mk_query_args( # type: ignore
693+
query, query_params=query_params
694+
)
695+
batch_built_queries.append(q)
696+
print("query", query_args, options)
697+
pipe.execute_command(
698+
"FT.SEARCH",
699+
*query_args,
700+
**options,
701+
)
702+
703+
st = time.time()
704+
# One list of results per query
705+
print("query stack", pipe.command_stack)
706+
results = pipe.execute()
707+
print("SUCCESS")
708+
709+
# We don't know how long each query took, so we'll use the total time
710+
# for all queries in the batch as the duration for each query
711+
duration = (time.time() - st) * 1000.0
712+
713+
for i, query_results in enumerate(results):
714+
_built_query = batch_built_queries[i]
715+
parsed_raw = search._parse_search( # type: ignore
716+
query_results,
717+
query=_built_query,
718+
duration=duration,
719+
)
720+
parsed = process_results(
721+
parsed_raw,
722+
query=_built_query,
723+
storage_type=self.schema.index.storage_type,
724+
)
725+
# Create separate lists of parsed results for each query
726+
# passed in to the batch_search method, so that callers can
727+
# access the results for each query individually
728+
all_parsed.append(parsed)
729+
return all_parsed
730+
655731
def search(self, *args, **kwargs) -> "Result":
656732
"""Perform a search against the index.
657733

tests/integration/test_search_index.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,52 @@ def test_search_index_validates_redis_modules(redis_url):
389389
index.create(overwrite=True, drop=True)
390390

391391
mock_validate_sync_redis.assert_called_once()
392+
393+
394+
def test_batch_search(index):
395+
index.create(overwrite=True, drop=True)
396+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
397+
index.load(data, id_field="id")
398+
399+
results = index.batch_search(["@test:{foo}", "@test:{bar}"])
400+
assert len(results) == 2
401+
assert results[0][0]["id"] == "rvl:1"
402+
assert results[1][0]["id"] == "rvl:2"
403+
404+
405+
def test_batch_search_with_multiple_batches(index):
406+
index.create(overwrite=True, drop=True)
407+
data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}]
408+
index.load(data, id_field="id")
409+
410+
results = index.batch_search(["@test:{foo}", "@test:{bar}"])
411+
assert len(results) == 2
412+
assert len(results[0]) == 1
413+
assert len(results[1]) == 1
414+
415+
results = index.batch_search(
416+
[
417+
"@test:{foo}",
418+
"@test:{bar}",
419+
"@test:{baz}",
420+
"@test:{foo}",
421+
"@test:{bar}",
422+
"@test:{baz}",
423+
],
424+
batch_size=2,
425+
)
426+
assert len(results) == 6
427+
428+
# First (and only) result for the first query
429+
assert results[0][0]["id"] == "rvl:1"
430+
431+
# Second (and only) result for the second query
432+
assert results[1][0]["id"] == "rvl:2"
433+
434+
# Third query has no results
435+
assert len(results[2]) == 0
436+
437+
# Then the pattern repeats
438+
assert results[3][0]["id"] == "rvl:1"
439+
assert results[4][0]["id"] == "rvl:2"
440+
assert len(results[5]) == 0

0 commit comments

Comments
 (0)