1616 Union ,
1717)
1818
19- from redisvl .utils .utils import deprecated_argument , deprecated_function , sync_wrapper
19+ from redisvl .utils .utils import (
20+ deprecated_argument ,
21+ deprecated_function ,
22+ norm_cosine_distance ,
23+ sync_wrapper ,
24+ )
2025
2126if TYPE_CHECKING :
2227 from redis .commands .search .aggregation import AggregateResult
3035
3136from redisvl .exceptions import RedisModuleVersionError , RedisSearchError
3237from redisvl .index .storage import BaseStorage , HashStorage , JsonStorage
33- from redisvl .query import BaseQuery , CountQuery , FilterQuery
38+ from redisvl .query import (
39+ BaseQuery ,
40+ CountQuery ,
41+ FilterQuery ,
42+ VectorQuery ,
43+ VectorRangeQuery ,
44+ )
3445from redisvl .query .filter import FilterExpression
3546from redisvl .redis .connection import (
3647 RedisConnectionFactory ,
3748 convert_index_info_to_schema ,
3849)
3950from redisvl .redis .utils import convert_bytes
4051from redisvl .schema import IndexSchema , StorageType
52+ from redisvl .schema .fields import VectorDistanceMetric
4153from redisvl .utils .log import get_logger
4254
4355logger = get_logger (__name__ )
5062
5163
5264def process_results (
53- results : "Result" , query : BaseQuery , storage_type : StorageType
65+ results : "Result" , query : BaseQuery , schema : IndexSchema
5466) -> List [Dict [str , Any ]]:
5567 """Convert a list of search Result objects into a list of document
5668 dictionaries.
@@ -75,11 +87,18 @@ def process_results(
7587
7688 # Determine if unpacking JSON is needed
7789 unpack_json = (
78- (storage_type == StorageType .JSON )
90+ (schema . index . storage_type == StorageType .JSON )
7991 and isinstance (query , FilterQuery )
8092 and not query ._return_fields # type: ignore
8193 )
8294
95+ normalize_cosine_distance = (
96+ (isinstance (query , VectorQuery ) or isinstance (query , VectorRangeQuery ))
97+ and query ._normalize_cosine_distance
98+ and schema .fields [query ._vector_field_name ].attrs .distance_metric # type: ignore
99+ == VectorDistanceMetric .COSINE
100+ )
101+
83102 # Process records
84103 def _process (doc : "Document" ) -> Dict [str , Any ]:
85104 doc_dict = doc .__dict__
@@ -93,6 +112,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
93112 return {"id" : doc_dict .get ("id" ), ** json_data }
94113 raise ValueError (f"Unable to parse json data from Redis { json_data } " )
95114
115+ if normalize_cosine_distance :
116+ # convert float back to string to be consistent
117+ doc_dict [query .DISTANCE_ID ] = str ( # type: ignore
118+ norm_cosine_distance (float (doc_dict [query .DISTANCE_ID ])) # type: ignore
119+ )
120+
96121 # Remove 'payload' if present
97122 doc_dict .pop ("payload" , None )
98123
@@ -665,9 +690,7 @@ def search(self, *args, **kwargs) -> "Result":
665690 def _query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
666691 """Execute a query and process results."""
667692 results = self .search (query .query , query_params = query .params )
668- return process_results (
669- results , query = query , storage_type = self .schema .index .storage_type
670- )
693+ return process_results (results , query = query , schema = self .schema )
671694
672695 def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
673696 """Execute a query on the index.
@@ -1219,9 +1242,7 @@ async def search(self, *args, **kwargs) -> "Result":
12191242 async def _query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
12201243 """Asynchronously execute a query and process results."""
12211244 results = await self .search (query .query , query_params = query .params )
1222- return process_results (
1223- results , query = query , storage_type = self .schema .index .storage_type
1224- )
1245+ return process_results (results , query = query , schema = self .schema )
12251246
12261247 async def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
12271248 """Asynchronously execute a query on the index.
0 commit comments