Skip to content

Commit 2278f25

Browse files
committed
add normalize cosine distance flag
1 parent 103c221 commit 2278f25

File tree

5 files changed

+158
-10
lines changed

5 files changed

+158
-10
lines changed

redisvl/index/index.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
Union,
1717
)
1818

19-
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
19+
from redisvl.utils.utils import (
20+
deprecated_argument,
21+
deprecated_function,
22+
norm_cosine_distance,
23+
sync_wrapper,
24+
)
2025

2126
if TYPE_CHECKING:
2227
from redis.commands.search.aggregation import AggregateResult
@@ -30,14 +35,21 @@
3035

3136
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3237
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
33-
from redisvl.query import BaseQuery, CountQuery, FilterQuery
38+
from redisvl.query import (
39+
BaseQuery,
40+
CountQuery,
41+
FilterQuery,
42+
VectorQuery,
43+
VectorRangeQuery,
44+
)
3445
from redisvl.query.filter import FilterExpression
3546
from redisvl.redis.connection import (
3647
RedisConnectionFactory,
3748
convert_index_info_to_schema,
3849
)
3950
from redisvl.redis.utils import convert_bytes
4051
from redisvl.schema import IndexSchema, StorageType
52+
from redisvl.schema.fields import VectorDistanceMetric
4153
from redisvl.utils.log import get_logger
4254

4355
logger = get_logger(__name__)
@@ -50,7 +62,7 @@
5062

5163

5264
def process_results(
53-
results: "Result", query: BaseQuery, storage_type: StorageType
65+
results: "Result", query: BaseQuery, schema: IndexSchema
5466
) -> List[Dict[str, Any]]:
5567
"""Convert a list of search Result objects into a list of document
5668
dictionaries.
@@ -75,11 +87,18 @@ def process_results(
7587

7688
# Determine if unpacking JSON is needed
7789
unpack_json = (
78-
(storage_type == StorageType.JSON)
90+
(schema.index.storage_type == StorageType.JSON)
7991
and isinstance(query, FilterQuery)
8092
and not query._return_fields # type: ignore
8193
)
8294

95+
normalize_cosine_distance = (
96+
(isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery))
97+
and query._normalize_cosine_distance
98+
and schema.fields[query._vector_field_name].attrs.distance_metric # type: ignore
99+
== VectorDistanceMetric.COSINE
100+
)
101+
83102
# Process records
84103
def _process(doc: "Document") -> Dict[str, Any]:
85104
doc_dict = doc.__dict__
@@ -93,6 +112,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
93112
return {"id": doc_dict.get("id"), **json_data}
94113
raise ValueError(f"Unable to parse json data from Redis {json_data}")
95114

115+
if normalize_cosine_distance:
116+
# convert float back to string to be consistent
117+
doc_dict[query.DISTANCE_ID] = str( # type: ignore
118+
norm_cosine_distance(float(doc_dict[query.DISTANCE_ID])) # type: ignore
119+
)
120+
96121
# Remove 'payload' if present
97122
doc_dict.pop("payload", None)
98123

@@ -665,9 +690,7 @@ def search(self, *args, **kwargs) -> "Result":
665690
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
666691
"""Execute a query and process results."""
667692
results = self.search(query.query, query_params=query.params)
668-
return process_results(
669-
results, query=query, storage_type=self.schema.index.storage_type
670-
)
693+
return process_results(results, query=query, schema=self.schema)
671694

672695
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
673696
"""Execute a query on the index.
@@ -1219,9 +1242,7 @@ async def search(self, *args, **kwargs) -> "Result":
12191242
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
12201243
"""Asynchronously execute a query and process results."""
12211244
results = await self.search(query.query, query_params=query.params)
1222-
return process_results(
1223-
results, query=query, storage_type=self.schema.index.storage_type
1224-
)
1245+
return process_results(results, query=query, schema=self.schema)
12251246

