Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions model_compression_toolkit/target_platform_capabilities/schema/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated

from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \
from pydantic import BaseModel, Field, PositiveInt, ConfigDict, field_validator, \
model_validator

from mct_quantizers import QuantizationMethod
Expand Down Expand Up @@ -124,7 +124,7 @@ class AttributeQuantizationConfig(BaseModel):
@property
def field_names(self) -> list:
"""Return a list of field names for the model."""
return list(self.__fields__.keys())
return list(self.model_fields.keys())

def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
"""
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_info(self) -> Dict[str, Any]:
Returns:
dict: Information about the quantization configuration as a dictionary.
"""
return self.dict() # pragma: no cover
return self.model_dump() # pragma: no cover

def clone_and_edit(
self,
Expand Down Expand Up @@ -240,6 +240,7 @@ class QuantizationConfigOptions(BaseModel):
model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and set the base_config based on quantization_configurations.
Expand Down Expand Up @@ -464,6 +465,7 @@ class OperatorSetGroup(OperatorsSetBase):
model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the input and set the concatenated name based on the operators_set.
Expand Down Expand Up @@ -518,6 +520,7 @@ class Fusing(TargetPlatformModelComponent):
model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the operator_groups and set the name by concatenating operator group names.
Expand Down Expand Up @@ -545,14 +548,14 @@ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

@model_validator(mode="after")
def validate_after_initialization(cls, model: 'Fusing') -> Any:
def validate_after_initialization(self) -> 'Fusing':
"""
Perform validation after the model has been instantiated.
Ensures that there are at least two operator groups.
"""
if len(model.operator_groups) < 2:
if len(self.operator_groups) < 2:
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
return model
return self

def contains(self, other: Any) -> bool:
"""
Expand Down Expand Up @@ -633,32 +636,25 @@ class TargetPlatformCapabilities(BaseModel):
model_config = ConfigDict(frozen=True)

@model_validator(mode="after")
def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
def validate_after_initialization(self) -> 'TargetPlatformCapabilities':
"""
Perform validation after the model has been instantiated.

Args:
model (TargetPlatformCapabilities): The instantiated target platform model.

Returns:
TargetPlatformCapabilities: The validated model.
"""
# Validate `default_qco`
default_qco = model.default_qco
default_qco = self.default_qco
if len(default_qco.quantization_configurations) != 1:
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover

# Validate `operator_set` uniqueness
operator_set = model.operator_set
operator_set = self.operator_set
if operator_set is not None:
opsets_names = [
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
for op in operator_set
]
if len(set(opsets_names)) != len(opsets_names):
Logger.critical("Operator Sets must have unique names.") # pragma: no cover

return model
return self

def get_info(self) -> Dict[str, Any]:
"""
Expand Down
25 changes: 13 additions & 12 deletions model_compression_toolkit/target_platform_capabilities/schema/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from typing import Dict, Any, Union, Tuple, Optional, Annotated

from pydantic import BaseModel, Field, root_validator, model_validator, ConfigDict
from pydantic import BaseModel, Field, model_validator, ConfigDict

from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import FLOAT_BITWIDTH
Expand Down Expand Up @@ -116,7 +116,8 @@ class Fusing(TargetPlatformModelComponent):

model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@model_validator(mode='before')
@classmethod
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the operator_groups and set the name by concatenating operator group names.
Expand All @@ -143,16 +144,16 @@ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:

return values

@model_validator(mode="after")
def validate_after_initialization(cls, model: 'Fusing') -> Any:
@model_validator(mode='after')
def validate_after_initialization(self) -> 'Fusing':
"""
Perform validation after the model has been instantiated.
Ensures that there are at least two operator groups.
"""
if len(model.operator_groups) < 2:
if len(self.operator_groups) < 2:
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
return model

return self

def contains(self, other: Any) -> bool:
"""
Expand Down Expand Up @@ -235,8 +236,8 @@ class TargetPlatformCapabilities(BaseModel):

model_config = ConfigDict(frozen=True)

@model_validator(mode="after")
def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
@model_validator(mode='after')
def validate_after_initialization(self) -> 'TargetPlatformCapabilities':
"""
Perform validation after the model has been instantiated.

Expand All @@ -247,12 +248,12 @@ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> A
TargetPlatformCapabilities: The validated model.
"""
# Validate `default_qco`
default_qco = model.default_qco
default_qco = self.default_qco
if len(default_qco.quantization_configurations) != 1:
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover

# Validate `operator_set` uniqueness
operator_set = model.operator_set
operator_set = self.operator_set
if operator_set is not None:
opsets_names = [
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
Expand All @@ -261,7 +262,7 @@ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> A
if len(set(opsets_names)) != len(opsets_names):
Logger.critical("Operator Sets must have unique names.") # pragma: no cover

return model
return self

def get_info(self) -> Dict[str, Any]:
"""
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers==1.6.0
pydantic>=2.0,<2.12.0
pydantic>=2.0,<3
edge-mdt-cl-dev