Skip to content

Commit 98ee687

Browse files
Improve gen_test_serializable for more flexibility on error checking
1 parent d6f2921 commit 98ee687

File tree

12 files changed

+312
-263
lines changed

12 files changed

+312
-263
lines changed

changes/285.internal.1.md

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,34 @@
1-
- Changed the way `Serializable` classes are handled:
2-
3-
Here is how a basic `Serializable` class looks like:
4-
5-
```python
6-
@final
7-
@dataclass
8-
class ToyClass(Serializable):
9-
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""
10-
11-
a: int
12-
b: str | int
13-
14-
@override
15-
def __attrs_post_init__(self):
16-
"""Initialize the object."""
17-
if isinstance(self.b, int):
18-
self.b = str(self.b)
19-
20-
super().__attrs_post_init__() # This will call validate()
21-
22-
@override
23-
def serialize_to(self, buf: Buffer):
24-
"""Write the object to a buffer."""
25-
self.b = cast(str, self.b) # Handled by the __attrs_post_init__ method
26-
buf.write_varint(self.a)
27-
buf.write_utf(self.b)
28-
29-
@classmethod
30-
@override
31-
def deserialize(cls, buf: Buffer) -> ToyClass:
32-
"""Deserialize the object from a buffer."""
33-
a = buf.read_varint()
34-
if a == 0:
35-
raise ZeroDivisionError("a must be non-zero")
36-
b = buf.read_utf()
37-
return cls(a, b)
38-
39-
@override
40-
def validate(self) -> None:
41-
"""Validate the object's attributes."""
42-
if self.a == 0:
43-
raise ZeroDivisionError("a must be non-zero")
44-
if len(self.b) > 10:
45-
raise ValueError("b must be less than 10 characters")
46-
47-
```
48-
49-
The `Serializable` class implement the following methods:
50-
51-
- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
52-
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.
53-
54-
And the following optional methods:
55-
56-
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
57-
- `__attrs_post_init__() -> None`: Initializes the object. Call `super().__attrs_post_init__()` to validate the object.
1+
- **Function**: `gen_serializable_test`
2+
- Generates tests for serializable classes, covering serialization, deserialization, validation, and error handling.
3+
- **Parameters**:
4+
- `context` (dict): Context to add the test functions to (usually `globals()`).
5+
- `cls` (type): The serializable class to test.
6+
- `fields` (list): Tuples of field names and types of the serializable class.
7+
- `serialize_deserialize` (list, optional): Tuples for testing successful serialization/deserialization.
8+
- `validation_fail` (list, optional): Tuples for testing validation failures with expected exceptions.
9+
- `deserialization_fail` (list, optional): Tuples for testing deserialization failures with expected exceptions.
10+
- **Note**: Implement `__eq__` in the class for accurate comparison.
11+
12+
- The `gen_serializable_test` function generates a test class with the following tests:
13+
14+
.. literalinclude:: /../tests/mcproto/utils/test_serializable.py
15+
:language: python
16+
:start-after: # region Test ToyClass
17+
:end-before: # endregion Test ToyClass
18+
19+
- The generated test class will have the following tests:
20+
21+
```python
22+
class TestGenToyClass:
23+
def test_serialization(self):
24+
# 3 subtests for the cases 1, 2, 3 (serialize_deserialize)
25+
26+
def test_deserialization(self):
27+
# 3 subtests for the cases 1, 2, 3 (serialize_deserialize)
28+
29+
def test_validation(self):
30+
# 3 subtests for the cases 4, 5, 6 (validation_fail)
31+
32+
def test_exceptions(self):
33+
# 3 subtests for the cases 7, 8, 9 (deserialization_fail)
34+
```

