From 3395c87d6064e27e313189a243b7b3836df725ad Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Tue, 23 Sep 2025 20:58:27 -0700 Subject: [PATCH] feat: add text field weights support to TextQuery (#360) Adds the ability to specify weights for text fields in RedisVL queries, enabling users to prioritize certain fields over others in search results. - Support dictionary of field:weight mappings in TextQuery constructor - Maintain backward compatibility with single string field names - Add set_field_weights() method for dynamic weight updates - Generate proper Redis query syntax with weight modifiers - Comprehensive validation for positive numeric weights Example usage: ```python query = TextQuery(text="search", text_field_name={"title": 5.0}) query = TextQuery( text="search", text_field_name={"title": 3.0, "content": 1.5, "tags": 1.0} ) ``` - Add has_redisearch_module and has_redisearch_module_async helpers to conftest.py - Add skip_if_no_redisearch and skip_if_no_redisearch_async functions - Update test_no_proactive_module_checks.py to use shared helpers - Update test_semantic_router.py to check RediSearch availability in fixtures and tests - Update test_llmcache.py to check RediSearch availability in all cache fixtures - Update test_message_history.py to check RediSearch availability for semantic history - Ensure all tests that require RediSearch are properly skipped on Redis 6.2.6-v9 - BM25STD scorer is not available in Redis versions prior to 7.2.0. Add version check to skip these tests on older Redis versions. --- redisvl/query/query.py | 95 ++++++++- tests/conftest.py | 46 +++++ tests/integration/test_llmcache.py | 30 ++- tests/integration/test_message_history.py | 19 +- .../test_no_proactive_module_checks.py | 35 +--- tests/integration/test_semantic_router.py | 31 ++- .../test_text_query_weights_integration.py | 190 ++++++++++++++++++ tests/unit/test_text_query_weights.py | 122 +++++++++++ 8 files changed, 514 insertions(+), 54 deletions(-) create mode 100644 tests/integration/test_text_query_weights_integration.py create mode 100644 tests/unit/test_text_query_weights.py diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 88aa9dfb..28dea9e0 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -801,7 +801,7 @@ class TextQuery(BaseQuery): def __init__( self, text: str, - text_field_name: str, + text_field_name: Union[str, Dict[str, float]], text_scorer: str = "BM25STD", filter_expression: Optional[Union[str, FilterExpression]] = None, return_fields: Optional[List[str]] = None, @@ -817,7 +817,8 @@ def __init__( Args: text (str): The text string to perform the text search with. - text_field_name (str): The name of the document field to perform text search on. + text_field_name (Union[str, Dict[str, float]]): The name of the document field to perform + text search on, or a dictionary mapping field names to their weights. text_scorer (str, optional): The text scoring algorithm to use. Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}. See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/ @@ -849,7 +850,7 @@ def __init__( TypeError: If stopwords is not a valid iterable set of strings. """ self._text = text - self._text_field_name = text_field_name + self._field_weights = self._parse_field_weights(text_field_name) self._num_results = num_results self._set_stopwords(stopwords) @@ -934,15 +935,97 @@ def _tokenize_and_escape_query(self, user_query: str) -> str: [token for token in tokens if token and token not in self._stopwords] ) + def _parse_field_weights( + self, field_spec: Union[str, Dict[str, float]] + ) -> Dict[str, float]: + """Parse the field specification into a weights dictionary. + + Args: + field_spec: Either a single field name or dictionary of field:weight mappings + + Returns: + Dictionary mapping field names to their weights + """ + if isinstance(field_spec, str): + return {field_spec: 1.0} + elif isinstance(field_spec, dict): + # Validate all weights are numeric and positive + for field, weight in field_spec.items(): + if not isinstance(field, str): + raise TypeError(f"Field name must be a string, got {type(field)}") + if not isinstance(weight, (int, float)): + raise TypeError( + f"Weight for field '{field}' must be numeric, got {type(weight)}" + ) + if weight <= 0: + raise ValueError( + f"Weight for field '{field}' must be positive, got {weight}" + ) + return field_spec + else: + raise TypeError( + "text_field_name must be a string or dictionary of field:weight mappings" + ) + + def set_field_weights(self, field_weights: Union[str, Dict[str, float]]): + """Set or update the field weights for the query. + + Args: + field_weights: Either a single field name or dictionary of field:weight mappings + """ + self._field_weights = self._parse_field_weights(field_weights) + # Invalidate the query string + self._built_query_string = None + + @property + def field_weights(self) -> Dict[str, float]: + """Get the field weights for the query. + + Returns: + Dictionary mapping field names to their weights + """ + return self._field_weights.copy() + + @property + def text_field_name(self) -> Union[str, Dict[str, float]]: + """Get the text field name(s) - for backward compatibility. + + Returns: + Either a single field name string (if only one field with weight 1.0) + or a dictionary of field:weight mappings. + """ + if len(self._field_weights) == 1: + field, weight = next(iter(self._field_weights.items())) + if weight == 1.0: + return field + return self._field_weights.copy() + def _build_query_string(self) -> str: """Build the full query string for text search with optional filtering.""" filter_expression = self._filter_expression if isinstance(filter_expression, FilterExpression): filter_expression = str(filter_expression) - text = ( - f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})" - ) + escaped_query = self._tokenize_and_escape_query(self._text) + + # Build query parts for each field with its weight + field_queries = [] + for field, weight in self._field_weights.items(): + if weight == 1.0: + # Default weight doesn't need explicit weight syntax + field_queries.append(f"@{field}:({escaped_query})") + else: + # Use Redis weight syntax for non-default weights + field_queries.append( + f"@{field}:({escaped_query}) => {{ $weight: {weight} }}" + ) + + # Join multiple field queries with OR operator + if len(field_queries) == 1: + text = field_queries[0] + else: + text = "(" + " | ".join(field_queries) + ")" + if filter_expression and filter_expression != "*": text += f" AND {filter_expression}" return text diff --git a/tests/conftest.py b/tests/conftest.py index 991b62cb..8a3bdae6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -579,6 +579,26 @@ async def get_redis_version_async(client): return info["redis_version"] +def has_redisearch_module(client): + """Check if RediSearch module is available.""" + try: + # Try to list indices - this is a RediSearch command + client.execute_command("FT._LIST") + return True + except Exception: + return False + + +async def has_redisearch_module_async(client): + """Check if RediSearch module is available (async).""" + try: + # Try to list indices - this is a RediSearch command + await client.execute_command("FT._LIST") + return True + except Exception: + return False + + def skip_if_redis_version_below(client, min_version: str, message: str = None): """ Skip test if Redis version is below minimum required. @@ -609,3 +629,29 @@ async def skip_if_redis_version_below_async( if not compare_versions(redis_version, min_version): skip_msg = message or f"Redis version {redis_version} < {min_version} required" pytest.skip(skip_msg) + + +def skip_if_no_redisearch(client, message: str = None): + """ + Skip test if RediSearch module is not available. + + Args: + client: Redis client instance + message: Custom skip message + """ + if not has_redisearch_module(client): + skip_msg = message or "RediSearch module not available" + pytest.skip(skip_msg) + + +async def skip_if_no_redisearch_async(client, message: str = None): + """ + Skip test if RediSearch module is not available (async version). + + Args: + client: Async Redis client instance + message: Custom skip message + """ + if not await has_redisearch_module_async(client): + skip_msg = message or "RediSearch module not available" + pytest.skip(skip_msg) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 0dcba8be..348a2ca8 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -11,6 +11,7 @@ from redisvl.index.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer +from tests.conftest import skip_if_no_redisearch, skip_if_no_redisearch_async @pytest.fixture(scope="session") @@ -19,7 +20,8 @@ def vectorizer(): @pytest.fixture -def cache(vectorizer, redis_url, worker_id): +def cache(client, vectorizer, redis_url, worker_id): + skip_if_no_redisearch(client) cache_instance = SemanticCache( name=f"llmcache_{worker_id}", vectorizer=vectorizer, @@ -31,7 +33,8 @@ def cache(vectorizer, redis_url, worker_id): @pytest.fixture -def cache_with_filters(vectorizer, redis_url, worker_id): +def cache_with_filters(client, vectorizer, redis_url, worker_id): + skip_if_no_redisearch(client) cache_instance = SemanticCache( name=f"llmcache_filters_{worker_id}", vectorizer=vectorizer, @@ -44,7 +47,8 @@ def cache_with_filters(vectorizer, redis_url, worker_id): @pytest.fixture -def cache_no_cleanup(vectorizer, redis_url, worker_id): +def cache_no_cleanup(client, vectorizer, redis_url, worker_id): + skip_if_no_redisearch(client) cache_instance = SemanticCache( name=f"llmcache_no_cleanup_{worker_id}", vectorizer=vectorizer, @@ -55,7 +59,8 @@ def cache_no_cleanup(vectorizer, redis_url, worker_id): @pytest.fixture -def cache_with_ttl(vectorizer, redis_url, worker_id): +def cache_with_ttl(client, vectorizer, redis_url, worker_id): + skip_if_no_redisearch(client) cache_instance = SemanticCache( name=f"llmcache_ttl_{worker_id}", vectorizer=vectorizer, @@ -69,6 +74,7 @@ def cache_with_ttl(vectorizer, redis_url, worker_id): @pytest.fixture def cache_with_redis_client(vectorizer, client, worker_id): + skip_if_no_redisearch(client) cache_instance = SemanticCache( name=f"llmcache_client_{worker_id}", vectorizer=vectorizer, @@ -750,7 +756,8 @@ def test_cache_filtering(cache_with_filters): assert len(results) == 0 -def test_cache_bad_filters(vectorizer, redis_url, worker_id): +def test_cache_bad_filters(client, vectorizer, redis_url, worker_id): + skip_if_no_redisearch(client) with pytest.raises(ValueError): cache_instance = SemanticCache( name=f"test_bad_filters_1_{worker_id}", @@ -819,6 +826,7 @@ def test_complex_filters(cache_with_filters): def test_cache_index_overwrite(client, redis_url, worker_id, hf_vectorizer): + skip_if_no_redisearch(client) # Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly redis_version = client.info()["redis_version"] if redis_version.startswith("6.2"): @@ -921,7 +929,8 @@ def test_no_key_collision_on_identical_prompts(redis_url, worker_id, hf_vectoriz assert len(filtered_results) == 2 -def test_create_cache_with_different_vector_types(worker_id, redis_url): +def test_create_cache_with_different_vector_types(client, worker_id, redis_url): + skip_if_no_redisearch(client) try: bfloat_cache = SemanticCache( name=f"bfloat_cache_{worker_id}", dtype="bfloat16", redis_url=redis_url @@ -951,6 +960,7 @@ def test_create_cache_with_different_vector_types(worker_id, redis_url): def test_bad_dtype_connecting_to_existing_cache(client, redis_url, worker_id): + skip_if_no_redisearch(client) # Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly redis_version = client.info()["redis_version"] if redis_version.startswith("6.2"): @@ -1021,7 +1031,10 @@ def test_deprecated_dtype_argument(redis_url, worker_id): @pytest.mark.asyncio -async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer): +async def test_cache_async_context_manager( + async_client, redis_url, worker_id, hf_vectorizer +): + await skip_if_no_redisearch_async(async_client) async with SemanticCache( name=f"test_cache_async_context_manager_{worker_id}", redis_url=redis_url, @@ -1034,8 +1047,9 @@ async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer): @pytest.mark.asyncio async def test_cache_async_context_manager_with_exception( - redis_url, worker_id, hf_vectorizer + async_client, redis_url, worker_id, hf_vectorizer ): + await skip_if_no_redisearch_async(async_client) try: async with SemanticCache( name=f"test_cache_async_context_manager_with_exception_{worker_id}", diff --git a/tests/integration/test_message_history.py b/tests/integration/test_message_history.py index a6bde980..623ede4d 100644 --- a/tests/integration/test_message_history.py +++ b/tests/integration/test_message_history.py @@ -5,6 +5,7 @@ from redisvl.extensions.constants import ID_FIELD_NAME from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory +from tests.conftest import skip_if_no_redisearch @pytest.fixture @@ -21,6 +22,7 @@ def standard_history(app_name, client): @pytest.fixture def semantic_history(app_name, client, hf_vectorizer): + skip_if_no_redisearch(client) history = SemanticMessageHistory( app_name, redis_client=client, overwrite=True, vectorizer=hf_vectorizer ) @@ -326,6 +328,7 @@ def test_standard_clear(standard_history): # test semantic message history def test_semantic_specify_client(client, hf_vectorizer): + skip_if_no_redisearch(client) history = SemanticMessageHistory( name="test_app", session_tag="abc", @@ -616,7 +619,8 @@ def test_semantic_drop(semantic_history): ] -def test_different_vector_dtypes(redis_url): +def test_different_vector_dtypes(client, redis_url): + skip_if_no_redisearch(client) try: bfloat_sess = SemanticMessageHistory( name="bfloat_history", dtype="bfloat16", redis_url=redis_url @@ -647,6 +651,7 @@ def test_different_vector_dtypes(redis_url): def test_bad_dtype_connecting_to_exiting_history(client, redis_url): + skip_if_no_redisearch(client) # Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly redis_version = client.info()["redis_version"] if redis_version.startswith("6.2"): @@ -674,7 +679,8 @@ def create_same_type(): ) -def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16): +def test_vectorizer_dtype_mismatch(client, redis_url, hf_vectorizer_float16): + skip_if_no_redisearch(client) with pytest.raises(ValueError): SemanticMessageHistory( name="test_dtype_mismatch", @@ -685,7 +691,8 @@ def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16): ) -def test_invalid_vectorizer(redis_url): +def test_invalid_vectorizer(client, redis_url): + skip_if_no_redisearch(client) with pytest.raises(TypeError): SemanticMessageHistory( name="test_invalid_vectorizer", @@ -695,7 +702,8 @@ def test_invalid_vectorizer(redis_url): ) -def test_passes_through_dtype_to_default_vectorizer(redis_url): +def test_passes_through_dtype_to_default_vectorizer(client, redis_url): + skip_if_no_redisearch(client) # The default is float32, so we should see float64 if we pass it in. cache = SemanticMessageHistory( name="test_pass_through_dtype", @@ -706,7 +714,8 @@ def test_passes_through_dtype_to_default_vectorizer(redis_url): assert cache._vectorizer.dtype == "float64" -def test_deprecated_dtype_argument(redis_url): +def test_deprecated_dtype_argument(client, redis_url): + skip_if_no_redisearch(client) with pytest.warns(DeprecationWarning): SemanticMessageHistory( name="float64 history", dtype="float64", redis_url=redis_url, overwrite=True diff --git a/tests/integration/test_no_proactive_module_checks.py b/tests/integration/test_no_proactive_module_checks.py index 81097dd3..5e6bc590 100644 --- a/tests/integration/test_no_proactive_module_checks.py +++ b/tests/integration/test_no_proactive_module_checks.py @@ -19,26 +19,12 @@ from redisvl.redis.connection import RedisConnectionFactory from redisvl.schema import IndexSchema from redisvl.utils.vectorize.base import BaseVectorizer - - -def has_redisearch_module(client): - """Check if RediSearch module is available.""" - try: - # Try to list indices - this is a RediSearch command - client.execute_command("FT._LIST") - return True - except (ResponseError, Exception): - return False - - -async def has_redisearch_module_async(client): - """Check if RediSearch module is available (async).""" - try: - # Try to list indices - this is a RediSearch command - await client.execute_command("FT._LIST") - return True - except (ResponseError, Exception): - return False +from tests.conftest import ( + has_redisearch_module, + has_redisearch_module_async, + skip_if_no_redisearch, + skip_if_no_redisearch_async, +) @pytest.fixture @@ -143,8 +129,7 @@ async def test_async_search_index_init_no_validation( def test_search_index_create_with_modules(self, client, sample_schema, worker_id): """Test that index.create() works with RediSearch available.""" # Skip if RediSearch is not available - if not has_redisearch_module(client): - pytest.skip("RediSearch module not available") + skip_if_no_redisearch(client) # Update schema name to be unique schema_copy = IndexSchema.from_dict(sample_schema.to_dict()) @@ -172,8 +157,7 @@ async def test_async_search_index_create_with_modules( ): """Test that async index.create() works with RediSearch available.""" # Skip if RediSearch is not available - if not await has_redisearch_module_async(async_client): - pytest.skip("RediSearch module not available") + await skip_if_no_redisearch_async(async_client) # Update schema name to be unique schema_copy = IndexSchema.from_dict(sample_schema.to_dict()) @@ -345,8 +329,7 @@ async def test_async_cleanup_without_validation(self, redis_url): def test_from_existing_index_no_validation(self, client, worker_id): """Test that SearchIndex.from_existing doesn't validate modules.""" # Skip if RediSearch is not available - if not has_redisearch_module(client): - pytest.skip("RediSearch module not available") + skip_if_no_redisearch(client) # First create an index normally schema = IndexSchema.from_dict( diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 1a10cedc..d082ceaa 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -12,7 +12,7 @@ RoutingConfig, ) from redisvl.redis.connection import compare_versions -from tests.conftest import skip_if_redis_version_below +from tests.conftest import skip_if_no_redisearch, skip_if_redis_version_below def get_base_path(): @@ -39,6 +39,7 @@ def routes(): @pytest.fixture def semantic_router(client, routes, hf_vectorizer): + skip_if_no_redisearch(client) router = SemanticRouter( name=f"test-router-{str(ULID())}", routes=routes, @@ -252,7 +253,8 @@ def test_bad_connection_info(routes): ) -def test_different_vector_dtypes(redis_url, routes): +def test_different_vector_dtypes(client, redis_url, routes): + skip_if_no_redisearch(client) try: bfloat_router = SemanticRouter( name="bfloat_router", @@ -289,6 +291,7 @@ def test_different_vector_dtypes(redis_url, routes): def test_bad_dtype_connecting_to_exiting_router(client, redis_url, routes): + skip_if_no_redisearch(client) # Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly redis_version = client.info()["redis_version"] if redis_version.startswith("6.2"): @@ -319,7 +322,8 @@ def test_bad_dtype_connecting_to_exiting_router(client, redis_url, routes): ) -def test_vectorizer_dtype_mismatch(routes, redis_url, hf_vectorizer_float16): +def test_vectorizer_dtype_mismatch(client, routes, redis_url, hf_vectorizer_float16): + skip_if_no_redisearch(client) with pytest.raises(ValueError): SemanticRouter( name="test_dtype_mismatch", @@ -331,7 +335,8 @@ def test_vectorizer_dtype_mismatch(routes, redis_url, hf_vectorizer_float16): ) -def test_invalid_vectorizer(redis_url): +def test_invalid_vectorizer(client, redis_url): + skip_if_no_redisearch(client) with pytest.raises(TypeError): SemanticRouter( name="test_invalid_vectorizer", @@ -341,7 +346,8 @@ def test_invalid_vectorizer(redis_url): ) -def test_passes_through_dtype_to_default_vectorizer(routes, redis_url): +def test_passes_through_dtype_to_default_vectorizer(client, routes, redis_url): + skip_if_no_redisearch(client) # The default is float32, so we should see float64 if we pass it in. router = SemanticRouter( name="test_pass_through_dtype", @@ -353,7 +359,8 @@ def test_passes_through_dtype_to_default_vectorizer(routes, redis_url): assert router.vectorizer.dtype == "float64" -def test_deprecated_dtype_argument(routes, redis_url): +def test_deprecated_dtype_argument(client, routes, redis_url): + skip_if_no_redisearch(client) with pytest.warns(DeprecationWarning): SemanticRouter( name="test_deprecated_dtype", @@ -364,8 +371,11 @@ def test_deprecated_dtype_argument(routes, redis_url): ) -def test_deprecated_distance_threshold_argument(semantic_router, routes, redis_url): +def test_deprecated_distance_threshold_argument( + semantic_router, client, routes, redis_url +): skip_if_redis_version_below(semantic_router._index.client, "7.0.0") + skip_if_no_redisearch(client) router = SemanticRouter( name="test_pass_through_dtype", @@ -378,9 +388,10 @@ def test_deprecated_distance_threshold_argument(semantic_router, routes, redis_u def test_routes_different_distance_thresholds_get_two( - semantic_router, routes, redis_url + semantic_router, client, routes, redis_url ): skip_if_redis_version_below(semantic_router._index.client, "7.0.0") + skip_if_no_redisearch(client) routes[0].distance_threshold = 0.5 routes[1].distance_threshold = 0.7 @@ -398,9 +409,10 @@ def test_routes_different_distance_thresholds_get_two( def test_routes_different_distance_thresholds_get_one( - semantic_router, routes, redis_url + semantic_router, client, routes, redis_url ): skip_if_redis_version_below(semantic_router._index.client, "7.0.0") + skip_if_no_redisearch(client) routes[0].distance_threshold = 0.5 @@ -462,6 +474,7 @@ def test_add_delete_route_references(semantic_router): def test_from_existing(client, redis_url, routes): + skip_if_no_redisearch(client) skip_if_redis_version_below(client, "7.0.0") # connect separately diff --git a/tests/integration/test_text_query_weights_integration.py b/tests/integration/test_text_query_weights_integration.py new file mode 100644 index 00000000..604eff5e --- /dev/null +++ b/tests/integration/test_text_query_weights_integration.py @@ -0,0 +1,190 @@ +"""Integration tests for TextQuery with field weights.""" + +import uuid + +import pytest + +from redisvl.index import SearchIndex +from redisvl.query import TextQuery +from redisvl.query.filter import Tag +from tests.conftest import skip_if_redis_version_below + + +@pytest.fixture +def weighted_index(client, redis_url, worker_id): + # BM25 scorer requires Redis Stack 7.2.0 or higher + skip_if_redis_version_below(client, "7.2.0", "BM25 scorer not available") + """Create an index with multiple text fields for testing weights.""" + unique_id = str(uuid.uuid4())[:8] + schema_dict = { + "index": { + "name": f"weighted_test_idx_{worker_id}_{unique_id}", + "prefix": f"weighted_doc_{worker_id}_{unique_id}", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text"}, + {"name": "content", "type": "text"}, + {"name": "tags", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "score", "type": "numeric"}, + ], + } + + index = SearchIndex.from_dict(schema_dict, redis_url=redis_url) + index.create(overwrite=True) + + # Load test data + data = [ + { + "id": "1", + "title": "Redis database introduction", + "content": "A comprehensive guide to getting started with Redis", + "tags": "tutorial beginner", + "category": "database", + "score": 95, + }, + { + "id": "2", + "title": "Advanced caching strategies", + "content": "Learn about Redis caching patterns and best practices", + "tags": "redis cache performance", + "category": "optimization", + "score": 88, + }, + { + "id": "3", + "title": "Python programming basics", + "content": "Introduction to Python with examples using Redis client", + "tags": "python redis programming", + "category": "programming", + "score": 90, + }, + { + "id": "4", + "title": "Data structures overview", + "content": "Understanding Redis data structures and their applications", + "tags": "redis structures", + "category": "database", + "score": 85, + }, + ] + + index.load(data) + yield index + index.delete(drop=True) + + +def test_text_query_with_single_weighted_field(weighted_index): + """Test TextQuery with a single weighted field.""" + text = "redis" + + # Query with higher weight on title + query = TextQuery( + text=text, + text_field_name={"title": 5.0}, + return_fields=["title", "content"], + num_results=4, + ) + + results = weighted_index.query(query) + assert len(results) > 0 + + # The document with "Redis" in the title should rank high + top_result = results[0] + assert "redis" in top_result["title"].lower() + + +def test_text_query_with_multiple_weighted_fields(weighted_index): + """Test TextQuery with multiple weighted fields.""" + text = "redis" + + # Query across multiple fields with different weights + query = TextQuery( + text=text, + text_field_name={"title": 3.0, "content": 2.0, "tags": 1.0}, + return_fields=["title", "content", "tags"], + num_results=4, + ) + + results = weighted_index.query(query) + assert len(results) > 0 + + # Check that results contain the search term in at least one field + for result in results: + text_found = ( + "redis" in result.get("title", "").lower() + or "redis" in result.get("content", "").lower() + or "redis" in result.get("tags", "").lower() + ) + assert text_found + + +def test_text_query_weights_with_filter(weighted_index): + """Test TextQuery with weights and filter expression.""" + text = "redis" + + # Query with weights and filter + filter_expr = Tag("category") == "database" + query = TextQuery( + text=text, + text_field_name={"title": 5.0, "content": 1.0}, + filter_expression=filter_expr, + return_fields=["title", "content", "category"], + num_results=4, + ) + + results = weighted_index.query(query) + + # Should only get database category results + for result in results: + assert result["category"] == "database" + + +def test_dynamic_weight_update(weighted_index): + """Test updating field weights dynamically.""" + text = "redis" + + # Start with equal weights + query = TextQuery( + text=text, + text_field_name={"title": 1.0, "content": 1.0}, + return_fields=["title", "content"], + num_results=4, + ) + + results1 = weighted_index.query(query) + + # Update to prioritize title + query.set_field_weights({"title": 10.0, "content": 1.0}) + + results2 = weighted_index.query(query) + + # Results might be reordered based on new weights + # At minimum, both queries should return results + assert len(results1) > 0 + assert len(results2) > 0 + + +def test_backward_compatibility_single_field(weighted_index): + """Test that the original single field API still works.""" + text = "redis" + + # Original API with single field name + query = TextQuery( + text=text, + text_field_name="content", + return_fields=["title", "content"], + num_results=4, + ) + + results = weighted_index.query(query) + assert len(results) > 0 + + # Check results are from content field + for result in results: + if "redis" in result.get("content", "").lower(): + break + else: + # At least one result should have redis in content + assert False, "No results with 'redis' in content field" diff --git a/tests/unit/test_text_query_weights.py b/tests/unit/test_text_query_weights.py new file mode 100644 index 00000000..08157346 --- /dev/null +++ b/tests/unit/test_text_query_weights.py @@ -0,0 +1,122 @@ +import pytest + +from redisvl.query import TextQuery +from redisvl.query.filter import Tag + + +def test_text_query_accepts_weights_dict(): + """Test that TextQuery can accept a dictionary of field weights.""" + text = "example search query" + + # Dictionary with field names as keys and weights as values + field_weights = {"title": 5.0, "content": 2.0, "tags": 1.0} + + # Should be able to create a TextQuery with weights dict + text_query = TextQuery(text=text, text_field_name=field_weights, num_results=10) + + # The query should have a method to set field weights + assert hasattr(text_query, "set_field_weights") + + # Check that weights are stored correctly + assert text_query.field_weights == field_weights + + +def test_text_query_generates_weighted_query_string(): + """Test that TextQuery generates correct query string with field weights.""" + text = "search query" + + # Single field with weight > 1 + text_query = TextQuery(text=text, text_field_name={"title": 5.0}, num_results=10) + + query_string = str(text_query) + # Should generate: @title:(search | query)=>{$weight:5.0} + assert ( + "@title:(search | query)=>{ $weight: 5.0 }" in query_string + or "@title:(search | query)=>{$weight:5.0}" in query_string + or "@title:(search | query) => { $weight: 5.0 }" in query_string + ) + + +def test_text_query_multiple_fields_with_weights(): + """Test that TextQuery generates correct query string with multiple weighted fields.""" + text = "search terms" + + field_weights = {"title": 3.0, "content": 1.5, "tags": 1.0} + + text_query = TextQuery(text=text, text_field_name=field_weights, num_results=10) + + query_string = str(text_query) + + # Should generate query with all fields and their weights, combined with OR + # The exact format depends on implementation, but all fields should be present + assert "@title:" in query_string + assert "@content:" in query_string + assert "@tags:" in query_string + + # Weights should be in the query + assert "$weight: 3.0" in query_string or "$weight:3.0" in query_string + assert "$weight: 1.5" in query_string or "$weight:1.5" in query_string + # Weight of 1.0 might be omitted as it's the default + + +def test_text_query_backward_compatibility(): + """Test that TextQuery still works with a single string field name.""" + text = "backward compatible" + + # Should work with just a string field name (original API) + text_query = TextQuery(text=text, text_field_name="description", num_results=5) + + query_string = str(text_query) + assert "@description:" in query_string + assert "backward | compatible" in query_string + + # Field weights should have the single field with weight 1.0 + assert text_query.field_weights == {"description": 1.0} + + +def test_text_query_weight_validation(): + """Test that invalid weights are properly rejected.""" + text = "test query" + + # Test negative weight + with pytest.raises(ValueError, match="must be positive"): + TextQuery(text=text, text_field_name={"title": -1.0}, num_results=10) + + # Test zero weight + with pytest.raises(ValueError, match="must be positive"): + TextQuery(text=text, text_field_name={"title": 0}, num_results=10) + + # Test non-numeric weight + with pytest.raises(TypeError, match="must be numeric"): + TextQuery(text=text, text_field_name={"title": "five"}, num_results=10) + + # Test invalid field name type + with pytest.raises(TypeError, match="Field name must be a string"): + TextQuery(text=text, text_field_name={123: 1.0}, num_results=10) + + # Test invalid text_field_name type (not str or dict) + with pytest.raises( + TypeError, match="text_field_name must be a string or dictionary" + ): + TextQuery(text=text, text_field_name=["title", "content"], num_results=10) + + +def test_set_field_weights_method(): + """Test that set_field_weights method updates weights correctly.""" + text = "dynamic weights" + + # Start with single field + text_query = TextQuery(text=text, text_field_name="title", num_results=10) + + assert text_query.field_weights == {"title": 1.0} + + # Update to multiple fields with weights + new_weights = {"title": 5.0, "content": 2.0} + text_query.set_field_weights(new_weights) + + assert text_query.field_weights == new_weights + + # Query string should reflect new weights + query_string = str(text_query) + assert "$weight: 5.0" in query_string or "$weight:5.0" in query_string + assert "$weight: 2.0" in query_string or "$weight:2.0" in query_string