Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ class $3L:
def __init__(self, *, client_config: HTTPClientConfiguration | None = None):
self._client_config = client_config

TIMEOUT_EXCEPTIONS = ()

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> HTTPResponse:
Expand All @@ -657,6 +659,8 @@ def __init__(
self.fields = tuples_to_fields(headers or [])
self.body = body

TIMEOUT_EXCEPTIONS = ()

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> _HTTPResponse:
Expand Down
45 changes: 26 additions & 19 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..auth import AuthParams
from ..deserializers import DeserializeableShape, ShapeDeserializer
from ..endpoints import EndpointResolverParams
from ..exceptions import RetryError, SmithyError
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
from ..interceptors import (
InputContext,
Interceptor,
Expand Down Expand Up @@ -448,24 +448,31 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](

_LOGGER.debug("Sending request %s", request_context.transport_request)

if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
try:
if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
except Exception as e:
if isinstance(e, self.transport.TIMEOUT_EXCEPTIONS):
raise ClientTimeoutError(
message=f"Client timeout occurred: {e}"
) from e
raise

_LOGGER.debug("Received response: %s", transport_response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin


class ClientTransport[I: Request, O: Response](Protocol):
"""Protocol-agnostic representation of a client tranport (e.g. an HTTP client)."""
"""Protocol-agnostic representation of a client transport (e.g. an HTTP client).

Transports must define TIMEOUT_EXCEPTIONS as a tuple of exception types that
are raised when a timeout occurs.
"""

TIMEOUT_EXCEPTIONS: tuple[type[Exception], ...]

async def send(self, request: I) -> O:
"""Send a request over the transport and receive the response."""
Expand Down
17 changes: 17 additions & 0 deletions packages/smithy-core/src/smithy_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class CallError(SmithyError):
is_throttling_error: bool = False
"""Whether the error is a throttling error."""

is_timeout_error: bool = False
"""Whether the error represents a timeout condition."""

def __post_init__(self):
super().__init__(self.message)

Expand All @@ -61,6 +64,20 @@ class ModeledError(CallError):
fault: Fault = "client"


@dataclass(kw_only=True)
class ClientTimeoutError(CallError):
"""Exception raised when a client-side timeout occurs.

This error indicates that the client transport layer encountered a timeout while
attempting to communicate with the server. This typically occurs when network
requests take longer than the configured timeout period.
"""

fault: Fault = "client"
is_timeout_error: bool = True
is_retry_safe: bool | None = True


class SerializationError(SmithyError):
"""Base exception type for exceptions raised during serialization."""

Expand Down
2 changes: 2 additions & 0 deletions packages/smithy-http/src/smithy_http/aio/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __post_init__(self) -> None:
class AIOHTTPClient(HTTPClient):
"""Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp."""

TIMEOUT_EXCEPTIONS = (TimeoutError,)

def __init__(
self,
*,
Expand Down
32 changes: 23 additions & 9 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING, Any

from awscrt.exceptions import AwsCrtError

if TYPE_CHECKING:
# pyright doesn't like optional imports. This is reasonable because if we use these
# in type hints then they'd result in runtime errors.
Expand Down Expand Up @@ -129,9 +131,16 @@ def __post_init__(self) -> None:
_assert_crt()


class _CRTTimeoutError(Exception):
"""Internal wrapper for CRT timeout errors."""


class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient):
_HTTP_PORT = 80
_HTTPS_PORT = 443
_TIMEOUT_ERROR_NAMES = frozenset(["AWS_IO_SOCKET_TIMEOUT", "AWS_IO_SOCKET_CLOSED"])

TIMEOUT_EXCEPTIONS = (_CRTTimeoutError,)

def __init__(
self,
Expand Down Expand Up @@ -163,18 +172,23 @@ async def send(
:param request: The request including destination URI, fields, payload.
:param request_config: Configuration specific to this request.
"""
crt_request = self._marshal_request(request)
connection = await self._get_connection(request.destination)
try:
crt_request = self._marshal_request(request)
connection = await self._get_connection(request.destination)

# Convert body to async iterator for request_body_generator
body_generator = self._create_body_generator(request.body)
# Convert body to async iterator for request_body_generator
body_generator = self._create_body_generator(request.body)

crt_stream = connection.request(
crt_request,
request_body_generator=body_generator,
)
crt_stream = connection.request(
crt_request,
request_body_generator=body_generator,
)

return await self._await_response(crt_stream)
return await self._await_response(crt_stream)
except AwsCrtError as e:
if e.name in self._TIMEOUT_ERROR_NAMES:
raise _CRTTimeoutError() from e
raise

async def _await_response(
self, stream: "AIOHttpClientStreamUnified"
Expand Down
11 changes: 8 additions & 3 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ async def _create_error(
)
return error_shape.deserialize(deserializer)

is_throttle = response.status == 429
message = (
f"Unknown error for operation {operation.schema.id} "
f"- status: {response.status}"
Expand All @@ -224,11 +223,17 @@ async def _create_error(
message += f" - id: {error_id}"
if response.reason is not None:
message += f" - reason: {response.status}"

is_timeout = response.status == 408
is_throttle = response.status == 429
fault = "client" if response.status < 500 else "server"

return CallError(
message=message,
fault="client" if response.status < 500 else "server",
fault=fault,
is_throttling_error=is_throttle,
is_retry_safe=is_throttle or None,
is_timeout_error=is_timeout,
is_retry_safe=is_throttle or is_timeout or None,
)

def _matches_content_type(self, response: HTTPResponse) -> bool:
Expand Down
33 changes: 29 additions & 4 deletions packages/smithy-http/tests/unit/aio/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import Mock

import pytest
from smithy_core import URI
Expand All @@ -11,14 +12,15 @@
from smithy_core.interfaces import URI as URIInterface
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID
from smithy_core.types import TypedProperties as ConcreteTypedProperties
from smithy_http import Fields
from smithy_http.aio import HTTPRequest
from smithy_http.aio import HTTPRequest, HTTPResponse
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
from smithy_http.aio.protocols import HttpClientProtocol
from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol


class TestProtocol(HttpClientProtocol):
class MockProtocol(HttpClientProtocol):
_id = ShapeID("ns.foo#bar")

@property
Expand Down Expand Up @@ -125,7 +127,7 @@ def deserialize_response(
def test_http_protocol_joins_uris(
request_uri: URI, endpoint_uri: URI, expected: URI
) -> None:
protocol = TestProtocol()
protocol = MockProtocol()
request = HTTPRequest(
destination=request_uri,
method="GET",
Expand All @@ -135,3 +137,26 @@ def test_http_protocol_joins_uris(
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
actual = updated_request.destination
assert actual == expected


async def test_http_408_creates_timeout_error() -> None:
protocol = Mock(spec=HttpBindingClientProtocol)
protocol.error_identifier = Mock()
protocol.error_identifier.identify.return_value = None

response = HTTPResponse(status=408, fields=Fields())

error = await HttpBindingClientProtocol._create_error( # type: ignore[reportPrivateUsage]
protocol,
operation=Mock(),
request=HTTPRequest(
destination=URI(host="example.com"), method="POST", fields=Fields()
),
response=response,
response_body=b"",
error_registry=TypeRegistry({}),
context=ConcreteTypedProperties(),
)

assert error.is_timeout_error is True
assert error.fault == "client"
Loading