changes/285.internal.2.md

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,16 @@
1-
- Added a test generator for `Serializable` classes:
2-
3-
The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:
4-
5-
- `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`).
6-
> Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`.
7-
- `cls`: The `Serializable` class to generate tests for.
8-
- `fields`: A list of fields where the test values will be placed.
9-
10-
> In the example above, the `ToyClass` class has two fields: `a` and `b`.
11-
12-
- `test_data`: A list of tuples containing either:
13-
- `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).`
14-
- `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object.
15-
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.
16-
17-
The `gen_serializable_test` function generates a test class with the following tests:
18-
19-
```python
20-
gen_serializable_test(
21-
context=globals(),
22-
cls=ToyClass,
23-
fields=[("a", int), ("b", str)],
24-
test_data=[
25-
((1, "hello"), b"\x01\x05hello"),
26-
((2, "world"), b"\x02\x05world"),
27-
((3, 1234567890), b"\x03\x0a1234567890"),
28-
((0, "hello"), ZeroDivisionError("a must be non-zero")), # With an error message
29-
((1, "hello world"), ValueError), # No error message
30-
((1, 12345678900), ValueError("b must be less than 10 .*")), # With an error message and regex
31-
(ZeroDivisionError, b"\x00"),
32-
(ZeroDivisionError, b"\x01\x05hello"),
33-
(IOError, b"\x01"),
34-
],
35-
)
36-
```
37-
38-
The generated test class will have the following tests:
39-
40-
```python
41-
class TestGenToyClass:
42-
def test_serialization(self):
43-
# 2 subtests for the cases 1 and 2
44-
45-
def test_deserialization(self):
46-
# 2 subtests for the cases 1 and 2
47-
48-
def test_validation(self):
49-
# 2 subtests for the cases 3 and 4
50-
51-
def test_exceptions(self):
52-
# 2 subtests for the cases 5 and 6
53-
```
1+
- **Class**: `Serializable`
2+
- Base class for types that should be (de)serializable into/from `mcproto.Buffer` data.
3+
- **Methods**:
4+
- `__attrs_post_init__()`: Runs validation after object initialization, override to define custom behavior.
5+
- `serialize() -> Buffer`: Returns the object as a `Buffer`.
6+
- `serialize_to(buf: Buffer)`: Abstract method to write the object to a `Buffer`.
7+
- `validate()`: Validates the object's attributes; can be overridden for custom validation.
8+
- `deserialize(cls, buf: Buffer) -> Self`: Abstract method to construct the object from a `Buffer`.
9+
- **Note**: Use the `dataclass` decorator when adding parameters to subclasses.
10+
11+
- Exemple:
12+
13+
.. literalinclude:: /../tests/mcproto/utils/test_serializable.py
14+
:language: python
15+
:start-after: # region ToyClass
16+
:end-before: # endregion ToyClass

mcproto/utils/abc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __new__(cls: type[Self], *a: Any, **kw: Any) -> Self:
6868
class Serializable(ABC):
6969
"""Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data.
7070
71-
Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.dataclass`
71+
Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.define`
7272
decorator.
7373
"""
7474

tests/helpers.py

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import asyncio
44
import inspect
5+
import re
56
import unittest.mock
67
from collections.abc import Callable, Coroutine
7-
from typing import Any, Generic, TypeVar
8+
from typing import Any, Generic, NamedTuple, TypeVar
89
from typing_extensions import TypeGuard
910

1011
import pytest
@@ -17,7 +18,14 @@
1718
P = ParamSpec("P")
1819
T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock)
1920

20-
__all__ = ["synchronize", "SynchronizedMixin", "UnpropagatingMockMixin", "CustomMockMixin", "gen_serializable_test"]
21+
__all__ = [
22+
"synchronize",
23+
"SynchronizedMixin",
24+
"UnpropagatingMockMixin",
25+
"CustomMockMixin",
26+
"gen_serializable_test",
27+
"TestExc",
28+
]
2129

2230

2331
def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
@@ -169,27 +177,42 @@ def __init__(self, **kwargs):
169177
super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid
170178

171179

172-
def isexception(obj: object) -> TypeGuard[type[Exception] | Exception]:
180+
def isexception(obj: object) -> TypeGuard[type[Exception] | TestExc]:
173181
"""Check if the object is an exception."""
174-
return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, Exception)
182+
return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, TestExc)
175183

176184

