Skip to content

Commit bfb2a6d

Browse files
authored
connect: add single method for Stream class (#98)
1 parent a4765b0 commit bfb2a6d

File tree

4 files changed

+109
-83
lines changed

4 files changed

+109
-83
lines changed

examples/client.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
from collections.abc import AsyncGenerator
66

7-
from connect.connect import StreamRequest, UnaryRequest, ensure_single
7+
from connect.connect import StreamRequest, UnaryRequest
88
from connect.connection_pool import AsyncConnectionPool
99

1010
from proto.connectrpc.eliza.v1.eliza_pb2 import IntroduceRequest, ReflectRequest, SayRequest
@@ -21,11 +21,7 @@ async def run_unary(client: ElizaServiceClient) -> None:
2121

2222
async def run_server_streaming(client: ElizaServiceClient) -> None:
2323
"""Run server streaming RPC (Introduce)."""
24-
25-
async def request_generator() -> AsyncGenerator[IntroduceRequest]:
26-
yield IntroduceRequest(name="Alice")
27-
28-
request = StreamRequest(request_generator())
24+
request = StreamRequest(IntroduceRequest(name="Alice"))
2925

3026
message_count = 1
3127
async with client.Introduce(request) as response:
@@ -43,7 +39,7 @@ async def request_generator() -> AsyncGenerator[ReflectRequest]:
4339

4440
request = StreamRequest(request_generator())
4541
async with client.Reflect(request) as response:
46-
message = await ensure_single(response.messages)
42+
message = await response.single()
4743

4844
print(f"Final response: {message.sentence}")
4945

examples/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import hypercorn
77
import hypercorn.asyncio
8-
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse, ensure_single
8+
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse
99
from connect.handler_context import HandlerContext
1010
from connect.middleware import ConnectMiddleware
1111
from starlette.applications import Starlette
@@ -36,7 +36,7 @@ async def Introduce(
3636
self, request: StreamRequest[IntroduceRequest], _context: HandlerContext
3737
) -> StreamResponse[IntroduceResponse]:
3838
"""Introduce the Eliza service."""
39-
message = await ensure_single(request.messages)
39+
message = await request.single()
4040
name = message.name
4141
intros = eliza.get_intro_responses(name)
4242

src/connect/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
StreamType,
2323
UnaryRequest,
2424
UnaryResponse,
25-
recieve_stream_response,
26-
recieve_unary_response,
25+
receive_stream_response,
26+
receive_unary_response,
2727
)
2828
from connect.connection_pool import AsyncConnectionPool
2929
from connect.error import ConnectError
@@ -244,7 +244,7 @@ def on_request_send(r: httpcore.Request) -> None:
244244

245245
await conn.send(aiterate([request.message]), call_options.timeout, abort_event=call_options.abort_event)
246246

247-
response = await recieve_unary_response(conn=conn, t=output, abort_event=call_options.abort_event)
247+
response = await receive_unary_response(conn=conn, t=output, abort_event=call_options.abort_event)
248248
return response
249249

250250
unary_func = apply_interceptors(_unary_func, options.interceptors)
@@ -290,7 +290,7 @@ def on_request_send(r: httpcore.Request) -> None:
290290

291291
await conn.send(request.messages, call_options.timeout, call_options.abort_event)
292292

293-
response = await recieve_stream_response(conn, output, request.spec, call_options.abort_event)
293+
response = await receive_stream_response(conn, output, request.spec, call_options.abort_event)
294294
return response
295295

296296
stream_func = apply_interceptors(_stream_func, options.interceptors)

src/connect/connect.py

Lines changed: 100 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,22 @@ class Peer(BaseModel):
5151

5252

