Skip to content

Commit 0cf4411

Browse files
pydantic facelift to v2
1 parent b751db7 commit 0cf4411

File tree

12 files changed

+73
-71
lines changed

12 files changed

+73
-71
lines changed

redisvl/extensions/llmcache/schema.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional
22

3-
from pydantic.v1 import BaseModel, Field, root_validator, validator
3+
from pydantic import BaseModel, Field, field_validator, model_validator
44

55
from redisvl.extensions.constants import (
66
CACHE_VECTOR_FIELD_NAME,
@@ -34,22 +34,23 @@ class CacheEntry(BaseModel):
3434
filters: Optional[Dict[str, Any]] = Field(default=None)
3535
"""Optional filter data stored on the cache entry for customizing retrieval"""
3636

37-
@root_validator(pre=True)
37+
@model_validator(mode="before")
3838
@classmethod
3939
def generate_id(cls, values):
4040
# Ensure entry_id is set
4141
if not values.get("entry_id"):
4242
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
4343
return values
4444

45-
@validator("metadata")
45+
@field_validator("metadata")
46+
@classmethod
4647
def non_empty_metadata(cls, v):
4748
if v is not None and not isinstance(v, dict):
4849
raise TypeError("Metadata must be a dictionary.")
4950
return v
5051

5152
def to_dict(self, dtype: str) -> Dict:
52-
data = self.dict(exclude_none=True)
53+
data = self.model_dump(exclude_none=True)
5354
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
5455
if self.metadata is not None:
5556
data["metadata"] = serialize(self.metadata)
@@ -79,18 +80,18 @@ class CacheHit(BaseModel):
7980
filters: Optional[Dict[str, Any]] = Field(default=None)
8081
"""Optional filter data stored on the cache entry for customizing retrieval"""
8182

82-
@root_validator(pre=True)
83+
@model_validator(mode="before")
8384
@classmethod
8485
def validate_cache_hit(cls, values):
8586
# Deserialize metadata if necessary
8687
if "metadata" in values and isinstance(values["metadata"], str):
8788
values["metadata"] = deserialize(values["metadata"])
8889

8990
# Separate filters from other fields
90-
known_fields = set(cls.__fields__.keys())
91+
known_fields = set(cls.model_fields.keys())
9192
filters = {k: v for k, v in values.items() if k not in known_fields}
9293

93-
# Add filters to values
94+
# Add filters to valuesgiy s
9495
if filters:
9596
values["filters"] = filters
9697

@@ -101,7 +102,7 @@ def validate_cache_hit(cls, values):
101102
return values
102103

103104
def to_dict(self) -> Dict:
104-
data = self.dict(exclude_none=True)
105+
data = self.model_dump(exclude_none=True)
105106
if self.filters:
106107
data.update(self.filters)
107108
del data["filters"]

redisvl/extensions/router/schema.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Dict, List, Optional
33

4-
from pydantic.v1 import BaseModel, Field, validator
4+
from pydantic import BaseModel, Field, field_validator
55

66
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
77
from redisvl.schema import IndexSchema
@@ -19,21 +19,24 @@ class Route(BaseModel):
1919
distance_threshold: float = Field(default=0.5)
2020
"""Distance threshold for matching the route."""
2121

22-
@validator("name")
22+
@field_validator("name")
23+
@classmethod
2324
def name_must_not_be_empty(cls, v):
2425
if not v or not v.strip():
2526
raise ValueError("Route name must not be empty")
2627
return v
2728

28-
@validator("references")
29+
@field_validator("references")
30+
@classmethod
2931
def references_must_not_be_empty(cls, v):
3032
if not v:
3133
raise ValueError("References must not be empty")
3234
if any(not ref.strip() for ref in v):
3335
raise ValueError("All references must be non-empty strings")
3436
return v
3537

36-
@validator("distance_threshold")
38+
@field_validator("distance_threshold")
39+
@classmethod
3740
def distance_threshold_must_be_positive(cls, v):
3841
if v is not None and v <= 0:
3942
raise ValueError("Route distance threshold must be greater than zero")
@@ -79,7 +82,8 @@ class RoutingConfig(BaseModel):
7982
description="Global distance threshold is deprecated all distance_thresholds now apply at route level.",
8083
)
8184

82-
@validator("max_k")
85+
@field_validator("max_k")
86+
@classmethod
8387
def max_k_must_be_positive(cls, v):
8488
if v <= 0:
8589
raise ValueError("max_k must be a positive integer")

redisvl/extensions/router/semantic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import redis.commands.search.reducers as reducers
55
import yaml
6-
from pydantic.v1 import BaseModel, Field, PrivateAttr
6+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
77
from redis import Redis
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
@@ -44,8 +44,7 @@ class SemanticRouter(BaseModel):
4444

