Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion eng/tools/azure-sdk-tools/devtools_testutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
CachedStorageAccountPreparer,
)

# cSpell:disable
from .envvariable_loader import EnvironmentVariableLoader
from .exceptions import AzureTestError, ReservedResourceNameError
from .proxy_fixtures import environment_variables, recorded_test, variable_recorder
from .proxy_startup import start_test_proxy, stop_test_proxy, test_proxy
from .proxy_testcase import recorded_by_proxy

# Import httpx decorators if httpx is available
try:
from .proxy_testcase_httpx import recorded_by_proxy_httpx
_httpx_available = True
except ImportError:
_httpx_available = False

from .sanitizers import (
add_api_version_transform,
add_batch_sanitizers,
Expand Down Expand Up @@ -118,3 +125,7 @@
"create_combined_bundle",
"is_live_and_not_recording",
]

# Add httpx decorator if available
if _httpx_available:
__all__.append("recorded_by_proxy_httpx")
11 changes: 11 additions & 0 deletions eng/tools/azure-sdk-tools/devtools_testutils/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from .proxy_testcase_async import recorded_by_proxy_async

# Import httpx decorator if httpx is available
try:
from .proxy_testcase_async_httpx import recorded_by_proxy_async_httpx
_httpx_available = True
except ImportError:
_httpx_available = False

__all__ = ["recorded_by_proxy_async"]

# Add httpx decorator if available
if _httpx_available:
__all__.append("recorded_by_proxy_async_httpx")
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Async proxy decorators for httpx-based clients (e.g., OpenAI AsyncOpenAI SDK).

These decorators monkeypatch httpx async transport classes to redirect requests through the test proxy,
enabling recording and playback for async clients that use httpx instead of Azure Core's transport layer.
"""
import logging
import urllib.parse as url_parse

from azure.core.exceptions import ResourceNotFoundError
from azure.core.pipeline.policies import ContentDecodePolicy

from ..helpers import is_live_and_not_recording, trim_kwargs_from_test_function
from ..proxy_testcase import (
get_test_id,
start_record_or_playback,
stop_record_or_playback,
get_proxy_netloc,
)
from ..helpers import is_live

try:
import httpx
except ImportError:
httpx = None


def recorded_by_proxy_async_httpx(test_func):
"""Decorator that redirects async httpx network requests to target the azure-sdk-tools test proxy.

Use this decorator for async tests that use httpx-based clients (like OpenAI AsyncOpenAI SDK)
instead of Azure SDK clients. It monkeypatches httpx.AsyncHTTPTransport.handle_async_request
to route requests through the test proxy.

For more details and usage examples, refer to
https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/tests.md#write-or-run-tests
"""
if httpx is None:
raise ImportError("httpx is required to use recorded_by_proxy_async_httpx. Install it with: pip install httpx")

async def record_wrap(*args, **kwargs):
def transform_httpx_request(request: httpx.Request, recording_id: str) -> None:
"""Transform an httpx.Request to route through the test proxy."""
parsed_result = url_parse.urlparse(str(request.url))

# Store original upstream URI
if "x-recording-upstream-base-uri" not in request.headers:
request.headers["x-recording-upstream-base-uri"] = f"{parsed_result.scheme}://{parsed_result.netloc}"

# Set recording headers
request.headers["x-recording-id"] = recording_id
request.headers["x-recording-mode"] = "record" if is_live() else "playback"

# Rewrite URL to proxy
updated_target = parsed_result._replace(**get_proxy_netloc()).geturl()
request.url = httpx.URL(updated_target)

def restore_httpx_response_url(response: httpx.Response) -> httpx.Response:
"""Restore the response's request URL to the original upstream target."""
try:
parsed_resp = url_parse.urlparse(str(response.request.url))
upstream_uri_str = response.request.headers.get("x-recording-upstream-base-uri", "")
if upstream_uri_str:
upstream_uri = url_parse.urlparse(upstream_uri_str)
original_target = parsed_resp._replace(
scheme=upstream_uri.scheme or parsed_resp.scheme,
netloc=upstream_uri.netloc
).geturl()
response.request.url = httpx.URL(original_target)
except Exception:
# Best-effort restore; don't fail the call if something goes wrong
pass
return response

trimmed_kwargs = {k: v for k, v in kwargs.items()}
trim_kwargs_from_test_function(test_func, trimmed_kwargs)

if is_live_and_not_recording():
return await test_func(*args, **trimmed_kwargs)

test_id = get_test_id()
recording_id, variables = start_record_or_playback(test_id)
original_transport_func = httpx.AsyncHTTPTransport.handle_async_request