5353
class RequestCommon:
54-
"""RequestCommon is a class that encapsulates common attributes and methods for handling HTTP requests.
54+
"""A common base class for handling request-related functionality.
5555
56-
Attributes:
57-
_spec (Spec): The specification for the request.
58-
_peer (Peer): The peer information.
59-
_headers (Headers): The request headers.
60-
_method (str): The HTTP method used for the request.
56+
This class encapsulates the common properties and behaviors shared across
57+
different types of requests, including specification details, peer information,
58+
headers, and HTTP method configuration.
6159
60+
Attributes:
61+
_spec (Spec): The specification for the request containing procedure details,
62+
descriptor, stream type, and idempotency level.
63+
_peer (Peer): The peer information including address, protocol, and query parameters.
64+
_headers (Headers): The request headers as a collection of key-value pairs.
65+
_method (str): The HTTP method used for the request (defaults to POST).
66+
67+
The class provides property accessors for all attributes with appropriate getters
68+
and setters where modification is allowed. Default values are provided for all
69+
parameters during initialization to ensure the object is always in a valid state.
6270
"""
6371

6472
_spec: Spec
@@ -73,17 +81,19 @@ def __init__(
7381
headers: Headers | None = None,
7482
method: str | None = None,
7583
) -> None:
76-
"""Initialize a new Request instance.
84+
"""Initialize a Connect request/response context.
7785
7886
Args:
79-
spec (Spec): The specification for the request.
80-
peer (Peer): The peer information.
81-
headers (Mapping[str, str]): The request headers.
82-
method (str): The HTTP method used for the request.
87+
spec: The RPC specification containing procedure name, descriptor, stream type,
88+
and idempotency level. If None, creates a default Spec with empty procedure,
89+
no descriptor, unary stream type, and idempotent level.
90+
peer: The peer information including address, protocol, and query parameters.
91+
If None, creates a default Peer with no address, empty protocol, and empty query.
92+
headers: HTTP headers for the request/response. If None, creates an empty Headers object.
93+
method: HTTP method to use for the request. If None, defaults to POST.
8394
8495
Returns:
8596
None
86-
8797
"""
8898
self._spec = (
8999
spec
@@ -138,18 +148,12 @@ def method(self, value: str) -> None:
138148
class StreamRequest[T](RequestCommon):
139149
"""StreamRequest class represents a request that can handle streaming messages.
140150
141-
Attributes:
142-
messages (AsyncIterable[T]): An asynchronous iterable of messages.
143-
_spec (Spec): The specification for the request.
144-
_peer (Peer): The peer information.
145-
_headers (Headers): The request headers.
146-
_method (str): The HTTP method used for the request.
147-
151+
This class provides a unified interface for handling both single and multiple
152+
messages in streaming requests. It automatically determines the appropriate
153+
method based on the stream type and usage context.
148154
"""
149155

150156
_messages: AsyncIterable[T]
151-
# timeout: float | None
152-
# abort_event: asyncio.Event | None = None
153157

154158
def __init__(
155159
self,
@@ -158,50 +162,53 @@ def __init__(
158162
peer: Peer | None = None,
159163
headers: Headers | None = None,
160164
method: str | None = None,
161-
# timeout: float | None = None,
162-
# abort_event: asyncio.Event | None = None,
163165
) -> None:
164-
"""Initialize a new Request instance.
166+
"""Initialize a new instance.
165167
166168
Args:
167-
content (AsyncIterable[T] | T): The request content, which can be an async iterable or a single message.
168-
spec (Spec): The specification for the request.
169-
peer (Peer): The peer information.
170-
headers (Mapping[str, str]): The request headers.
171-
method (str): The HTTP method used for the request.
172-
timeout (float): The timeout for the request.
173-
abort_event (asyncio.Event): An event to signal request abortion.
169+
content: The content to be processed, either a single item of type T or an async iterable of items.
170+
spec: Optional specification object defining the behavior or configuration.
171+
peer: Optional peer object representing the connection endpoint.
172+
headers: Optional headers dictionary for metadata or configuration.
173+
method: Optional string specifying the method or operation type.
174174
175175
Returns:
176176
None
177-
178177
"""
179178
super().__init__(spec, peer, headers, method)
180179
self._messages = content if isinstance(content, AsyncIterable) else aiterate([content])
181-
# self.timeout = timeout
182-
# self.abort_event = abort_event
183180

184181
@property
185182
def messages(self) -> AsyncIterable[T]:
186-
"""Return the request message."""
183+
"""Return the request messages as an async iterable.
184+
185+
Use this when you expect multiple messages (client streaming, bidi streaming).
186+
187+
Example:
188+
async for message in request.messages:
189+
process(message)
190+
"""
187191
return self._messages
188192

193+
async def single(self) -> T:
194+
"""Return a single message from the request.
189195
190-
class UnaryRequest[T](RequestCommon):
191-
"""UnaryRequest is a class that encapsulates a request with a message, specification, peer, headers, and method.
196+
Use this when you expect exactly one message (server-side handlers for client streaming).
197+
Raises ConnectError if there are zero or multiple messages.
192198
193-
Attributes:
194-
message (Req): The request message.
195-
_spec (Spec): The specification of the request.
196-
_peer (Peer): The peer associated with the request.
197-
_headers (Mapping[str, str]): The headers of the request.
198-
_method (str): The method of the request.
199+
Example:
200+
message = await request.single()
201+
process(message)
202+
"""
203+
return await ensure_single(self._messages)
199204

200-
"""
201205

202-
_message: T
203-
# timeout: float | None
204-
# abort_event: asyncio.Event | None = None
206+
class UnaryRequest[T](RequestCommon):
207+
"""A unary request wrapper that extends RequestCommon functionality.
208+
209+
This class encapsulates a single message/content of type T along with common request
210+
metadata such as specifications, peer information, headers, and HTTP method.
211+
"""
205212

206213
def __init__(
207214
self,
@@ -210,28 +217,21 @@ def __init__(
210217
peer: Peer | None = None,
211218
headers: Headers | None = None,
212219
method: str | None = None,
213-
# timeout: float | None = None,
214-
# abort_event: asyncio.Event | None = None,
215220
) -> None:
216-
"""Initialize a new Request instance.
221+
"""Initialize a new instance with content and optional parameters.
217222
218223
Args:
219-
content (T): The request message.
220-
spec (Spec): The specification for the request.
221-
peer (Peer): The peer information.
222-
headers (Mapping[str, str]): The request headers.
223-
method (str): The HTTP method used for the request.
224-
timeout (float): The timeout for the request.
225-
abort_event (asyncio.Event): An event to signal request abortion.
224+
content (T): The main content/message to be stored in this instance.
225+
spec (Spec | None, optional): Specification object defining behavior or configuration. Defaults to None.
226+
peer (Peer | None, optional): Peer object representing the remote endpoint or connection. Defaults to None.
227+
headers (Headers | None, optional): HTTP headers or metadata associated with the request/response. Defaults to None.
228+
method (str | None, optional): HTTP method or operation type (e.g., 'GET', 'POST'). Defaults to None.
226229
227230
Returns:
228231
None
229-
230232
"""
231233
super().__init__(spec, peer, headers, method)
232234
self._message = content
233-
# self.timeout = timeout
234-
# self.abort_event = abort_event
235235

236236
@property
237237
def message(self) -> T:
@@ -293,7 +293,12 @@ def message(self) -> T:
293293

294294

295295
class StreamResponse[T](ResponseCommon):
296-
"""Response class for handling responses."""
296+
"""Response class for handling streaming responses.
297+
298+
This class provides a unified interface for handling both single and multiple
299+
messages from streaming responses. It automatically determines the appropriate
300+
method based on the stream type and usage context.
301+
"""
297302

298303
_messages: AsyncIterable[T]
299304

@@ -303,15 +308,40 @@ def __init__(
303308
headers: Headers | None = None,
304309
trailers: Headers | None = None,
305310
) -> None:
306-
"""Initialize the response with a message."""
311+
"""Initialize the response with content.
312+
313+
Args:
314+
content: Either a single message or an async iterable of messages
315+
headers: Optional response headers
316+
trailers: Optional response trailers
317+
"""
307318
super().__init__(headers, trailers)
308319
self._messages = content if isinstance(content, AsyncIterable) else aiterate([content])
309320

310321
@property
311322
def messages(self) -> AsyncIterable[T]:
312-
"""Return the response message."""
323+
"""Return the response messages as an async iterable.
324+
325+
Use this when you expect multiple messages (server streaming, bidi streaming).
326+
327+
Example:
328+
async for message in response.messages:
329+
print(message)
330+
"""
313331
return self._messages
314332

333+
async def single(self) -> T:
334+
"""Return a single message from the response.
335+
336+
Use this when you expect exactly one message (client streaming results).
337+
Raises ConnectError if there are zero or multiple messages.
338+
339+
Example:
340+
message = await response.single()
341+
print(message)
342+
"""
343+
return await ensure_single(self._messages)
344+
315345
async def aclose(self) -> None:
316346
"""Asynchronously close the response stream."""
317347
aclose = get_acallable_attribute(self._messages, "aclose")
@@ -475,8 +505,8 @@ async def send_error(self, error: ConnectError) -> None:
475505
raise NotImplementedError()
476506

477507

478-
class UnaryClientConn:
479-
"""Abstract base class for a streaming client connection."""
508+
class UnaryClientConn(abc.ABC):
509+
"""Abstract base class for a unary client connection."""
480510

481511
@property
482512
@abc.abstractmethod
@@ -529,7 +559,7 @@ async def aclose(self) -> None:
529559
raise NotImplementedError()
530560

531561

532-
class StreamingClientConn:
562+
class StreamingClientConn(abc.ABC):
533563
"""Abstract base class for a streaming client connection."""
534564

535565
@property
@@ -645,7 +675,7 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S
645675
)
646676

647677

648-
async def recieve_unary_response[T](
678+
async def receive_unary_response[T](
649679
conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None
650680
) -> UnaryResponse[T]:
651681
"""Receives a unary response message from a streaming client connection.
@@ -672,7 +702,7 @@ async def recieve_unary_response[T](
672702
return UnaryResponse(message, conn.response_headers, conn.response_trailers)
673703

674704

675-
async def recieve_stream_response[T](
705+
async def receive_stream_response[T](
676706
conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None
677707
) -> StreamResponse[T]:
678708
"""Handle receiving a stream response from a streaming client connection.
@@ -697,10 +727,10 @@ async def recieve_stream_response[T](
697727
698728
"""
699729
if spec.stream_type == StreamType.ClientStream:
700-
single_message = await ensure_single(conn.receive(t, abort_event), conn.aclose)
730+
single_message = await ensure_single(conn.receive(t, abort_event))
701731

702732
return StreamResponse(
703-
AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers
733+
AsyncDataStream[T](aiterate([single_message]), conn.aclose), conn.response_headers, conn.response_trailers
704734
)
705735
else:
706736
return StreamResponse(

0 commit comments

Comments
 (0)