Skip to content

Commit d6f2921

Browse files
Use attrs.define instead of dataclasses
- Split the changelog - Use TypeGuard correctly - Remove `transform` in favor of `__attrs_post_init__` to allow for a more personalized use of `validate` - Remove `define` from the docs, change format in changelog
1 parent 14d1f8c commit d6f2921

File tree

16 files changed

+209
-212
lines changed

16 files changed

+209
-212
lines changed

changes/285.internal.1.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
- Changed the way `Serializable` classes are handled:
2+
3+
Here is how a basic `Serializable` class looks like:
4+
5+
```python
6+
@final
7+
@dataclass
8+
class ToyClass(Serializable):
9+
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""
10+
11+
a: int
12+
b: str | int
13+
14+
@override
15+
def __attrs_post_init__(self):
16+
"""Initialize the object."""
17+
if isinstance(self.b, int):
18+
self.b = str(self.b)
19+
20+
super().__attrs_post_init__() # This will call validate()
21+
22+
@override
23+
def serialize_to(self, buf: Buffer):
24+
"""Write the object to a buffer."""
25+
self.b = cast(str, self.b) # Handled by the __attrs_post_init__ method
26+
buf.write_varint(self.a)
27+
buf.write_utf(self.b)
28+
29+
@classmethod
30+
@override
31+
def deserialize(cls, buf: Buffer) -> ToyClass:
32+
"""Deserialize the object from a buffer."""
33+
a = buf.read_varint()
34+
if a == 0:
35+
raise ZeroDivisionError("a must be non-zero")
36+
b = buf.read_utf()
37+
return cls(a, b)
38+
39+
@override
40+
def validate(self) -> None:
41+
"""Validate the object's attributes."""
42+
if self.a == 0:
43+
raise ZeroDivisionError("a must be non-zero")
44+
if len(self.b) > 10:
45+
raise ValueError("b must be less than 10 characters")
46+
47+
```
48+
49+
The `Serializable` class implement the following methods:
50+
51+
- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
52+
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.
53+
54+
And the following optional methods:
55+
56+
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
57+
- `__attrs_post_init__() -> None`: Initializes the object. Call `super().__attrs_post_init__()` to validate the object.
Lines changed: 4 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,3 @@
1-
- Changed the way `Serializable` classes are handled:
2-
3-
Here is how a basic `Serializable` class looks like:
4-
```python
5-
@final
6-
@dataclass
7-
class ToyClass(Serializable):
8-
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""
9-
10-
a: int
11-
b: str | int
12-
13-
@override
14-
def serialize_to(self, buf: Buffer):
15-
"""Write the object to a buffer."""
16-
self.b = cast(str, self.b) # Handled by the transform method
17-
buf.write_varint(self.a)
18-
buf.write_utf(self.b)
19-
20-
@classmethod
21-
@override
22-
def deserialize(cls, buf: Buffer) -> ToyClass:
23-
"""Deserialize the object from a buffer."""
24-
a = buf.read_varint()
25-
if a == 0:
26-
raise ZeroDivisionError("a must be non-zero")
27-
b = buf.read_utf()
28-
return cls(a, b)
29-
30-
@override
31-
def validate(self) -> None:
32-
"""Validate the object's attributes."""
33-
if self.a == 0:
34-
raise ZeroDivisionError("a must be non-zero")
35-
if (isinstance(self.b, int) and math.log10(self.b) > 10) or (isinstance(self.b, str) and len(self.b) > 10):
36-
raise ValueError("b must be less than 10 characters")
37-
38-
@override
39-
def transform(self) -> None:
40-
"""Apply a transformation to the payload of the object."""
41-
if isinstance(self.b, int):
42-
self.b = str(self.b)
43-
```
44-
45-
46-
The `Serializable` class implement the following methods:
47-
- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
48-
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.
49-
50-
And the following optional methods:
51-
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
52-
- `transform() -> None`: Transforms the the object's attributes, this method is meant to convert types like you would in a classic `__init__`.
53-
You can rely on this `validate` having been executed.
54-
551
- Added a test generator for `Serializable` classes:
562

573
The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:
@@ -69,6 +15,7 @@ And the following optional methods:
6915
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.
7016