12261247
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
12271248
"""Asynchronously execute a query on the index.

redisvl/query/query.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def __init__(
188188
dialect: int = 2,
189189
sort_by: Optional[str] = None,
190190
in_order: bool = False,
191+
normalize_cosine_distance: bool = False,
191192
):
192193
"""A query for running a vector search along with an optional filter
193194
expression.
@@ -213,6 +214,9 @@ def __init__(
213214
in_order (bool): Requires the terms in the field to have
214215
the same order as the terms in the query filter, regardless of
215216
the offsets between them. Defaults to False.
217+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
218+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
219+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
216220
217221
Raises:
218222
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -224,6 +228,7 @@ def __init__(
224228
self._vector_field_name = vector_field_name
225229
self._dtype = dtype
226230
self._num_results = num_results
231+
self._normalize_cosine_distance = normalize_cosine_distance
227232
self.set_filter(filter_expression)
228233
query_string = self._build_query_string()
229234

@@ -284,6 +289,7 @@ def __init__(
284289
dialect: int = 2,
285290
sort_by: Optional[str] = None,
286291
in_order: bool = False,
292+
normalize_cosine_distance: bool = False,
287293
):
288294
"""A query for running a filtered vector search based on semantic
289295
distance threshold.
@@ -312,6 +318,9 @@ def __init__(
312318
in_order (bool): Requires the terms in the field to have
313319
the same order as the terms in the query filter, regardless of
314320
the offsets between them. Defaults to False.
321+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
322+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
323+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
315324
316325
Raises:
317326
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -324,6 +333,7 @@ def __init__(
324333
self._vector_field_name = vector_field_name
325334
self._dtype = dtype
326335
self._num_results = num_results
336+
self._normalize_cosine_distance = normalize_cosine_distance
327337
self.set_distance_threshold(distance_threshold)
328338
self.set_filter(filter_expression)
329339
query_string = self._build_query_string()

redisvl/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,10 @@ def wrapper():
191191
return
192192

193193
return wrapper
194+
195+
196+
def norm_cosine_distance(value: float) -> float:
197+
"""
198+
Normalize the cosine distance to a similarity score between 0 and 1.
199+
"""
200+
return (2 - value) / 2

tests/integration/test_query.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ def sorted_vector_query():
5252
)
5353

5454

55+
@pytest.fixture
56+
def normalized_vector_query():
57+
return VectorQuery(
58+
vector=[0.1, 0.1, 0.5],
59+
vector_field_name="user_embedding",
60+
normalize_cosine_distance=True,
61+
return_score=True,
62+
return_fields=[
63+
"user",
64+
"credit_score",
65+
"age",
66+
"job",
67+
"location",
68+
"last_updated",
69+
],
70+
)
71+
72+
5573
@pytest.fixture
5674
def filter_query():
5775
return FilterQuery(
@@ -83,6 +101,18 @@ def sorted_filter_query():
83101
)
84102

85103

104+
@pytest.fixture
105+
def normalized_range_query():
106+
return RangeQuery(
107+
vector=[0.1, 0.1, 0.5],
108+
vector_field_name="user_embedding",
109+
normalize_cosine_distance=True,
110+
return_score=True,
111+
return_fields=["user", "credit_score", "age", "job", "location"],
112+
distance_threshold=0.2,
113+
)
114+
115+
86116
@pytest.fixture
87117
def range_query():
88118
return RangeQuery(
@@ -154,6 +184,56 @@ def hash_preprocess(item: dict) -> dict:
154184
index.delete(drop=True)
155185

156186

187+
@pytest.fixture
188+
def L2_index(sample_data, redis_url):
189+
# construct a search index from the schema
190+
index = SearchIndex.from_dict(
191+
{
192+
"index": {
193+
"name": "L2_index",
194+
"prefix": "L2_index",
195+
"storage_type": "hash",
196+
},
197+
"fields": [
198+
{"name": "credit_score", "type": "tag"},
199+
{"name": "job", "type": "text"},
200+
{"name": "age", "type": "numeric"},
201+
{"name": "last_updated", "type": "numeric"},
202+
{"name": "location", "type": "geo"},
203+
{
204+
"name": "user_embedding",
205+
"type": "vector",
206+
"attrs": {
207+
"dims": 3,
208+
"distance_metric": "L2",
209+
"algorithm": "flat",
210+
"datatype": "float32",
211+
},
212+
},
213+
],
214+
},
215+
redis_url=redis_url,
216+
)
217+
218+
# create the index (no data yet)
219+
index.create(overwrite=True)
220+
221+
# Prepare and load the data
222+
def hash_preprocess(item: dict) -> dict:
223+
return {
224+
**item,
225+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
226+
}
227+
228+
index.load(sample_data, preprocess=hash_preprocess)
229+
230+
# run the test
231+
yield index
232+
233+
# clean up
234+
index.delete(drop=True)
235+
236+
157237
def test_search_and_query(index):
158238
# *=>[KNN 7 @user_embedding $vector AS vector_distance]
159239
v = VectorQuery(
@@ -531,3 +611,26 @@ def test_query_with_chunk_number_zero():
531611
assert (
532612
str(filter_conditions) == expected_query_str
533613
), "Query with chunk_number zero is incorrect"
614+
615+
616+
def test_query_normalize_cosine_distance(index, normalized_vector_query):
617+
618+
res = index.query(normalized_vector_query)
619+
620+
for r in res:
621+
assert 0 <= float(r["vector_distance"]) <= 1
622+
623+
624+
def test_query_normalize_cosine_distance_ip_distance(L2_index, normalized_vector_query):
625+
626+
res = L2_index.query(normalized_vector_query)
627+
628+
assert any(float(r["vector_distance"]) > 1 for r in res)
629+
630+
631+
def test_range_query_normalize_cosine_distance(index, normalized_range_query):
632+
633+
res = index.query(normalized_range_query)
634+
635+
for r in res:
636+
assert 0 <= float(r["vector_distance"]) <= 1

tests/unit/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313
assert_no_warnings,
1414
deprecated_argument,
1515
deprecated_function,
16+
norm_cosine_distance,
1617
)
1718

1819

20+
def test_norm_cosine_distance():
21+
input = 2
22+
expected = 0
23+
assert norm_cosine_distance(input) == expected
24+
25+
1926
def test_even_number_of_elements():
2027
"""Test with an even number of elements"""
2128
values = ["key1", "value1", "key2", "value2"]

0 commit comments

Comments
 (0)