4545
_index: SearchIndex = PrivateAttr()
4646

47-
class Config:
48-
arbitrary_types_allowed = True
47+
model_config = ConfigDict(arbitrary_types_allowed=True)
4948

5049
@deprecated_argument("dtype", "vectorizer")
5150
def __init__(

redisvl/extensions/session_manager/schema.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional
22

3-
from pydantic.v1 import BaseModel, Field, root_validator
3+
from pydantic import BaseModel, ConfigDict, Field, model_validator
44

55
from redisvl.extensions.constants import (
66
CONTENT_FIELD_NAME,
@@ -33,11 +33,9 @@ class ChatMessage(BaseModel):
3333
"""An optional identifier for a tool call associated with the message."""
3434
vector_field: Optional[List[float]] = Field(default=None)
3535
"""The vector representation of the message content."""
36+
model_config = ConfigDict(arbitrary_types_allowed=True)
3637

37-
class Config:
38-
arbitrary_types_allowed = True
39-
40-
@root_validator(pre=True)
38+
@model_validator(mode="before")
4139
@classmethod
4240
def generate_id(cls, values):
4341
if TIMESTAMP_FIELD_NAME not in values:
@@ -49,7 +47,7 @@ def generate_id(cls, values):
4947
return values
5048

5149
def to_dict(self, dtype: Optional[str] = None) -> Dict:
52-
data = self.dict(exclude_none=True)
50+
data = self.model_dump(exclude_none=True)
5351

5452
# handle optional fields
5553
if SESSION_VECTOR_FIELD_NAME in data:

redisvl/index/storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
from typing import Any, Callable, Dict, Iterable, List, Optional
33

4-
from pydantic.v1 import BaseModel
4+
from pydantic import BaseModel
55
from redis import Redis
66
from redis.asyncio import Redis as AsyncRedis
77
from redis.commands.search.indexDefinition import IndexType

redisvl/schema/fields.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,16 @@
66
"""
77

88
from enum import Enum
9-
from typing import Any, Dict, Optional, Tuple, Type, Union
9+
from typing import Any, Dict, Literal, Optional, Tuple, Type, Union
1010

11-
from pydantic.v1 import BaseModel, Field, validator
11+
from pydantic import BaseModel, Field, field_validator
1212
from redis.commands.search.field import Field as RedisField
1313
from redis.commands.search.field import GeoField as RedisGeoField
1414
from redis.commands.search.field import NumericField as RedisNumericField
1515
from redis.commands.search.field import TagField as RedisTagField
1616
from redis.commands.search.field import TextField as RedisTextField
1717
from redis.commands.search.field import VectorField as RedisVectorField
1818

19-
### Attribute Enums ###
20-
2119

2220
class VectorDistanceMetric(str, Enum):
2321
COSINE = "COSINE"
@@ -99,7 +97,7 @@ class BaseVectorFieldAttributes(BaseModel):
9997
initial_cap: Optional[int] = None
10098
"""Initial vector capacity in the index affecting memory allocation size of the index"""
10199

102-
@validator("algorithm", "datatype", "distance_metric", pre=True)
100+
@field_validator("algorithm", "datatype", "distance_metric", mode="before")
103101
@classmethod
104102
def uppercase_strings(cls, v):
105103
"""Validate that provided values are cast to uppercase"""
@@ -121,9 +119,7 @@ def field_data(self) -> Dict[str, Any]:
121119
class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
122120
"""FLAT vector field attributes"""
123121

124-
algorithm: VectorIndexAlgorithm = Field(
125-
default=VectorIndexAlgorithm.FLAT, const=True
126-
)
122+
algorithm: Literal[VectorIndexAlgorithm.FLAT] = VectorIndexAlgorithm.FLAT
127123
"""The indexing algorithm for the vector field"""
128124
block_size: Optional[int] = None
129125
"""Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion"""
@@ -132,9 +128,7 @@ class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
132128
class HNSWVectorFieldAttributes(BaseVectorFieldAttributes):
133129
"""HNSW vector field attributes"""
134130

135-
algorithm: VectorIndexAlgorithm = Field(
136-
default=VectorIndexAlgorithm.HNSW, const=True
137-
)
131+
algorithm: Literal[VectorIndexAlgorithm.HNSW] = VectorIndexAlgorithm.HNSW
138132
"""The indexing algorithm for the vector field"""
139133
m: int = Field(default=16)
140134
"""Number of max outgoing edges for each graph node in each layer"""
@@ -173,7 +167,7 @@ def as_redis_field(self) -> RedisField:
173167
class TextField(BaseField):
174168
"""Text field supporting a full text search index"""
175169

176-
type: str = Field(default="text", const=True)
170+
type: Literal["text"] = "text"
177171
attrs: TextFieldAttributes = Field(default_factory=TextFieldAttributes)
178172

179173
def as_redis_field(self) -> RedisField:
@@ -191,7 +185,7 @@ def as_redis_field(self) -> RedisField:
191185
class TagField(BaseField):
192186
"""Tag field for simple boolean-style filtering"""
193187

194-
type: str = Field(default="tag", const=True)
188+
type: Literal["tag"] = "tag"
195189
attrs: TagFieldAttributes = Field(default_factory=TagFieldAttributes)
196190

197191
def as_redis_field(self) -> RedisField:
@@ -208,7 +202,7 @@ def as_redis_field(self) -> RedisField:
208202
class NumericField(BaseField):
209203
"""Numeric field for numeric range filtering"""
210204

211-
type: str = Field(default="numeric", const=True)
205+
type: Literal["numeric"] = "numeric"
212206
attrs: NumericFieldAttributes = Field(default_factory=NumericFieldAttributes)
213207

214208
def as_redis_field(self) -> RedisField:
@@ -223,7 +217,7 @@ def as_redis_field(self) -> RedisField:
223217
class GeoField(BaseField):
224218
"""Geo field with a geo-spatial index for location based search"""
225219

226-
type: str = Field(default="geo", const=True)
220+
type: Literal["geo"] = "geo"
227221
attrs: GeoFieldAttributes = Field(default_factory=GeoFieldAttributes)
228222

229223
def as_redis_field(self) -> RedisField:
@@ -238,7 +232,7 @@ def as_redis_field(self) -> RedisField:
238232
class FlatVectorField(BaseField):
239233
"Vector field with a FLAT index (brute force nearest neighbors search)"
240234

241-
type: str = Field(default="vector", const=True)
235+
type: Literal["vector"] = "vector"
242236
attrs: FlatVectorFieldAttributes
243237

244238
def as_redis_field(self) -> RedisField:
@@ -253,7 +247,7 @@ def as_redis_field(self) -> RedisField:
253247
class HNSWVectorField(BaseField):
254248
"""Vector field with an HNSW index (approximate nearest neighbors search)"""
255249

256-
type: str = Field(default="vector", const=True)
250+
type: Literal["vector"] = "vector"
257251
attrs: HNSWVectorFieldAttributes
258252

259253
def as_redis_field(self) -> RedisField:
@@ -271,20 +265,21 @@ def as_redis_field(self) -> RedisField:
271265
return RedisVectorField(name, self.attrs.algorithm, field_data, as_name=as_name)
272266

273267

274-
class FieldFactory:
275-
"""Factory class to create fields from client data and kwargs."""
268+
FIELD_TYPE_MAP = {
269+
"tag": TagField,
270+
"text": TextField,
271+
"numeric": NumericField,
272+
"geo": GeoField,
273+
}
276274

277-
FIELD_TYPE_MAP = {
278-
"tag": TagField,
279-
"text": TextField,
280-
"numeric": NumericField,
281-
"geo": GeoField,
282-
}
275+
VECTOR_FIELD_TYPE_MAP = {
276+
"flat": FlatVectorField,
277+
"hnsw": HNSWVectorField,
278+
}
283279

284-
VECTOR_FIELD_TYPE_MAP = {
285-
"flat": FlatVectorField,
286-
"hnsw": HNSWVectorField,
287-
}
280+
281+
class FieldFactory:
282+
"""Factory class to create fields from client data and kwargs."""
288283

289284
@classmethod
290285
def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
@@ -296,10 +291,10 @@ def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
296291
raise ValueError("Must provide dims param for the vector field.")
297292

298293
algorithm = attrs["algorithm"].lower()
299-
if algorithm not in cls.VECTOR_FIELD_TYPE_MAP:
294+
if algorithm not in VECTOR_FIELD_TYPE_MAP:
300295
raise ValueError(f"Unknown vector field algorithm: {algorithm}")
301296

302-
return cls.VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore
297+
return VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore
303298

304299
@classmethod
305300
def create_field(
@@ -314,8 +309,14 @@ def create_field(
314309
if type == "vector":
315310
field_class = cls.pick_vector_field_type(attrs)
316311
else:
317-
if type not in cls.FIELD_TYPE_MAP:
312+
if type not in FIELD_TYPE_MAP:
318313
raise ValueError(f"Unknown field type: {type}")
319-
field_class = cls.FIELD_TYPE_MAP[type] # type: ignore
314+
field_class = FIELD_TYPE_MAP[type] # type: ignore
320315

321-
return field_class(name=name, path=path, attrs=attrs) # type: ignore
316+
return field_class.model_validate(
317+
{
318+
"name": name,
319+
"path": path,
320+
"attrs": attrs,
321+
}
322+
)

0 commit comments

Comments
 (0)