async def combined_call(transport_self, request: httpx.Request) -> httpx.Response:
transform_httpx_request(request, recording_id)
result = await original_transport_func(transport_self, request)
return restore_httpx_response_url(result)

httpx.AsyncHTTPTransport.handle_async_request = combined_call

# Call the test function
test_variables = None
test_run = False
try:
try:
test_variables = await test_func(*args, variables=variables, **trimmed_kwargs)
test_run = True
except TypeError as error:
if "unexpected keyword argument" in str(error) and "variables" in str(error):
logger = logging.getLogger()
logger.info(
"This test can't accept variables as input. The test method should accept `**kwargs` and/or a "
"`variables` parameter to make use of recorded test variables."
)
else:
raise error
# If the test couldn't accept `variables`, run without passing them
if not test_run:
test_variables = await test_func(*args, **trimmed_kwargs)

except ResourceNotFoundError as error:
error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
message = error_body.get("message") or error_body.get("Message")
error_with_message = ResourceNotFoundError(message=message, response=error.response)
raise error_with_message from error

finally:
httpx.AsyncHTTPTransport.handle_async_request = original_transport_func
stop_record_or_playback(test_id, recording_id, test_variables)

return test_variables

return record_wrap
136 changes: 125 additions & 11 deletions eng/tools/azure-sdk-tools/devtools_testutils/proxy_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from inspect import iscoroutinefunction
import logging
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple, Optional
import urllib.parse as url_parse

import pytest
Expand All @@ -21,7 +21,7 @@
pass

from .helpers import get_test_id, is_live, is_live_and_not_recording
from .proxy_testcase import start_record_or_playback, stop_record_or_playback, transform_request
from .proxy_testcase import start_record_or_playback, stop_record_or_playback, transform_request, get_proxy_netloc
from .proxy_startup import test_proxy
from .sanitizers import add_batch_sanitizers, add_general_string_sanitizer, Sanitizer

Expand Down Expand Up @@ -158,7 +158,7 @@ async def recorded_test(test_proxy: None, request: "FixtureRequest") -> "Dict[st

# True if the function requesting the fixture is an async test
if iscoroutinefunction(request._pyfuncitem.function):
original_transport_func = await redirect_async_traffic(recording_id)
original_transport_func, original_httpx_async_func = await redirect_async_traffic(recording_id)
yield {"variables": variables} # yield relevant test info and allow tests to run
restore_async_traffic(original_transport_func, request)
else:
Expand Down Expand Up @@ -201,7 +201,7 @@ def start_proxy_session() -> "Tuple[str, str, Dict[str, str]]":
return (test_id, recording_id, variables)


async def redirect_async_traffic(recording_id: str) -> "Callable":
async def redirect_async_traffic(recording_id: str) -> Tuple["Callable", Optional["Callable"]]:
"""Redirects asynchronous network requests to target the test proxy.

:param str recording_id: Recording ID of the currently executing test.
Expand All @@ -210,7 +210,7 @@ async def redirect_async_traffic(recording_id: str) -> "Callable":
"""
from azure.core.pipeline.transport import AioHttpTransport

original_transport_func = AioHttpTransport.send
original_async_transport_func = AioHttpTransport.send

def transform_args(*args, **kwargs):
copied_positional_args = list(args)
Expand All @@ -222,7 +222,7 @@ def transform_args(*args, **kwargs):

async def combined_call(*args, **kwargs):
adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs)
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
result = await original_async_transport_func(*adjusted_args, **adjusted_kwargs)

# make the x-recording-upstream-base-uri the URL of the request
# this makes the request look like it was made to the original endpoint instead of to the proxy
Expand All @@ -236,7 +236,52 @@ async def combined_call(*args, **kwargs):
return result

AioHttpTransport.send = combined_call
return original_transport_func

try:
import httpx

original_httpx_async_send = httpx.AsyncClient.send

if original_httpx_async_send is None:
raise ImportError("httpx.AsyncClient.send not found while able to import httpx")

def _transform_args(*args, **kwargs):
copied_positional_args = list(args)
request = copied_positional_args[1]

parsed = url_parse.urlparse(str(request.url))
if "x-recording-upstream-base-uri" not in request.headers:
request.headers["x-recording-upstream-base-uri"] = f"{parsed.scheme}://{parsed.netloc}"
request.headers["x-recording-id"] = recording_id
request.headers["x-recording-mode"] = "record" if is_live() else "playback"

