Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions gen/tests/example/v1/validations_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions gen/tests/example/v1/validations_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions proto/tests/example/v1/validations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ package tests.example.v1;
import "buf/validate/validate.proto";
import "google/protobuf/timestamp.proto";

message MultipleValidations {
string title = 1 [(buf.validate.field).string.prefix = "foo"];
string name = 2 [(buf.validate.field).string.min_len = 5];
}

message DoubleFinite {
double val = 1 [(buf.validate.field).double.finite = true];
}
Expand Down
11 changes: 6 additions & 5 deletions protovalidate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from protovalidate import validator
from protovalidate import config, validator

Config = config.Config
Validator = validator.Validator
CompilationError = validator.CompilationError
ValidationError = validator.ValidationError
Violations = validator.Violations

_validator = Validator()
validate = _validator.validate
collect_violations = _validator.collect_violations
_default_validator = Validator()
validate = _default_validator.validate
collect_violations = _default_validator.collect_violations

__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
26 changes: 26 additions & 0 deletions protovalidate/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023-2025 Buf Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass


@dataclass
class Config:
"""A class for holding configuration values for validation.

Attributes:
fail_fast (bool): If true, validation will stop after the first violation. Defaults to False.
"""

fail_fast: bool = False
13 changes: 8 additions & 5 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.protobuf import any_pb2, descriptor, message, message_factory

from buf.validate import validate_pb2 # type: ignore
from protovalidate.config import Config
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has


Expand Down Expand Up @@ -266,15 +267,17 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N
class RuleContext:
"""The state associated with a single rule evaluation."""

def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None):
self._fail_fast = fail_fast
_cfg: Config

def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
self._cfg = config
if violations is None:
violations = []
self._violations = violations

@property
def fail_fast(self) -> bool:
return self._fail_fast
return self._cfg.fail_fast

@property
def violations(self) -> list[Violation]:
Expand All @@ -296,13 +299,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat

@property
def done(self) -> bool:
return self._fail_fast and self.has_errors()
return self.fail_fast and self.has_errors()

def has_errors(self) -> bool:
return len(self._violations) > 0

def sub_context(self):
return RuleContext(fail_fast=self._fail_fast)
return RuleContext(config=self._cfg)


class Rules:
Expand Down
26 changes: 14 additions & 12 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.protobuf import message

from buf.validate import validate_pb2 # type: ignore
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal import rules as _rules

Expand All @@ -35,28 +36,28 @@ class Validator:
"""

_factory: _rules.RuleFactory
_cfg: Config

def __init__(self):
def __init__(self, config=None):
self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS)
self._cfg = config if config is not None else Config()

def validate(
self,
message: message.Message,
*,
fail_fast: bool = False,
):
"""
Validates the given message against the static rules defined in
the message's descriptor.

Parameters:
message: The message to validate.
fail_fast: If true, validation will stop after the first violation.
Raises:
CompilationError: If the static rules could not be compiled.
ValidationError: If the message is invalid.
ValidationError: If the message is invalid. The violations raised as part of this error should
always be equal to the list of violations returned by `collect_violations`.
"""
violations = self.collect_violations(message, fail_fast=fail_fast)
violations = self.collect_violations(message)
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)
Expand All @@ -65,24 +66,25 @@ def collect_violations(
self,
message: message.Message,
*,
fail_fast: bool = False,
into: typing.Optional[list[Violation]] = None,
) -> list[Violation]:
"""
Validates the given message against the static rules defined in
the message's descriptor. Compared to validate, collect_violations is
faster but puts the burden of raising an appropriate exception on the
caller.
the message's descriptor. Compared to `validate`, `collect_violations` simply
returns the violations as a list and puts the burden of raising an appropriate
exception on the caller.

The violations returned from this method should always be equal to the violations
raised as part of the ValidationError in the call to `validate`.

Parameters:
message: The message to validate.
fail_fast: If true, validation will stop after the first violation.
into: If provided, any violations will be appended to the
Violations object and the same object will be returned.
Raises:
CompilationError: If the static rules could not be compiled.
"""
ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into)
ctx = _rules.RuleContext(config=self._cfg, violations=into)
for rule in self._factory.get(message.DESCRIPTOR):
rule.validate(ctx, message)
if ctx.done:
Expand Down
23 changes: 23 additions & 0 deletions tests/config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023-2025 Buf Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from protovalidate import Config


class TestConfig(unittest.TestCase):
def test_defaults(self):
cfg = Config()
self.assertFalse(cfg.fail_fast)
Loading
Loading