|
1 | 1 | import asyncio |
2 | 2 | import json |
3 | 3 | import threading |
| 4 | +import time |
4 | 5 | import warnings |
5 | 6 | import weakref |
6 | 7 | from typing import ( |
|
26 | 27 |
|
27 | 28 | import redis |
28 | 29 | import redis.asyncio as aredis |
| 30 | +from redis.client import NEVER_DECODE |
| 31 | +from redis.commands.helpers import get_protocol_version # type: ignore |
29 | 32 | from redis.commands.search.indexDefinition import IndexDefinition |
30 | 33 |
|
31 | 34 | from redisvl.exceptions import RedisModuleVersionError, RedisSearchError |
@@ -349,7 +352,7 @@ def client(self) -> Optional[redis.Redis]: |
349 | 352 | return self.__redis_client |
350 | 353 |
|
351 | 354 | @property |
352 | | - def _redis_client(self) -> Optional[redis.Redis]: |
| 355 | + def _redis_client(self) -> redis.Redis: |
353 | 356 | """ |
354 | 357 | Get a Redis client instance. |
355 | 358 |
|
@@ -652,6 +655,79 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult": |
652 | 655 | except Exception as e: |
653 | 656 | raise RedisSearchError(f"Error while aggregating: {str(e)}") from e |
654 | 657 |
|
| 658 | + def batch_search( |
| 659 | + self, queries: List[str], batch_size: int = 100, **query_params |
| 660 | + ) -> List[List[Dict[str, Any]]]: |
| 661 | + """Perform a search against the index for multiple queries. |
| 662 | +
|
| 663 | + This method takes a list of queries and returns a list of search results. |
| 664 | + The results are returned in the same order as the queries. |
| 665 | +
|
| 666 | + Args: |
| 667 | + queries (List[str]): The queries to search for. |
| 668 | + batch_size (int, optional): The number of queries to search for at a time. |
| 669 | + Defaults to 100. |
| 670 | + query_params (dict, optional): The query parameters to pass to the search |
| 671 | + for each query. |
| 672 | +
|
| 673 | + Returns: |
| 674 | + List[List[Dict[str, Any]]]: The search results. |
| 675 | + """ |
| 676 | + all_parsed = [] |
| 677 | + search = self._redis_client.ft(self.schema.index.name) |
| 678 | + options = {} |
| 679 | + if get_protocol_version(self._redis_client) not in ["3", 3]: |
| 680 | + options[NEVER_DECODE] = True |
| 681 | + |
| 682 | + for i in range(0, len(queries), batch_size): |
| 683 | + batch_queries = queries[i : i + batch_size] |
| 684 | + print("batch queries", batch_queries) |
| 685 | + |
| 686 | + # redis-py doesn't support calling `search` in a pipeline, |
| 687 | + # so we need to manually execute each command in a pipeline |
| 688 | + # and parse the results |
| 689 | + with self._redis_client.pipeline(transaction=False) as pipe: |
| 690 | + batch_built_queries = [] |
| 691 | + for query in batch_queries: |
| 692 | + query_args, q = search._mk_query_args( # type: ignore |
| 693 | + query, query_params=query_params |
| 694 | + ) |
| 695 | + batch_built_queries.append(q) |
| 696 | + print("query", query_args, options) |
| 697 | + pipe.execute_command( |
| 698 | + "FT.SEARCH", |
| 699 | + *query_args, |
| 700 | + **options, |
| 701 | + ) |
| 702 | + |
| 703 | + st = time.time() |
| 704 | + # One list of results per query |
| 705 | + print("query stack", pipe.command_stack) |
| 706 | + results = pipe.execute() |
| 707 | + print("SUCCESS") |
| 708 | + |
| 709 | + # We don't know how long each query took, so we'll use the total time |
| 710 | + # for all queries in the batch as the duration for each query |
| 711 | + duration = (time.time() - st) * 1000.0 |
| 712 | + |
| 713 | + for i, query_results in enumerate(results): |
| 714 | + _built_query = batch_built_queries[i] |
| 715 | + parsed_raw = search._parse_search( # type: ignore |
| 716 | + query_results, |
| 717 | + query=_built_query, |
| 718 | + duration=duration, |
| 719 | + ) |
| 720 | + parsed = process_results( |
| 721 | + parsed_raw, |
| 722 | + query=_built_query, |
| 723 | + storage_type=self.schema.index.storage_type, |
| 724 | + ) |
| 725 | + # Create separate lists of parsed results for each query |
| 726 | + # passed in to the batch_search method, so that callers can |
| 727 | + # access the results for each query individually |
| 728 | + all_parsed.append(parsed) |
| 729 | + return all_parsed |
| 730 | + |
655 | 731 | def search(self, *args, **kwargs) -> "Result": |
656 | 732 | """Perform a search against the index. |
657 | 733 |
|
|
0 commit comments