proxied = parsed._replace(**get_proxy_netloc()).geturl()
request.url = httpx.URL(proxied)

return tuple(copied_positional_args), kwargs

async def combined_async_send(*args, **kwargs):
adjusted_args, adjusted_kwargs = _transform_args(*args, **kwargs)
result = await original_httpx_async_send(*adjusted_args, **adjusted_kwargs)

try:
parsed_result = url_parse.urlparse(result.request.url)
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
original_target = parsed_result._replace(**upstream_uri_dict).geturl()
result.request.url = original_target
except Exception:
pass

return result

if original_httpx_async_send is not None:
httpx.AsyncClient.send = combined_async_send
return (original_async_transport_func, original_httpx_async_send)
except ImportError:
pass # there is no httpx to patch the return at the end is good enough.

return (original_async_transport_func, None)


def redirect_traffic(recording_id: str) -> "Callable":
Expand Down Expand Up @@ -272,10 +317,53 @@ def combined_call(*args, **kwargs):
return result

RequestsTransport.send = combined_call
return original_transport_func

# attempt to monkeypatch httpx.Client.send as well (if httpx is installed)
original_httpx_send = None
try:
import httpx

original_httpx_send = getattr(httpx.Client, "send", None)

def _transform_args(*args, **kwargs):
copied_positional_args = list(args)
request = copied_positional_args[1]

parsed = url_parse.urlparse(str(request.url))
if "x-recording-upstream-base-uri" not in request.headers:
request.headers["x-recording-upstream-base-uri"] = f"{parsed.scheme}://{parsed.netloc}"
request.headers["x-recording-id"] = recording_id
request.headers["x-recording-mode"] = "record" if is_live() else "playback"

def restore_async_traffic(original_transport_func: "Callable", request: "FixtureRequest") -> None:
proxied = parsed._replace(**get_proxy_netloc()).geturl()
request.url = httpx.URL(proxied)

return tuple(copied_positional_args), kwargs

def combined_httpx_send(*args, **kwargs):
adjusted_args, adjusted_kwargs = _transform_args(*args, **kwargs)
result = original_httpx_send(*adjusted_args, **adjusted_kwargs)

try:
parsed_result = url_parse.urlparse(result.request.url)
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
original_target = parsed_result._replace(**upstream_uri_dict).geturl()
result.request.url = httpx.URL(original_target)
except Exception:
pass

return result

if original_httpx_send is not None:
httpx.Client.send = combined_httpx_send
except Exception:
original_httpx_send = None

return (original_transport_func, original_httpx_send)


def restore_async_traffic(original_transport_func: "Callable", original_httpx_transport_func: Optional["Callable"], request: "FixtureRequest") -> None:
"""Resets asynchronous network traffic to no longer target the test proxy.

:param original_transport_func: The original transport function used by the currently executing test.
Expand All @@ -285,7 +373,20 @@ def restore_async_traffic(original_transport_func: "Callable", request: "Fixture
"""
from azure.core.pipeline.transport import AioHttpTransport

AioHttpTransport.send = original_transport_func # test finished running -- tear down
# original_transport_func may be a tuple (original_aio_send, original_httpx_async_send)
orig_aio_send = original_transport_func[0] if isinstance(original_transport_func, tuple) else original_transport_func
orig_httpx_async_send = original_transport_func[1] if isinstance(original_transport_func, tuple) else None

AioHttpTransport.send = orig_aio_send # test finished running -- tear down

# restore httpx.AsyncClient.send if we patched it
if orig_httpx_async_send is not None:
try:
import httpx

httpx.AsyncClient.send = orig_httpx_async_send
except Exception:
pass

if hasattr(request.node, "test_error"):
# Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
Expand All @@ -308,7 +409,20 @@ def restore_traffic(original_transport_func: "Callable", request: "FixtureReques
:param request: The built-in `request` pytest fixture.
:type request: ~pytest.FixtureRequest
"""
RequestsTransport.send = original_transport_func # test finished running -- tear down
# original_transport_func may be a tuple (original_requests_send, original_httpx_send)
orig_requests_send = original_transport_func[0] if isinstance(original_transport_func, tuple) else original_transport_func
orig_httpx_send = original_transport_func[1] if isinstance(original_transport_func, tuple) else None

RequestsTransport.send = orig_requests_send # test finished running -- tear down

# restore httpx.Client.send if we patched it
if orig_httpx_send is not None:
try:
import httpx

httpx.Client.send = orig_httpx_send
except Exception:
pass

if hasattr(request.node, "test_error"):
# Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
Expand Down
Loading
Loading