7117
The `gen_serializable_test` function generates a test class with the following tests:
18+
7219
```python
7320
gen_serializable_test(
7421
context=globals(),
@@ -80,7 +27,7 @@ gen_serializable_test(
8027
((3, 1234567890), b"\x03\x0a1234567890"),
8128
((0, "hello"), ZeroDivisionError("a must be non-zero")), # With an error message
8229
((1, "hello world"), ValueError), # No error message
83-
((1, 12345678900), ValueError),
30+
((1, 12345678900), ValueError("b must be less than 10 .*")), # With an error message and regex
8431
(ZeroDivisionError, b"\x00"),
8532
(ZeroDivisionError, b"\x01\x05hello"),
8633
(IOError, b"\x01"),
@@ -90,7 +37,7 @@ gen_serializable_test(
9037

9138
The generated test class will have the following tests:
9239

93-
```python
40+
```python
9441
class TestGenToyClass:
9542
def test_serialization(self):
9643
# 2 subtests for the cases 1 and 2
@@ -103,4 +50,4 @@ class TestGenToyClass:
10350

10451
def test_exceptions(self):
10552
# 2 subtests for the cases 5 and 6
106-
```
53+
```

docs/api/internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ should not be used externally**, as we do not guarantee their backwards compatib
77
may be introduced between patch versions without any warnings.
88

99
.. automodule:: mcproto.utils.abc
10+
:exclude-members: define
1011

1112
.. autofunction:: tests.helpers.gen_serializable_test
1213
..

mcproto/packets/handshaking/handshake.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mcproto.buffer import Buffer
99
from mcproto.packets.packet import GameState, ServerBoundPacket
1010
from mcproto.protocol.base_io import StructFormat
11-
from mcproto.utils.abc import dataclass
11+
from mcproto.utils.abc import define
1212

1313
__all__ = [
1414
"NextState",
@@ -24,7 +24,7 @@ class NextState(IntEnum):
2424

2525

2626
@final
27-
@dataclass
27+
@define
2828
class Handshake(ServerBoundPacket):
2929
"""Initializes connection between server and client. (Client -> Server).
3030
@@ -44,10 +44,17 @@ class Handshake(ServerBoundPacket):
4444
server_port: int
4545
next_state: NextState | int
4646

47+
@override
48+
def __attrs_post_init__(self) -> None:
49+
if not isinstance(self.next_state, NextState):
50+
self.next_state = NextState(self.next_state)
51+
52+
super().__attrs_post_init__()
53+
4754
@override
4855
def serialize_to(self, buf: Buffer) -> None:
4956
"""Serialize the packet."""
50-
self.next_state = cast(NextState, self.next_state) # Handled by the transform method
57+
self.next_state = cast(NextState, self.next_state) # Handled by the __attrs_post_init__ method
5158
buf.write_varint(self.protocol_version)
5259
buf.write_utf(self.server_address)
5360
buf.write_value(StructFormat.USHORT, self.server_port)
@@ -69,8 +76,3 @@ def validate(self) -> None:
6976
rev_lookup = {x.value: x for x in NextState.__members__.values()}
7077
if self.next_state not in rev_lookup:
7178
raise ValueError("No such next_state.")
72-
73-
@override
74-
def transform(self) -> None:
75-
"""Get the next state enum from the integer value."""
76-
self.next_state = NextState(self.next_state)

mcproto/packets/login/login.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
1212
from mcproto.types.chat import ChatMessage
1313
from mcproto.types.uuid import UUID
14-
from mcproto.utils.abc import dataclass
14+
from mcproto.utils.abc import define
1515

1616
__all__ = [
1717
"LoginDisconnect",
@@ -26,7 +26,7 @@
2626

2727

2828
@final
29-
@dataclass
29+
@define
3030
class LoginStart(ServerBoundPacket):
3131
"""Packet from client asking to start login process. (Client -> Server).
3232
@@ -56,7 +56,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
5656

5757

5858
@final
59-
@dataclass
59+
@define
6060
class LoginEncryptionRequest(ClientBoundPacket):
6161
"""Used by the server to ask the client to encrypt the login process. (Server -> Client).
6262
@@ -74,6 +74,13 @@ class LoginEncryptionRequest(ClientBoundPacket):
7474
verify_token: bytes
7575
server_id: str | None = None
7676

77+
@override
78+
def __attrs_post_init__(self) -> None:
79+
if self.server_id is None:
80+
self.server_id = " " * 20
81+
82+
super().__attrs_post_init__()
83+
7784
@override
7885
def serialize_to(self, buf: Buffer) -> None:
7986
self.server_id = cast(str, self.server_id)
@@ -96,14 +103,9 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
96103

97104
return cls(server_id=server_id, public_key=public_key, verify_token=verify_token)
98105

99-
@override
100-
def transform(self) -> None:
101-
if self.server_id is None:
102-
self.server_id = " " * 20
103-
104106

105107
@final
106-
@dataclass
108+
@define
107109
class LoginEncryptionResponse(ServerBoundPacket):
108110
"""Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server).
109111
@@ -134,7 +136,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
134136

135137

136138
@final
137-
@dataclass
139+
@define
138140
class LoginSuccess(ClientBoundPacket):
139141
"""Sent by the server to denote a successful login. (Server -> Client).
140142
@@ -164,7 +166,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
164166

165167

166168
@final
167-
@dataclass
169+
@define
168170
class LoginDisconnect(ClientBoundPacket):
169171
"""Sent by the server to kick a player while in the login state. (Server -> Client).
170172
@@ -190,7 +192,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
190192

191193

192194
@final
193-
@dataclass
195+
@define
194196
class LoginPluginRequest(ClientBoundPacket):
195197
"""Sent by the server to implement a custom handshaking flow. (Server -> Client).
196198
@@ -224,7 +226,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
224226

225227

226228
@final
227-
@dataclass
229+
@define
228230
class LoginPluginResponse(ServerBoundPacket):
229231
"""Response to LoginPluginRequest from client. (Client -> Server).
230232
@@ -254,7 +256,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
254256

255257

256258
@final
257-
@dataclass
259+
@define
258260
class LoginSetCompression(ClientBoundPacket):
259261
"""Sent by the server to specify whether to use compression on future packets or not (Server -> Client).
260262

mcproto/packets/status/ping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from mcproto.buffer import Buffer
88
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
99
from mcproto.protocol.base_io import StructFormat
10-
from mcproto.utils.abc import dataclass
10+
from mcproto.utils.abc import define
1111

1212
__all__ = ["PingPong"]
1313

1414

1515
@final
16-
@dataclass
16+
@define
1717
class PingPong(ClientBoundPacket, ServerBoundPacket):
1818
"""Ping request/Pong response (Server <-> Client).
1919

mcproto/packets/status/status.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
from mcproto.buffer import Buffer
99
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
10-
from mcproto.utils.abc import dataclass
10+
from mcproto.utils.abc import define
1111

1212
__all__ = ["StatusRequest", "StatusResponse"]
1313

1414

1515
@final
16-
@dataclass
16+
@define
1717
class StatusRequest(ServerBoundPacket):
1818
"""Request from the client to get information on the server. (Client -> Server)."""
1919

@@ -31,7 +31,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to t
3131

3232

3333
@final
34-
@dataclass
34+
@define
3535
class StatusResponse(ClientBoundPacket):
3636
"""Response from the server to requesting client with status data information. (Server -> Client).
3737

mcproto/types/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
from mcproto.utils.abc import Serializable, dataclass
3+
from mcproto.utils.abc import Serializable, define
44

5-
__all__ = ["MCType", "dataclass"] # That way we can import it from mcproto.types.abc
5+
__all__ = ["MCType", "define"] # That way we can import it from mcproto.types.abc
66

77

88
class MCType(Serializable):

mcproto/types/chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing_extensions import Self, TypeAlias, override
77

88
from mcproto.buffer import Buffer
9-
from mcproto.types.abc import MCType, dataclass
9+
from mcproto.types.abc import MCType, define
1010

1111
__all__ = [
1212
"ChatMessage",
@@ -33,13 +33,15 @@ class RawChatMessageDict(TypedDict, total=False):
3333
RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str]
3434

3535

36-
@dataclass
3736
@final
37+
@define
3838
class ChatMessage(MCType):
3939
"""Minecraft chat message representation."""
4040

4141
raw: RawChatMessage
4242

43+
__slots__ = ("raw",)
44+
4345
def as_dict(self) -> RawChatMessageDict:
4446
"""Convert received ``raw`` into a stadard :class:`dict` form."""
4547
if isinstance(self.raw, list):

0 commit comments

Comments
 (0)