Skip to content

Commit 9d49e82

Browse files
Remove Config class in favor of fail_fast kwargs (#364)
Brings us full circle, effectively backing out #323. Resolves #362. [1]: https://docs.python.org/3/tutorial/controlflow.html#keyword-only-arguments
1 parent f450590 commit 9d49e82

File tree

6 files changed

+20
-88
lines changed

6 files changed

+20
-88
lines changed

protovalidate/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from protovalidate import config, validator
15+
from protovalidate import validator
1616

17-
Config = config.Config
1817
Validator = validator.Validator
1918
CompilationError = validator.CompilationError
2019
ValidationError = validator.ValidationError
@@ -24,4 +23,4 @@
2423
validate = _default_validator.validate
2524
collect_violations = _default_validator.collect_violations
2625

27-
__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
26+
__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]

protovalidate/config.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

protovalidate/internal/rules.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory, timestamp_pb2
2323

2424
from buf.validate import validate_pb2
25-
from protovalidate.config import Config
2625
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2726

2827

@@ -266,19 +265,14 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N
266265
class RuleContext:
267266
"""The state associated with a single rule evaluation."""
268267

269-
_cfg: Config
270268
_violations: list[Violation]
271269

272-
def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
273-
self._cfg = config
270+
def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None):
271+
self._fail_fast = fail_fast
274272
if violations is None:
275273
violations = []
276274
self._violations = violations
277275

278-
@property
279-
def fail_fast(self) -> bool:
280-
return self._cfg.fail_fast
281-
282276
@property
283277
def violations(self) -> list[Violation]:
284278
return self._violations
@@ -299,13 +293,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat
299293

300294
@property
301295
def done(self) -> bool:
302-
return self.fail_fast and self.has_errors()
296+
return self._fail_fast and self.has_errors()
303297

304298
def has_errors(self) -> bool:
305299
return len(self._violations) > 0
306300

307301
def sub_context(self) -> "RuleContext":
308-
return RuleContext(config=self._cfg)
302+
return RuleContext(fail_fast=self._fail_fast)
309303

310304

311305
class Rules:

protovalidate/validator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from google.protobuf import message
1818

1919
from buf.validate import validate_pb2
20-
from protovalidate.config import Config
2120
from protovalidate.internal import extra_func
2221
from protovalidate.internal import rules as _rules
2322

@@ -36,29 +35,25 @@ class Validator:
3635
"""
3736

3837
_factory: _rules.RuleFactory
39-
_cfg: Config
4038

41-
def __init__(self, config: typing.Optional[Config] = None):
42-
self._cfg = config if config is not None else Config()
39+
def __init__(self):
4340
funcs = extra_func.make_extra_funcs()
4441
self._factory = _rules.RuleFactory(funcs)
4542

46-
def validate(
47-
self,
48-
message: message.Message,
49-
):
43+
def validate(self, message: message.Message, *, fail_fast: bool = False):
5044
"""
5145
Validates the given message against the static rules defined in
5246
the message's descriptor.
5347
5448
Parameters:
5549
message: The message to validate.
50+
fail_fast: If true, validation will stop after the first iteration.
5651
Raises:
5752
CompilationError: If the static rules could not be compiled.
5853
ValidationError: If the message is invalid. The violations raised as part of this error should
5954
always be equal to the list of violations returned by `collect_violations`.
6055
"""
61-
violations = self.collect_violations(message)
56+
violations = self.collect_violations(message, fail_fast=fail_fast)
6257
if len(violations) > 0:
6358
msg = f"invalid {message.DESCRIPTOR.name}"
6459
raise ValidationError(msg, violations)
@@ -67,6 +62,7 @@ def collect_violations(
6762
self,
6863
message: message.Message,
6964
*,
65+
fail_fast: bool = False,
7066
into: typing.Optional[list[Violation]] = None,
7167
) -> list[Violation]:
7268
"""
@@ -80,12 +76,13 @@ def collect_violations(
8076
8177
Parameters:
8278
message: The message to validate.
79+
fail_fast: If true, validation will stop after the first iteration.
8380
into: If provided, any violations will be appended to the
8481
Violations object and the same object will be returned.
8582
Raises:
8683
CompilationError: If the static rules could not be compiled.
8784
"""
88-
ctx = _rules.RuleContext(config=self._cfg, violations=into)
85+
ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into)
8986
for rule in self._factory.get(message.DESCRIPTOR):
9087
rule.validate(ctx, message)
9188
if ctx.done:

test/test_config.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

test/test_validate.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import protovalidate
2020
from gen.tests.example.v1 import validations_pb2
21-
from protovalidate.config import Config
2221
from protovalidate.internal import rules
2322

2423

@@ -27,13 +26,11 @@ def get_default_validator():
2726
2827
This allows testing for validators created via:
2928
- module-level singleton
30-
- instantiated class with no config
31-
- instantiated class with config
29+
- instantiated class
3230
"""
3331
return [
3432
("module singleton", protovalidate),
35-
("no config", protovalidate.Validator()),
36-
("with default config", protovalidate.Validator(Config())),
33+
("instantiated class", protovalidate.Validator()),
3734
]
3835

3936

@@ -42,8 +39,7 @@ class TestCollectViolations(unittest.TestCase):
4239
4340
A validator can be created via various ways:
4441
- a module-level singleton, which returns a default validator
45-
- instantiating the Validator class with no config, which returns a default validator
46-
- instantiating the Validator class with a config
42+
- instantiating the Validator class
4743
4844
In addition, the API for validating a message allows for two approaches:
4945
- via a call to `validate`, which will raise a ValidationError if validation fails
@@ -188,11 +184,7 @@ def test_concatenated_values(self):
188184
self._run_valid_tests(msg)
189185

190186
def test_fail_fast(self):
191-
"""Test that fail fast correctly fails on first violation
192-
193-
Note this does not use a default validator, but instead uses one with a custom config
194-
so that fail_fast can be set to True.
195-
"""
187+
"""Test that fail fast correctly fails on first violation"""
196188
msg = validations_pb2.MultipleValidations()
197189
msg.title = "bar"
198190
msg.name = "blah"
@@ -203,18 +195,17 @@ def test_fail_fast(self):
203195
expected_violation.field_value = msg.title
204196
expected_violation.rule_value = "foo"
205197

206-
cfg = Config(fail_fast=True)
207-
validator = protovalidate.Validator(config=cfg)
198+
validator = protovalidate.Validator()
208199

209200
# Test validate
210201
with self.assertRaises(protovalidate.ValidationError) as cm:
211-
validator.validate(msg)
202+
validator.validate(msg, fail_fast=True)
212203
e = cm.exception
213204
self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}")
214205
self._compare_violations(e.violations, [expected_violation])
215206

216207
# Test collect_violations
217-
violations = validator.collect_violations(msg)
208+
violations = validator.collect_violations(msg, fail_fast=True)
218209
self._compare_violations(violations, [expected_violation])
219210

220211
def _run_valid_tests(self, msg: message.Message):

0 commit comments

Comments
 (0)