Skip to content

Commit 1340690

Browse files
refactor the rerankers to conform to new pydantic
1 parent a0e6fa5 commit 1340690

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

redisvl/utils/rerank/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, Dict, List, Optional, Tuple, Union
33

4-
from pydantic.v1 import BaseModel, validator
4+
from pydantic import BaseModel, field_validator
55

66

77
class BaseReranker(BaseModel, ABC):
@@ -10,15 +10,15 @@ class BaseReranker(BaseModel, ABC):
1010
limit: int
1111
return_score: bool
1212

13-
@validator("limit")
13+
@field_validator("limit")
1414
@classmethod
1515
def check_limit(cls, value):
1616
"""Ensures the limit is a positive integer."""
1717
if value <= 0:
1818
raise ValueError("Limit must be a positive integer.")
1919
return value
2020

21-
@validator("rank_by")
21+
@field_validator("rank_by")
2222
@classmethod
2323
def check_rank_by(cls, value):
2424
"""Ensures that rank_by is a list of strings if provided."""

redisvl/utils/rerank/cohere.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Dict, List, Optional, Tuple, Union
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55

66
from redisvl.utils.rerank.base import BaseReranker
77

@@ -45,7 +45,8 @@ def __init__(
4545
limit: int = 5,
4646
return_score: bool = True,
4747
api_config: Optional[Dict] = None,
48-
) -> None:
48+
**kwargs,
49+
):
4950
"""
5051
Initialize the CohereReranker with specified model, ranking criteria,
5152
and API configuration.
@@ -71,9 +72,9 @@ def __init__(
7172
super().__init__(
7273
model=model, rank_by=rank_by, limit=limit, return_score=return_score
7374
)
74-
self._initialize_clients(api_config)
75+
self._initialize_clients(api_config, **kwargs)
7576

76-
def _initialize_clients(self, api_config: Optional[Dict]):
77+
def _initialize_clients(self, api_config: Optional[Dict], **kwargs):
7778
"""
7879
Setup the Cohere clients using the provided API key or an
7980
environment variable.
@@ -96,8 +97,8 @@ def _initialize_clients(self, api_config: Optional[Dict]):
9697
"Cohere API key is required. "
9798
"Provide it in api_config or set the COHERE_API_KEY environment variable."
9899
)
99-
self._client = Client(api_key=api_key, client_name="redisvl")
100-
self._aclient = AsyncClient(api_key=api_key, client_name="redisvl")
100+
self._client = Client(api_key=api_key, client_name="redisvl", **kwargs)
101+
self._aclient = AsyncClient(api_key=api_key, client_name="redisvl", **kwargs)
101102

102103
def _preprocess(
103104
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs

redisvl/utils/rerank/hf_cross_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional, Tuple, Union
22

3-
from pydantic.v1 import PrivateAttr
3+
from pydantic import PrivateAttr
44

55
from redisvl.utils.rerank.base import BaseReranker
66

@@ -39,7 +39,7 @@ def __init__(
3939
limit: int = 3,
4040
return_score: bool = True,
4141
**kwargs,
42-
) -> None:
42+
):
4343
"""
4444
Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.
4545

redisvl/utils/rerank/voyageai.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Dict, List, Optional, Tuple, Union
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55

66
from redisvl.utils.rerank.base import BaseReranker
77

@@ -45,7 +45,8 @@ def __init__(
4545
limit: int = 5,
4646
return_score: bool = True,
4747
api_config: Optional[Dict] = None,
48-
) -> None:
48+
**kwargs,
49+
):
4950
"""
5051
Initialize the VoyageAIReranker with specified model, ranking criteria,
5152
and API configuration.
@@ -70,9 +71,9 @@ def __init__(
7071
super().__init__(
7172
model=model, rank_by=rank_by, limit=limit, return_score=return_score
7273
)
73-
self._initialize_clients(api_config)
74+
self._initialize_clients(api_config, **kwargs)
7475

75-
def _initialize_clients(self, api_config: Optional[Dict]):
76+
def _initialize_clients(self, api_config: Optional[Dict], **kwargs):
7677
"""
7778
Setup the VoyageAI clients using the provided API key or an
7879
environment variable.
@@ -95,8 +96,8 @@ def _initialize_clients(self, api_config: Optional[Dict]):
9596
"VoyageAI API key is required. "
9697
"Provide it in api_config or set the VOYAGE_API_KEY environment variable."
9798
)
98-
self._client = Client(api_key=api_key)
99-
self._aclient = AsyncClient(api_key=api_key)
99+
self._client = Client(api_key=api_key, **kwargs)
100+
self._aclient = AsyncClient(api_key=api_key, **kwargs)
100101

101102
def _preprocess(
102103
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs

0 commit comments

Comments
 (0)