177-
def get_exception(exception: type[Exception] | Exception) -> tuple[type[Exception], str | None]:
178-
"""Get the exception type and message."""
179-
if isinstance(exception, type):
180-
return exception, None
181-
return type(exception), str(exception)
185+
class TestExc(NamedTuple):
186+
"""Named tuple to check if an exception is raised with a specific message.
187+
188+
:param exception: The exception type.
189+
:param match: If specified, a string containing a regular expression, or a regular expression object, that is
190+
tested against the string representation of the exception using :func:`re.search`.
191+
192+
:param kwargs: The keyword arguments passed to the exception.
193+
194+
If :attr:`kwargs` is not None, the exception instance will need to have the same attributes with the same values.
195+
"""
196+
197+
exception: type[Exception] | tuple[type[Exception], ...]
198+
match: str | re.Pattern[str] | None = None
199+
kwargs: dict[str, Any] | None = None
200+
201+
@classmethod
202+
def from_exception(cls, exception: type[Exception] | tuple[type[Exception], ...] | TestExc) -> TestExc:
203+
"""Create a :class:`TestExc` from an exception, does nothing if the object is already a :class:`TestExc`."""
204+
if isinstance(exception, TestExc):
205+
return exception
206+
return cls(exception)
182207

183208

184209
def gen_serializable_test(
185210
context: dict[str, Any],
186211
cls: type[Serializable],
187212
fields: list[tuple[str, type | str]],
188-
test_data: list[
189-
tuple[tuple[Any, ...], bytes]
190-
| tuple[tuple[Any, ...], type[Exception] | Exception]
191-
| tuple[type[Exception] | Exception, bytes]
192-
],
213+
serialize_deserialize: list[tuple[tuple[Any, ...], bytes]] | None = None,
214+
validation_fail: list[tuple[tuple[Any, ...], type[Exception] | TestExc]] | None = None,
215+
deserialization_fail: list[tuple[bytes, type[Exception] | TestExc]] | None = None,
193216
):
194217
"""Generate tests for a serializable class.
195218
@@ -199,15 +222,14 @@ def gen_serializable_test(
199222
:param context: The context to add the test functions to. This is usually `globals()`.
200223
:param cls: The serializable class to test.
201224
:param fields: A list of tuples containing the field names and types of the serializable class.
202-
:param test_data: A list of test data. Each element is a tuple containing either:
203-
- A tuple of parameters to pass to the serializable class constructor and the expected bytes after
204-
serialization
205-
- A tuple of parameters to pass to the serializable class constructor and the expected exception during
206-
validation
207-
- An exception to expect during deserialization and the bytes to deserialize
208-
209-
Exception can be either a type or an instance of an exception, in the latter case the exception message will
210-
be used to match the exception, and can contain regex patterns.
225+
:param serialize_deserialize: A list of tuples containing:
226+
- The tuple representing the arguments to pass to the :class:`mcproto.utils.abc.Serializable` class
227+
- The expected bytes
228+
:param validation_fail: A list of tuples containing the arguments to pass to the
229+
:class:`mcproto.utils.abc.Serializable` class and the expected exception, either as is or wrapped in a
230+
:class:`TestExc` object.
231+
:param deserialization_fail: A list of tuples containing the bytes to pass to the :meth:`deserialize` method of the
232+
class and the expected exception, either as is or wrapped in a :class:`TestExc` object.
211233
212234
Example usage:
213235
@@ -221,28 +243,30 @@ def gen_serializable_test(
221243
222244
.. note::
223245
The test cases will use :meth:`__eq__` to compare the objects, so make sure to implement it in the class if
224-
you are not using a dataclass.
246+
you are not using the autogenerated method from :func:`attrs.define`.
225247
226248
"""
227-
# Separate the test data into parameters for each test function
228249
# This holds the parameters for the serialization and deserialization tests
229250
parameters: list[tuple[dict[str, Any], bytes]] = []
230251

231252
# This holds the parameters for the validation tests
232-
validation_fail: list[tuple[dict[str, Any], type[Exception] | Exception]] = []
253+
validation_fail_kw: list[tuple[dict[str, Any], TestExc]] = []
254+
255+
for data, exp_bytes in [] if serialize_deserialize is None else serialize_deserialize:
256+
kwargs = dict(zip([f[0] for f in fields], data))
257+
parameters.append((kwargs, exp_bytes))
233258

