Skip to content

Commit 5964dd2

Browse files
committed
feat: Add multiple IO backend implementations
1 parent bc0a0d7 commit 5964dd2

File tree

12 files changed

+1290
-0
lines changed

12 files changed

+1290
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from ._base import NetworkStream, NetWorkBackend, AsyncNetworkStream, AsyncNetworkBackend
18+
from ._anyio import AnyIOTCPStream, AnyIOBackend
19+
from ._sync import SyncTCPStream, SyncBackend
20+
21+
__all__ = [
22+
"NetworkStream",
23+
"NetWorkBackend",
24+
"AsyncNetworkStream",
25+
"AsyncNetworkBackend",
26+
"AnyIOTCPStream",
27+
"AnyIOBackend",
28+
"SyncTCPStream",
29+
"SyncBackend",
30+
]
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import socket
18+
import ssl
19+
from ssl import SSLContext
20+
from typing import Any, Callable, Optional, Union
21+
22+
import anyio
23+
from anyio.abc import SocketStream, SocketAttribute
24+
from anyio.streams.tls import TLSListener, TLSStream
25+
26+
from ..types import IPAddressType
27+
from ._base import AsyncNetworkBackend, AsyncNetworkStream
28+
29+
30+
class AnyIOTCPStream(AsyncNetworkStream):
31+
"""
32+
AnyIOTCPStream is an asynchronous network stream implementation using AnyIO.
33+
"""
34+
35+
_socket: Union[SocketStream, TLSStream]
36+
37+
def __init__(self, stream: Union[SocketStream, TLSStream]) -> None:
38+
self._socket = stream
39+
40+
@property
41+
def local_addr(self) -> tuple:
42+
return self._socket.extra(SocketAttribute.local_address)
43+
44+
@property
45+
def remote_addr(self) -> tuple:
46+
return self._socket.extra(SocketAttribute.remote_address)
47+
48+
async def send(self, data: bytes, timeout: Optional[float] = None) -> None:
49+
with anyio.fail_after(timeout):
50+
await self._socket.send(data)
51+
52+
async def receive(self, max_size: int, timeout: Optional[float] = None) -> bytes:
53+
with anyio.fail_after(timeout):
54+
try:
55+
return await self._socket.receive(max_size)
56+
except anyio.EndOfStream:
57+
# It means no more data can be read from the stream.
58+
return b""
59+
60+
async def aclose(self) -> None:
61+
await self._socket.aclose()
62+
63+
64+
class AnyIOTCPServer:
65+
def __init__(self, listener, handler: Callable[[AsyncNetworkStream], Any]) -> None:
66+
self.listener = listener
67+
self.raw_handler = handler
68+
69+
async def _wrap_handler(self, stream: SocketStream) -> None:
70+
"""
71+
Wrap the handler to convert the stream to AnyIOTCPStream.
72+
"""
73+
stream = AnyIOTCPStream(stream)
74+
await self.raw_handler(stream)
75+
76+
async def serve_forever(self):
77+
"""
78+
Serve the TCP server forever.
79+
"""
80+
await self.listener.serve(self._wrap_handler)
81+
82+
83+
class AnyIOBackend(AsyncNetworkBackend):
84+
async def connect_tcp(
85+
self,
86+
host: IPAddressType,
87+
port: int,
88+
ssl_context: Optional[SSLContext] = None,
89+
timeout: Optional[float] = None,
90+
) -> AnyIOTCPStream:
91+
"""
92+
Connect to a TCP server.
93+
"""
94+
with anyio.fail_after(timeout):
95+
stream = await anyio.connect_tcp(host, port, ssl_context=ssl_context)
96+
return AnyIOTCPStream(stream)
97+
98+
async def listen_tcp(
99+
self,
100+
handler: Callable[[AsyncNetworkStream], Any],
101+
local_host: Optional[IPAddressType] = None,
102+
local_port: int = 0,
103+
ssl_context: Optional[ssl.SSLContext] = None,
104+
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
105+
backlog: int = 65535,
106+
reuse_port: bool = False,
107+
) -> Any:
108+
"""
109+
Listen for incoming TCP connections.
110+
"""
111+
listener = await anyio.create_tcp_listener(
112+
local_host=local_host, local_port=local_port, family=family, backlog=backlog, reuse_port=reuse_port
113+
)
114+
if ssl_context:
115+
listener = TLSListener(listener, ssl_context)
116+
return AnyIOTCPServer(listener, handler)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import abc
18+
19+
__all___ = ["NetworkStream", "NetWorkBackend", "AsyncNetworkStream", "AsyncNetworkBackend"]
20+
21+
import ssl
22+
23+
from typing import Optional
24+
25+
26+
class BaseNetworkStream(abc.ABC):
27+
"""
28+
BaseNetworkStream is an abstract base class for all network streams.
29+
"""
30+
31+
@property
32+
@abc.abstractmethod
33+
def local_addr(self) -> tuple:
34+
"""
35+
Get the local address of the network stream.
36+
"""
37+
raise NotImplementedError()
38+
39+
@property
40+
@abc.abstractmethod
41+
def remote_addr(self) -> tuple:
42+
"""
43+
Get the remote address of the network stream.
44+
"""
45+
raise NotImplementedError()
46+
47+
48+
class NetworkStream(BaseNetworkStream, abc.ABC):
49+
"""
50+
NetworkStream is an abstract base class for synchronous network streams.
51+
"""
52+
53+
def __enter__(self):
54+
"""
55+
Enter the runtime context related to this object.
56+
"""
57+
return self
58+
59+
def __exit__(self, exc_type, exc_val, exc_tb):
60+
"""
61+
Exit the runtime context related to this object.
62+
"""
63+
self.close()
64+
65+
@abc.abstractmethod
66+
def send(self, data: bytes, timeout: Optional[float] = None) -> None:
67+
"""
68+
Send data over the network stream.
69+
70+
:param data: The data to send.
71+
:param timeout: Optional timeout for the send operation.
72+
"""
73+
raise NotImplementedError()
74+
75+
@abc.abstractmethod
76+
def receive(self, max_size: int, timeout: Optional[float] = None) -> bytes:
77+
"""
78+
Receive data from the network stream.
79+
80+
:param max_size: The maximum size of data to receive.
81+
:param timeout: Optional timeout for the reception operation.
82+
:return: The received data.
83+
"""
84+
raise NotImplementedError()
85+
86+
@abc.abstractmethod
87+
def close(self) -> None:
88+
"""
89+
Close the network stream.
90+
"""
91+
raise NotImplementedError()
92+
93+
94+
class NetWorkBackend(abc.ABC):
95+
"""
96+
NetworkBackend is an abstract base class for synchronous network backends.
97+
"""
98+
99+
pass
100+
101+
102+
class AsyncNetworkStream(BaseNetworkStream, abc.ABC):
103+
"""
104+
AsyncNetworkStream is an abstract base class for asynchronous network streams.
105+
"""
106+
107+
async def __aenter__(self):
108+
"""
109+
Enter the asynchronous runtime context related to this object.
110+
"""
111+
return self
112+
113+
async def __aexit__(self, exc_type, exc_val, exc_tb):
114+
"""
115+
Exit the asynchronous runtime context related to this object.
116+
"""
117+
await self.aclose()
118+
119+
@abc.abstractmethod
120+
async def send(self, data: bytes, timeout: Optional[float] = None) -> None:
121+
"""
122+
Send data over the network stream.
123+
124+
:param data: The data to send.
125+
:param timeout: Optional timeout for the send operation.
126+
"""
127+
raise NotImplementedError()
128+
129+
@abc.abstractmethod
130+
async def receive(self, max_size: int, timeout: Optional[float] = None) -> bytes:
131+
"""
132+
Receive data from the network stream.
133+
134+
:param max_size: The maximum size of data to receive.
135+
:param timeout: Optional timeout for the reception operation.
136+
:return: The received data.
137+
"""
138+
raise NotImplementedError()
139+
140+
@abc.abstractmethod
141+
async def aclose(self) -> None:
142+
"""
143+
Close the network stream.
144+
"""
145+
raise NotImplementedError()
146+
147+
148+
class AsyncServer(abc.ABC):
149+
"""
150+
AsyncServer is an abstract base class for asynchronous network servers.
151+
"""
152+
153+
pass
154+
155+
156+
class AsyncNetworkBackend(abc.ABC):
157+
"""
158+
AsyncNetworkBackend is an abstract base class for asynchronous network backends.
159+
"""
160+
161+
pass

0 commit comments

Comments
 (0)