234-
# This holds the parameters for the deserialization error tests
235-
deserialization_fail: list[tuple[bytes, type[Exception] | Exception]] = []
259+
for data, exc in [] if validation_fail is None else validation_fail:
260+
kwargs = dict(zip([f[0] for f in fields], data))
261+
exc_wrapped = TestExc.from_exception(exc)
262+
validation_fail_kw.append((kwargs, exc_wrapped))
236263

237-
for data_or_exc, expected_bytes_or_exc in test_data:
238-
if isinstance(data_or_exc, tuple) and isinstance(expected_bytes_or_exc, bytes):
239-
kwargs = dict(zip([f[0] for f in fields], data_or_exc))
240-
parameters.append((kwargs, expected_bytes_or_exc))
241-
elif isexception(data_or_exc) and isinstance(expected_bytes_or_exc, bytes):
242-
deserialization_fail.append((expected_bytes_or_exc, data_or_exc))
243-
elif isinstance(data_or_exc, tuple) and isexception(expected_bytes_or_exc):
244-
kwargs = dict(zip([f[0] for f in fields], data_or_exc))
245-
validation_fail.append((kwargs, expected_bytes_or_exc))
264+
# Just make sure that the exceptions are wrapped in TestExc
265+
deserialization_fail = (
266+
[]
267+
if deserialization_fail is None
268+
else [(data, TestExc.from_exception(exc)) for data, exc in deserialization_fail]
269+
)
246270

247271
def generate_name(param: dict[str, Any] | bytes, i: int) -> str:
248272
"""Generate a name for the test case."""
@@ -301,33 +325,45 @@ def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes):
301325

302326
@pytest.mark.parametrize(
303327
("kwargs", "exc"),
304-
validation_fail,
305-
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail)),
328+
validation_fail_kw,
329+
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail_kw)),
306330
)
307-
def test_validation(self, kwargs: dict[str, Any], exc: type[Exception] | Exception):
331+
def test_validation(self, kwargs: dict[str, Any], exc: TestExc):
308332
"""Test validation of the object."""
309-
exc, msg = get_exception(exc)
310-
with pytest.raises(exc, match=msg):
333+
with pytest.raises(exc.exception, match=exc.match) as exc_info:
311334
cls(**kwargs)
312335

336+
# If exc.kwargs is not None, check them against the exception
337+
if exc.kwargs is not None:
338+
for key, value in exc.kwargs.items():
339+
assert value == getattr(
340+
exc_info.value, key
341+
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"
342+
313343
@pytest.mark.parametrize(
314344
("content", "exc"),
315345
deserialization_fail,
316346
ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)),
317347
)
318-
def test_deserialization_error(self, content: bytes, exc: type[Exception] | Exception):
348+
def test_deserialization_error(self, content: bytes, exc: TestExc):
319349
"""Test deserialization error handling."""
320350
buf = Buffer(content)
321-
exc, msg = get_exception(exc)
322-
with pytest.raises(exc, match=msg):
351+
with pytest.raises(exc.exception, match=exc.match) as exc_info:
323352
cls.deserialize(buf)
324353

354+
# If exc.kwargs is not None, check them against the exception
355+
if exc.kwargs is not None:
356+
for key, value in exc.kwargs.items():
357+
assert value == getattr(
358+
exc_info.value, key
359+
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"
360+
325361
if len(parameters) == 0:
326362
# If there are no serialization tests, remove them
327363
del TestClass.test_serialization
328364
del TestClass.test_deserialization
329365

330-
if len(validation_fail) == 0:
366+
if len(validation_fail_kw) == 0:
331367
# If there are no validation tests, remove them
332368
del TestClass.test_validation
333369

tests/mcproto/packets/handshaking/test_handshake.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
("server_port", int),
1313
("next_state", NextState),
1414
],
15-
test_data=[
15+
serialize_deserialize=[
1616
(
1717
(757, "mc.aircs.racing", 25565, NextState.LOGIN),
1818
bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"),
@@ -29,6 +29,8 @@
2929
(757, "hypixel.net", 25565, NextState.STATUS),
3030
bytes.fromhex("f5050b6879706978656c2e6e657463dd01"),
3131
),
32+
],
33+
validation_fail=[
3234
# Invalid next state
3335
((757, "localhost", 25565, 3), ValueError),
3436
],

0 commit comments

Comments
 (0)