From 11d88e9951ba9d8c8b660be0d5fb662046da0952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 07:17:23 +0200 Subject: [PATCH 01/54] SNOW-2183023: Session manager refactored --- src/snowflake/connector/network.py | 110 ++++++++++++++++------------- 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 46758eef6d..1504e43b88 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -334,19 +334,59 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r +class SessionManager: + def __init__(self, use_pooling: bool = True): + self._use_pooling = use_pooling + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def make_session(self) -> Session: + s = requests.Session() + s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) + s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + class SessionPool: - def __init__(self, rest: SnowflakeRestful) -> None: + def __init__(self, manager: SessionManager) -> None: # A stack of the idle sessions - self._idle_sessions: list[Session] = [] - self._active_sessions: set[Session] = set() - self._rest: SnowflakeRestful = rest + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager def get_session(self) -> Session: """Returns a session from the session pool or creates a new one.""" try: session = self._idle_sessions.pop() except IndexError: - session = self._rest.make_requests_session() + session = self._manager.make_session() self._active_sessions.add(session) return session @@ -368,11 +408,11 @@ def close(self) -> None: """Closes all active and idle sessions in this session pool.""" if self._active_sessions: logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for s in itertools.chain(self._active_sessions, self._idle_sessions): + for session in itertools.chain(self._active_sessions, self._idle_sessions): try: - s.close() + session.close() except Exception as e: - logger.info(f"Session cleanup failed: {e}") + logger.info(f"Session cleanup failed - failed to close session: {e}") self._active_sessions.clear() self._idle_sessions.clear() @@ -403,8 +443,8 @@ def __init__( self._inject_client_pause = inject_client_pause self._connection = connection self._lock_token = Lock() - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) + self._session_manager = SessionManager( + use_pooling=not self._connection.disable_request_pooling, ) # OCSP mode (OCSPMode.FAIL_OPEN by default) @@ -470,6 +510,14 @@ def mfa_token(self, value: str) -> None: def server_url(self) -> str: return f"{self._protocol}://{self._host}:{self._port}" + @property + def session_manager(self) -> SessionManager: + return self._session_manager + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self.session_manager.sessions_map + def close(self) -> None: if hasattr(self, "_token"): del self._token @@ -480,8 +528,7 @@ def close(self) -> None: if hasattr(self, "_mfa_token"): del self._mfa_token - for session_pool in self._sessions_map.values(): - session_pool.close() + self._session_manager.close() def request( self, @@ -1258,40 +1305,5 @@ def _request_exec( except Exception as err: raise err - def make_requests_session(self) -> Session: - s = requests.Session() - s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) - s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def _use_requests_session(self, url: str | None = None): - """Session caching context manager. - - Notes: - The session is not closed until close() is called so each session may be used multiple times. - """ - # short-lived session, not added to the _sessions_map - if self._connection.disable_request_pooling: - session = self.make_requests_session() - try: - yield session - finally: - session.close() - else: - try: - hostname = urlparse(url).hostname - except Exception: - hostname = None - - session_pool: SessionPool = self._sessions_map[hostname] - session = session_pool.get_session() - logger.debug(f"Session status for SessionPool '{hostname}', {session_pool}") - try: - yield session - finally: - session_pool.return_session(session) - logger.debug( - f"Session status for SessionPool '{hostname}', {session_pool}" - ) + def _use_requests_session(self, url=None): + return self._session_manager.use_session(url) From c37516436c06165e6ad9e2c1e83366bda6411773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 07:37:08 +0200 Subject: [PATCH 02/54] SNOW-2183023: Added adapter factory --- src/snowflake/connector/network.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 1504e43b88..8c00f0b2f5 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -12,7 +12,7 @@ import uuid from collections import OrderedDict from threading import Lock -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import OpenSSL.SSL @@ -335,8 +335,13 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: class SessionManager: - def __init__(self, use_pooling: bool = True): + def __init__( + self, + use_pooling: bool = True, + adapter_factory: Callable[..., HTTPAdapter] | None = None, + ): self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( lambda: SessionPool(self) ) @@ -347,8 +352,9 @@ def sessions_map(self) -> dict[str, SessionPool]: def make_session(self) -> Session: s = requests.Session() - s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) - s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + s.mount("http://", adapter) + s.mount("https://", adapter) s._reuse_count = itertools.count() return s @@ -436,6 +442,7 @@ def __init__( protocol: str = "http", inject_client_pause: int = 0, connection: SnowflakeConnection | None = None, + adapter_factory: Callable[[], HTTPAdapter] | None = None, ) -> None: self._host = host self._port = port @@ -445,6 +452,7 @@ def __init__( self._lock_token = Lock() self._session_manager = SessionManager( use_pooling=not self._connection.disable_request_pooling, + adapter_factory=adapter_factory, ) # OCSP mode (OCSPMode.FAIL_OPEN by default) From e820b4dab1fbd546405f9f333f2a50fdd3f9954d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 08:49:03 +0200 Subject: [PATCH 03/54] SNOW-2183023: Adapters fixed in tests --- src/snowflake/connector/network.py | 14 ++++++++++---- test/unit/test_session_manager.py | 6 +++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 8c00f0b2f5..48d8b8f0bd 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -338,7 +338,9 @@ class SessionManager: def __init__( self, use_pooling: bool = True, - adapter_factory: Callable[..., HTTPAdapter] | None = None, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, ): self._use_pooling = use_pooling self._adapter_factory = adapter_factory or ProxySupportAdapter @@ -350,11 +352,15 @@ def __init__( def sessions_map(self) -> dict[str, SessionPool]: return self._sessions_map + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + def make_session(self) -> Session: s = requests.Session() - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - s.mount("http://", adapter) - s.mount("https://", adapter) + self._mount_adapter(s) s._reuse_count = itertools.count() return s diff --git a/test/unit/test_session_manager.py b/test/unit/test_session_manager.py index 8ca3044b6b..67d303d7d2 100644 --- a/test/unit/test_session_manager.py +++ b/test/unit/test_session_manager.py @@ -48,7 +48,7 @@ def create_session( create_session(rest, num_sessions - 1, url) -@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session") +@mock.patch("snowflake.connector.network.SessionManager.make_session") def test_no_url_multiple_sessions(make_session_mock): rest = SnowflakeRestful(connection=mock_conn) @@ -65,7 +65,7 @@ def test_no_url_multiple_sessions(make_session_mock): close_sessions(rest, 1) -@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session") +@mock.patch("snowflake.connector.network.SessionManager.make_session") def test_multiple_urls_multiple_sessions(make_session_mock): rest = SnowflakeRestful(connection=mock_conn) @@ -85,7 +85,7 @@ def test_multiple_urls_multiple_sessions(make_session_mock): close_sessions(rest, 3) -@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session") +@mock.patch("snowflake.connector.network.SessionManager.make_session") def test_multiple_urls_reuse_sessions(make_session_mock): rest = SnowflakeRestful(connection=mock_conn) for url in [url_1, url_2, url_3, None]: From 965786ae6f9da6983845154583f448d0e0fc7d13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 09:15:31 +0200 Subject: [PATCH 04/54] SNOW-2183023: Removed boto --- src/snowflake/connector/wif_util.py | 307 ++++++++++++++++++++-------- 1 file changed, 221 insertions(+), 86 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 3449cdd5ef..117cc23ed7 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -1,17 +1,17 @@ from __future__ import annotations +import hashlib +import hmac import json import logging import os from base64 import b64encode from dataclasses import dataclass +from datetime import datetime, timezone from enum import Enum, unique +from urllib.parse import urlparse -import boto3 import jwt -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.utils import InstanceMetadataRegionFetcher from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError @@ -62,6 +62,15 @@ class WorkloadIdentityAttestation: user_identifier_components: dict +@dataclass +class AwsCredentials: + """AWS credentials container.""" + + access_key: str + secret_key: str + token: str | None = None + + def try_metadata_service_call( method: str, url: str, headers: dict, timeout_sec: int = 3 ) -> Response | None: @@ -105,84 +114,202 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st return claims["iss"], claims["sub"] -def get_aws_region() -> str | None: - """Get the current AWS workload's region, if any.""" - if "AWS_REGION" in os.environ: # Lambda - return os.environ["AWS_REGION"] - else: # EC2 - return InstanceMetadataRegionFetcher().retrieve_region() +def get_aws_credentials() -> AwsCredentials | None: + """Get AWS credentials from environment variables or instance metadata. + Implements the AWS credential chain without using boto3. + """ + # Try environment variables first + access_key = os.environ.get("AWS_ACCESS_KEY_ID") + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + session_token = os.environ.get("AWS_SESSION_TOKEN") -def get_aws_arn() -> str | None: - """Get the current AWS workload's ARN, if any.""" - caller_identity = boto3.client("sts").get_caller_identity() - if not caller_identity or "Arn" not in caller_identity: - return None - return caller_identity["Arn"] + if access_key and secret_key: + return AwsCredentials(access_key, secret_key, session_token) + + # Try instance metadata service (IMDSv2) + try: + # First, get a token for IMDSv2 + token_res = try_metadata_service_call( + method="PUT", + url="http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + ) + if token_res is None: + logger.debug("Failed to get IMDSv2 token from metadata service.") + return None -def get_aws_partition(arn: str) -> str | None: - """Get the current AWS partition from ARN, if any. + token = token_res.text.strip() - Args: - arn (str): The Amazon Resource Name (ARN) string. + # Get the security credentials from the metadata service + res = try_metadata_service_call( + method="GET", + url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", + headers={"X-aws-ec2-metadata-token": token}, + ) + if res is None: + logger.debug("Failed to get IAM role list from metadata service.") + return None - Returns: - str | None: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov') - if found, otherwise None. + role_name = res.text.strip() + if not role_name: + logger.debug("No IAM role found in metadata service.") + return None + + # Get credentials for the role + res = try_metadata_service_call( + method="GET", + url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", + headers={"X-aws-ec2-metadata-token": token}, + ) + if res is None: + logger.debug("Failed to get IAM role credentials from metadata service.") + return None + + creds_data = res.json() + access_key = creds_data.get("AccessKeyId") + secret_key = creds_data.get("SecretAccessKey") + token = creds_data.get("Token") + + if access_key and secret_key: + return AwsCredentials(access_key, secret_key, token) + + except Exception as e: + logger.debug(f"Error getting AWS credentials from metadata service: {e}") + + return None + + +def get_aws_region() -> str | None: + """Get the current AWS workload's region, if any.""" + # Try environment variable first + region = os.environ.get("AWS_REGION") + if region: + return region + + # Try instance metadata service (IMDSv2) + try: + # First, get a token for IMDSv2 + token_res = try_metadata_service_call( + method="PUT", + url="http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + ) + + if token_res is None: + logger.debug("Failed to get IMDSv2 token from metadata service.") + return None + + token = token_res.text.strip() + + # Get region from metadata service + res = try_metadata_service_call( + method="GET", + url="http://169.254.169.254/latest/meta-data/placement/region", + headers={"X-aws-ec2-metadata-token": token}, + ) + if res is not None: + return res.text.strip() + except Exception as e: + logger.debug(f"Error getting AWS region from metadata service: {e}") - Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html. - """ - if not arn or not isinstance(arn, str): - return None - parts = arn.split(":") - if len(parts) > 1 and parts[0] == "arn" and parts[1]: - return parts[1] - logger.warning("Invalid AWS ARN: %s", arn) return None -def get_aws_sts_hostname(region: str, partition: str) -> str | None: - """Constructs the AWS STS hostname for a given region and partition. +def get_aws_sts_hostname(region: str) -> str: + """Constructs the AWS STS hostname for a given region. Args: region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1'). - partition (str): The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov'). Returns: - str | None: The AWS STS hostname (e.g., 'sts.us-east-1.amazonaws.com') - if a valid hostname can be constructed, otherwise None. + str: The AWS STS hostname (e.g., 'sts.us-east-1.amazonaws.com') References: - https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - https://docs.aws.amazon.com/general/latest/gr/sts.html """ - if ( - not region - or not partition - or not isinstance(region, str) - or not isinstance(partition, str) - ): - return None - - if partition == "aws": - # For the 'aws' partition, STS endpoints are generally regional - # except for the global endpoint (sts.amazonaws.com) which is - # generally resolved to us-east-1 under the hood by the SDKs - # when a region is not explicitly specified. - # However, for explicit regional endpoints, the format is sts..amazonaws.com - return f"sts.{region}.amazonaws.com" - elif partition == "aws-cn": + if region.startswith("cn-"): # China regions have a different domain suffix return f"sts.{region}.amazonaws.com.cn" - elif partition == "aws-us-gov": - return ( - f"sts.{region}.amazonaws.com" # GovCloud uses .com, but dedicated regions - ) else: - logger.warning("Invalid AWS partition: %s", partition) - return None + # Standard AWS regions + return f"sts.{region}.amazonaws.com" + + +def aws_signature_v4_sign( + credentials: AwsCredentials, + method: str, + url: str, + region: str, + service: str, + headers: dict, + payload: str = "", +) -> dict: + """Sign an AWS request using Signature Version 4. + + Based on the C# implementation in AwsSignature4Signer.cs. + """ + # Parse the URL + parsed_url = urlparse(url) + + # Create timestamp + utc_now = datetime.now(timezone.utc) + amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") + date_string = utc_now.strftime("%Y%m%d") + + # Add required headers + headers = headers.copy() + headers["x-amz-date"] = amz_date + if credentials.token: + headers["x-amz-security-token"] = credentials.token + + # Create canonical request + canonical_uri = parsed_url.path or "/" + canonical_querystring = parsed_url.query or "" + + # Sort headers and create canonical headers + sorted_headers = sorted(headers.items(), key=lambda x: x[0].lower()) + canonical_headers = "" + signed_headers = "" + + for key, value in sorted_headers: + canonical_headers += f"{key.lower()}:{str(value).strip()}\n" + if signed_headers: + signed_headers += ";" + signed_headers += key.lower() + + # Create payload hash + payload_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest() + + # Create canonical request + canonical_request = f"{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + + # Create string to sign + algorithm = "AWS4-HMAC-SHA256" + credential_scope = f"{date_string}/{region}/{service}/aws4_request" + string_to_sign = f"{algorithm}\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" + + # Calculate signature + def hmac_sha256(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + k_date = hmac_sha256(f"AWS4{credentials.secret_key}".encode(), date_string) + k_region = hmac_sha256(k_date, region) + k_service = hmac_sha256(k_region, service) + k_signing = hmac_sha256(k_service, "aws4_request") + + signature = hmac.new( + k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + # Create authorization header + authorization = f"{algorithm} Credential={credentials.access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}" + headers["authorization"] = authorization + + return headers def create_aws_attestation() -> WorkloadIdentityAttestation | None: @@ -190,43 +317,51 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: If the application isn't running on AWS or no credentials were found, returns None. """ - aws_creds = boto3.session.Session().get_credentials() - if not aws_creds: + credentials = get_aws_credentials() + if not credentials: logger.debug("No AWS credentials were found.") return None + region = get_aws_region() if not region: logger.debug("No AWS region was found.") return None - arn = get_aws_arn() - if not arn: - logger.debug("No AWS caller identity was found.") - return None - partition = get_aws_partition(arn) - if not partition: - logger.debug("No AWS partition was found.") - return None - sts_hostname = get_aws_sts_hostname(region, partition) - request = AWSRequest( + # Create the GetCallerIdentity request + sts_hostname = get_aws_sts_hostname(region) + url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" + + headers = { + "Host": sts_hostname, + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, + } + + # Sign the request + signed_headers = aws_signature_v4_sign( + credentials=credentials, method="POST", - url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", - headers={ - "Host": sts_hostname, - "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, - }, + url=url, + region=region, + service="sts", + headers=headers, ) - SigV4Auth(aws_creds, "sts", region).add_auth(request) - - assertion_dict = { - "url": request.url, - "method": request.method, - "headers": dict(request.headers.items()), + # Create attestation request + attestation_request = { + "method": "POST", + "url": url, + "headers": signed_headers, } - credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") + + # Encode to base64 + credential = b64encode(json.dumps(attestation_request).encode("utf-8")).decode( + "utf-8" + ) + return WorkloadIdentityAttestation( - AttestationProvider.AWS, credential, {"arn": arn} + AttestationProvider.AWS, + credential, + {}, # No user identifier components needed - Snowflake will extract from the signed request ) @@ -356,11 +491,11 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_aws_attestation() + attestation = create_azure_attestation(entra_resource) if attestation: return attestation - attestation = create_azure_attestation(entra_resource) + attestation = create_aws_attestation() if attestation: return attestation @@ -385,7 +520,7 @@ def create_attestation( """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE - attestation: WorkloadIdentityAttestation = None + attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: attestation = create_aws_attestation() elif provider == AttestationProvider.AZURE: From f8b09446706788334a5c68344585db31b9c99852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 09:56:38 +0200 Subject: [PATCH 05/54] SNOW-2183023: Removed boto --- src/snowflake/connector/wif_util.py | 154 ++++++++++++++--------- test/unit/test_auth_workload_identity.py | 108 ++++++---------- 2 files changed, 132 insertions(+), 130 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 117cc23ed7..21a7142c30 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from enum import Enum, unique -from urllib.parse import urlparse +from urllib.parse import parse_qsl, quote, urlparse import jwt @@ -89,7 +89,9 @@ def try_metadata_service_call( return res -def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]: +def extract_iss_and_sub_without_signature_verification( + jwt_str: str, +) -> tuple[str | None, str | None]: """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature. Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have @@ -114,6 +116,18 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st return claims["iss"], claims["sub"] +# --------------------------------------------------------------------------- # +# AWS helper utilities (token, credentials, region) # +# --------------------------------------------------------------------------- # +def _imds_v2_token() -> str | None: + res = try_metadata_service_call( + method="PUT", + url="http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + ) + return res.text.strip() if res else None + + def get_aws_credentials() -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. @@ -129,24 +143,18 @@ def get_aws_credentials() -> AwsCredentials | None: # Try instance metadata service (IMDSv2) try: - # First, get a token for IMDSv2 - token_res = try_metadata_service_call( - method="PUT", - url="http://169.254.169.254/latest/api/token", - headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, - ) - - if token_res is None: + token = _imds_v2_token() + if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None - token = token_res.text.strip() + token_hdr = {"X-aws-ec2-metadata-token": token} if token else {} # Get the security credentials from the metadata service res = try_metadata_service_call( method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", - headers={"X-aws-ec2-metadata-token": token}, + headers=token_hdr, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -161,7 +169,7 @@ def get_aws_credentials() -> AwsCredentials | None: res = try_metadata_service_call( method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", - headers={"X-aws-ec2-metadata-token": token}, + headers=token_hdr, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -174,7 +182,6 @@ def get_aws_credentials() -> AwsCredentials | None: if access_key and secret_key: return AwsCredentials(access_key, secret_key, token) - except Exception as e: logger.debug(f"Error getting AWS credentials from metadata service: {e}") @@ -183,43 +190,47 @@ def get_aws_credentials() -> AwsCredentials | None: def get_aws_region() -> str | None: """Get the current AWS workload's region, if any.""" - # Try environment variable first region = os.environ.get("AWS_REGION") if region: return region - # Try instance metadata service (IMDSv2) try: - # First, get a token for IMDSv2 - token_res = try_metadata_service_call( - method="PUT", - url="http://169.254.169.254/latest/api/token", - headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, - ) - - if token_res is None: + token = _imds_v2_token() + if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None - token = token_res.text.strip() + token_hdr = {"X-aws-ec2-metadata-token": token} if token else {} # Get region from metadata service res = try_metadata_service_call( method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", - headers={"X-aws-ec2-metadata-token": token}, + headers=token_hdr, ) if res is not None: return res.text.strip() + + res = try_metadata_service_call( + method="GET", + url="http://169.254.169.254/latest/meta-data/placement/availability-zone", + headers=token_hdr, + ) + if res is not None: + return res.text.strip()[:-1] except Exception as e: logger.debug(f"Error getting AWS region from metadata service: {e}") return None -def get_aws_sts_hostname(region: str) -> str: +def get_aws_sts_hostname(region: str) -> str | None: """Constructs the AWS STS hostname for a given region. + * China regions (`cn-*`) → sts..amazonaws.com.cn + * All other regions → sts..amazonaws.com + * Any invalid input → None + Args: region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1'). @@ -231,6 +242,10 @@ def get_aws_sts_hostname(region: str) -> str: - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - https://docs.aws.amazon.com/general/latest/gr/sts.html """ + + if not region or not isinstance(region, str): + return None + if region.startswith("cn-"): # China regions have a different domain suffix return f"sts.{region}.amazonaws.com.cn" @@ -239,6 +254,19 @@ def get_aws_sts_hostname(region: str) -> str: return f"sts.{region}.amazonaws.com" +def _aws_percent_encode(s: str) -> str: + return quote(s, safe="~") + + +def _canonical_query(query: str) -> str: + if not query: + return "" + pairs = sorted(parse_qsl(query, keep_blank_values=True)) + return "&".join( + f"{_aws_percent_encode(k)}={_aws_percent_encode(v)}" for k, v in pairs + ) + + def aws_signature_v4_sign( credentials: AwsCredentials, method: str, @@ -252,47 +280,43 @@ def aws_signature_v4_sign( Based on the C# implementation in AwsSignature4Signer.cs. """ - # Parse the URL parsed_url = urlparse(url) - # Create timestamp utc_now = datetime.now(timezone.utc) amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") date_string = utc_now.strftime("%Y%m%d") - # Add required headers - headers = headers.copy() - headers["x-amz-date"] = amz_date + headers_lower = {k.lower(): str(v).strip() for k, v in headers.items()} + headers_lower["host"] = parsed_url.netloc + headers_lower["x-amz-date"] = amz_date if credentials.token: - headers["x-amz-security-token"] = credentials.token + headers_lower["x-amz-security-token"] = credentials.token - # Create canonical request - canonical_uri = parsed_url.path or "/" - canonical_querystring = parsed_url.query or "" + sorted_header_keys = sorted(headers_lower.keys()) + canonical_headers = "".join(f"{k}:{headers_lower[k]}\n" for k in sorted_header_keys) + signed_headers = ";".join(sorted_header_keys) - # Sort headers and create canonical headers - sorted_headers = sorted(headers.items(), key=lambda x: x[0].lower()) - canonical_headers = "" - signed_headers = "" - - for key, value in sorted_headers: - canonical_headers += f"{key.lower()}:{str(value).strip()}\n" - if signed_headers: - signed_headers += ";" - signed_headers += key.lower() - - # Create payload hash + canonical_querystring = _canonical_query(parsed_url.query) payload_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest() - # Create canonical request - canonical_request = f"{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + canonical_request = ( + f"{method}\n" + f"{parsed_url.path or '/'}\n" + f"{canonical_querystring}\n" + f"{canonical_headers}" + f"{signed_headers}\n" + f"{payload_hash}" + ) - # Create string to sign algorithm = "AWS4-HMAC-SHA256" credential_scope = f"{date_string}/{region}/{service}/aws4_request" - string_to_sign = f"{algorithm}\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" + string_to_sign = ( + f"{algorithm}\n" + f"{amz_date}\n" + f"{credential_scope}\n" + f"{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" + ) - # Calculate signature def hmac_sha256(key: bytes, msg: str) -> bytes: return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() @@ -305,11 +329,20 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 ).hexdigest() - # Create authorization header - authorization = f"{algorithm} Credential={credentials.access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}" - headers["authorization"] = authorization + authorization = ( + f"{algorithm} " + f"Credential={credentials.access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, Signature={signature}" + ) + + final_headers = headers.copy() + final_headers["Host"] = parsed_url.netloc + final_headers["X-Amz-Date"] = amz_date + if credentials.token: + final_headers["X-Amz-Security-Token"] = credentials.token + final_headers["Authorization"] = authorization - return headers + return final_headers def create_aws_attestation() -> WorkloadIdentityAttestation | None: @@ -331,19 +364,17 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: sts_hostname = get_aws_sts_hostname(region) url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" - headers = { - "Host": sts_hostname, + base_headers = { "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, } - # Sign the request signed_headers = aws_signature_v4_sign( credentials=credentials, method="POST", url=url, region=region, service="sts", - headers=headers, + headers=base_headers, ) # Create attestation request @@ -353,7 +384,6 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: "headers": signed_headers, } - # Encode to base64 credential = b64encode(json.dumps(attestation_request).encode("utf-8")).decode( "utf-8" ) diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index f2e42aae3e..6df7e0fb11 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -17,7 +17,6 @@ from snowflake.connector.wif_util import ( AZURE_ISSUER_PREFIXES, AttestationProvider, - get_aws_partition, get_aws_sts_hostname, ) @@ -33,7 +32,7 @@ def extract_api_data(auth_class: AuthByWorkloadIdentity): return req_body["data"] -def verify_aws_token(token: str, region: str): +def verify_aws_token(token: str, region: str, expect_session_token: bool = True): """Performs some basic checks on a 'token' produced for AWS, to ensure it includes the expected fields.""" decoded_token = json.loads(b64decode(token)) @@ -47,18 +46,22 @@ def verify_aws_token(token: str, region: str): assert decoded_token["method"] == "POST" headers = decoded_token["headers"] - assert set(headers.keys()) == { + + base_expected = { "Host", "X-Snowflake-Audience", "X-Amz-Date", - "X-Amz-Security-Token", "Authorization", } + if expect_session_token: + base_expected.add("X-Amz-Security-Token") + + assert base_expected.issubset(headers.keys()) assert headers["Host"] == f"sts.{region}.amazonaws.com" assert headers["X-Snowflake-Audience"] == "snowflakecomputing.com" -# -- OIDC Tests -- +# -- OIDC Tests -------------------------------------------------------------- def test_explicit_oidc_valid_inline_token_plumbed_to_api(): @@ -104,7 +107,7 @@ def test_explicit_oidc_no_token_raises_error(): assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) -# -- AWS Tests -- +# -- AWS Tests --------------------------------------------------------------- def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironment): @@ -125,7 +128,11 @@ def test_explicit_aws_encodes_audience_host_signature_to_api( data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" assert data["PROVIDER"] == "AWS" - verify_aws_token(data["TOKEN"], fake_aws_environment.region) + verify_aws_token( + data["TOKEN"], + fake_aws_environment.region, + expect_session_token=fake_aws_environment.credentials.token is not None, + ) def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnvironment): @@ -160,73 +167,34 @@ def test_explicit_aws_generates_unique_assertion_content( @pytest.mark.parametrize( - "arn, expected_partition", + "region, expected_hostname", [ - ("arn:aws:iam::123456789012:role/MyTestRole", "aws"), + # Standard partition + ("us-east-1", "sts.us-east-1.amazonaws.com"), + ("eu-west-2", "sts.eu-west-2.amazonaws.com"), + # China partition + ("cn-north-1", "sts.cn-north-1.amazonaws.com.cn"), + ("cn-northwest-1", "sts.cn-northwest-1.amazonaws.com.cn"), + # GovCloud partition + ("us-gov-west-1", "sts.us-gov-west-1.amazonaws.com"), + ("us-gov-east-1", "sts.us-gov-east-1.amazonaws.com"), ( - "arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0", - "aws-cn", + "invalid-region-valid-format", + "sts.invalid-region-valid-format.amazonaws.com", ), - ("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"), - ("arn:aws:s3:::my-bucket/my/key", "aws"), - ("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"), - ("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"), - # Edge cases / Invalid inputs - ("invalid-arn", None), - ("arn::service:region:account:resource", None), # Missing partition - ("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present - ("", None), # Empty string - (None, None), # None input - (123, None), # Non-string input ], ) -def test_get_aws_partition_valid_and_invalid_arns(arn, expected_partition): - assert get_aws_partition(arn) == expected_partition +def test_get_aws_sts_hostname_valid_and_invalid_inputs(region, expected_hostname): + assert get_aws_sts_hostname(region) == expected_hostname -@pytest.mark.parametrize( - "region, partition, expected_hostname", - [ - # AWS partition - ("us-east-1", "aws", "sts.us-east-1.amazonaws.com"), - ("eu-west-2", "aws", "sts.eu-west-2.amazonaws.com"), - ("ap-southeast-1", "aws", "sts.ap-southeast-1.amazonaws.com"), - ( - "us-east-1", - "aws", - "sts.us-east-1.amazonaws.com", - ), # Redundant but good for coverage - # AWS China partition - ("cn-north-1", "aws-cn", "sts.cn-north-1.amazonaws.com.cn"), - ("cn-northwest-1", "aws-cn", "sts.cn-northwest-1.amazonaws.com.cn"), - ("", "aws-cn", None), # No global endpoint for 'aws-cn' without region - # AWS GovCloud partition - ("us-gov-west-1", "aws-us-gov", "sts.us-gov-west-1.amazonaws.com"), - ("us-gov-east-1", "aws-us-gov", "sts.us-gov-east-1.amazonaws.com"), - ("", "aws-us-gov", None), # No global endpoint for 'aws-us-gov' without region - # Invalid/Edge cases - ("us-east-1", "unknown-partition", None), # Unknown partition - ("some-region", "invalid-partition", None), # Invalid partition - (None, "aws", None), # None region - ("us-east-1", None, None), # None partition - (123, "aws", None), # Non-string region - ("us-east-1", 456, None), # Non-string partition - ("", "", None), # Empty region and partition - ("us-east-1", "", None), # Empty partition - ( - "invalid-region", - "aws", - "sts.invalid-region.amazonaws.com", - ), # Valid format, invalid region name - ], -) -def test_get_aws_sts_hostname_valid_and_invalid_inputs( - region, partition, expected_hostname -): - assert get_aws_sts_hostname(region, partition) == expected_hostname +@pytest.mark.parametrize("bad_region", [None, "", 123, object()]) +def test_get_aws_sts_hostname_returns_none_on_invalid_input(bad_region): + # Non-string / empty inputs should fail gracefully + assert get_aws_sts_hostname(bad_region) is None # type: ignore[arg-type] -# -- GCP Tests -- +# -- GCP Tests --------------------------------------------------------------- @pytest.mark.parametrize( @@ -284,7 +252,7 @@ def test_explicit_gcp_generates_unique_assertion_content( assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' -# -- Azure Tests -- +# -- Azure Tests ------------------------------------------------------------- @pytest.mark.parametrize( @@ -400,7 +368,7 @@ def test_azure_issuer_prefixes(issuer): ) -# -- Auto-detect Tests -- +# -- Auto-detect Tests ------------------------------------------------------- def test_autodetect_aws_present( @@ -412,7 +380,11 @@ def test_autodetect_aws_present( data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" assert data["PROVIDER"] == "AWS" - verify_aws_token(data["TOKEN"], fake_aws_environment.region) + verify_aws_token( + data["TOKEN"], + fake_aws_environment.region, + expect_session_token=fake_aws_environment.credentials.token is not None, + ) def test_autodetect_gcp_present(fake_gce_metadata_service: FakeGceMetadataService): From f872fd22220d71a9c8fc7a5ec84bc5540c10a9be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 11:14:44 +0200 Subject: [PATCH 06/54] SNOW-2183023: SEssions map access fixed --- test/unit/test_session_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/unit/test_session_manager.py b/test/unit/test_session_manager.py index 67d303d7d2..8f2eb5c0f6 100644 --- a/test/unit/test_session_manager.py +++ b/test/unit/test_session_manager.py @@ -56,9 +56,9 @@ def test_no_url_multiple_sessions(make_session_mock): assert make_session_mock.call_count == 2 - assert list(rest._sessions_map.keys()) == [None] + assert list(rest.sessions_map.keys()) == [None] - session_pool = rest._sessions_map[None] + session_pool = rest.sessions_map[None] assert len(session_pool._idle_sessions) == 2 assert len(session_pool._active_sessions) == 0 @@ -74,11 +74,11 @@ def test_multiple_urls_multiple_sessions(make_session_mock): assert make_session_mock.call_count == 6 - hostnames = list(rest._sessions_map.keys()) + hostnames = list(rest.sessions_map.keys()) for hostname in [hostname_1, hostname_2, None]: assert hostname in hostnames - for pool in rest._sessions_map.values(): + for pool in rest.sessions_map.values(): assert len(pool._idle_sessions) == 2 assert len(pool._active_sessions) == 0 @@ -96,12 +96,12 @@ def test_multiple_urls_reuse_sessions(make_session_mock): # only one session is created and reused thereafter assert make_session_mock.call_count == 3 - hostnames = list(rest._sessions_map.keys()) + hostnames = list(rest.sessions_map.keys()) assert len(hostnames) == 3 for hostname in [hostname_1, hostname_2, None]: assert hostname in hostnames - for pool in rest._sessions_map.values(): + for pool in rest.sessions_map.values(): assert len(pool._idle_sessions) == 1 assert len(pool._active_sessions) == 0 From 95ace1f0c66140ce8000f476a43548957d1ad036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 21:12:11 +0200 Subject: [PATCH 07/54] SNOW-2183023: use pooling --- src/snowflake/connector/network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 48d8b8f0bd..84652205fa 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -457,7 +457,11 @@ def __init__( self._connection = connection self._lock_token = Lock() self._session_manager = SessionManager( - use_pooling=not self._connection.disable_request_pooling, + use_pooling=( + not self._connection.disable_request_pooling + if self._connection + else True + ), adapter_factory=adapter_factory, ) From 53087d43a382245fd435c0f81683a2a5d36f2891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 21:42:43 +0200 Subject: [PATCH 08/54] SNOW-2183023: refactored --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 + src/snowflake/connector/http_client.py | 103 +++++++++++ src/snowflake/connector/network.py | 145 +--------------- src/snowflake/connector/session_manager.py | 162 ++++++++++++++++++ src/snowflake/connector/wif_util.py | 77 ++++++--- 6 files changed, 334 insertions(+), 165 deletions(-) create mode 100644 src/snowflake/connector/http_client.py create mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 3c80c965e4..7f8ab60718 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,10 +74,13 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, **kwargs: typing.Any) -> None: + def prepare(self, *, conn, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, self.entra_resource, self.token + self.provider, + self.entra_resource, + self.token, + session_manager=conn.session_manager if conn else None, ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 963e04ee8a..ffc193df37 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,6 +712,11 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest + @property + def session_manager(self): + """Access to the connection's SessionManager for making HTTP requests.""" + return self._rest.session_manager if self._rest else None + @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py new file mode 100644 index 0000000000..1c802c2439 --- /dev/null +++ b/src/snowflake/connector/http_client.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from typing import Any, Mapping + +from .session_manager import SessionManager +from .vendored.requests import Response + +logger = logging.getLogger(__name__) + + +class HttpClient: + """HTTP client that uses SessionManager for connection pooling and adapter management.""" + + def __init__(self, session_manager: SessionManager): + """Initialize HttpClient with a SessionManager. + + Args: + session_manager: SessionManager instance to use for all requests + """ + self.session_manager = session_manager + + def request( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + use_pooling: bool | None = None, + **kwargs: Any, + ) -> Response: + """Make an HTTP request using the configured SessionManager. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + """ + mgr = ( + self.session_manager + if use_pooling is None + else self.session_manager.clone(use_pooling=use_pooling) + ) + + with mgr.use_session(url) as session: + return session.request( + method=method.upper(), + url=url, + headers=headers, + timeout=timeout_sec, + **kwargs, + ) + + +# Convenience function for backwards compatibility and simple usage +def request( + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + session_manager: SessionManager | None = None, + use_pooling: bool | None = None, + **kwargs: Any, +) -> Response: + """Convenience function for making HTTP requests. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + session_manager: SessionManager instance to use (required) + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + + Raises: + ValueError: If session_manager is None + """ + if session_manager is None: + raise ValueError( + "session_manager is required - no default session manager available" + ) + + client = HttpClient(session_manager) + return client.request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + use_pooling=use_pooling, + **kwargs, + ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 84652205fa..fbea591258 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,16 +1,12 @@ #!/usr/bin/env python from __future__ import annotations -import collections -import contextlib import gzip -import itertools import json import logging import re import time import uuid -from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -18,10 +14,6 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest -from snowflake.connector.vendored.urllib3.connectionpool import ( - HTTPConnectionPool, - HTTPSConnectionPool, -) from . import ssl_wrap_socket from .compat import ( @@ -84,6 +76,7 @@ ServiceUnavailableError, TooManyRequests, ) +from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -96,19 +89,16 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response, Session +from .vendored.requests import Response from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, - InvalidProxyURL, ReadTimeout, SSLError, ) -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError -from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -248,42 +238,6 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: OrderedDict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - class RetryRequest(Exception): """Signal to retry request.""" @@ -334,101 +288,6 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py new file mode 100644 index 0000000000..85b46ccfd4 --- /dev/null +++ b/src/snowflake/connector/session_manager.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import collections +import contextlib +import itertools +import logging +from typing import TYPE_CHECKING, Callable + +from .compat import urlparse +from .vendored import requests +from .vendored.requests import Session +from .vendored.requests.adapters import HTTPAdapter +from .vendored.requests.exceptions import InvalidProxyURL +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy +from .vendored.urllib3.poolmanager import ProxyManager +from .vendored.urllib3.util.url import parse_url + +if TYPE_CHECKING: + from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool + +logger = logging.getLogger(__name__) + +# requests parameters +REQUESTS_RETRY = 1 # requests library builtin retry + + +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: dict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + def clone(self, *, use_pooling: bool | None = None) -> SessionManager: + """Return an independent manager that reuses the adapter_factory.""" + return SessionManager( + use_pooling=self._use_pooling if use_pooling is None else use_pooling, + adapter_factory=self._adapter_factory, + ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 21a7142c30..df6e7cea57 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,6 +15,8 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError +from .http_client import request as http_request +from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -72,15 +74,28 @@ class AwsCredentials: def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, + url: str, + headers: dict, + timeout_sec: int = 3, + session_manager: SessionManager | None = None, ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + # If no session_manager provided, create a basic one for this call + if session_manager is None: + session_manager = SessionManager(use_pooling=False) + + res: Response = http_request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + session_manager=session_manager, + use_pooling=False, # IMDS calls are rare → don't pollute pool ) if not res.ok: return None @@ -119,16 +134,19 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token() -> str | None: +def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials() -> AwsCredentials | None: +def get_aws_credentials( + session_manager: SessionManager | None = None, +) -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -143,7 +161,7 @@ def get_aws_credentials() -> AwsCredentials | None: # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -155,6 +173,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -170,6 +189,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -188,14 +208,14 @@ def get_aws_credentials() -> AwsCredentials | None: return None -def get_aws_region() -> str | None: +def get_aws_region(session_manager: SessionManager | None = None) -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -207,6 +227,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -215,6 +236,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -345,17 +367,19 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation() -> WorkloadIdentityAttestation | None: +def create_aws_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials() + credentials = get_aws_credentials(session_manager) if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region() + region = get_aws_region(session_manager) if not region: logger.debug("No AWS region was found.") return None @@ -395,7 +419,9 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: ) -def create_gcp_attestation() -> WorkloadIdentityAttestation | None: +def create_gcp_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -406,6 +432,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: headers={ "Metadata-Flavor": "Google", }, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -428,6 +455,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: def create_azure_attestation( snowflake_entra_resource: str, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -461,6 +489,7 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -511,7 +540,9 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, token: str | None = None + entra_resource: str, + token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -521,15 +552,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) if attestation: return attestation - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) if attestation: return attestation - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) if attestation: return attestation @@ -540,6 +571,7 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -549,18 +581,23 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + session_manager = ( + session_manager.clone() if session_manager else SessionManager(use_pooling=True) + ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation(entra_resource, token) + attestation = create_autodetect_attestation( + entra_resource, token, session_manager + ) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From 23e338bcacd9eba1af47ebf9b5847f7c10d445c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:03:47 +0200 Subject: [PATCH 09/54] SNOW-2183023: fixed fake env --- test/csp_helpers.py | 113 ++++++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 42 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index ac35336166..c2efce1a94 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -import datetime import json import logging import os @@ -9,12 +8,13 @@ from urllib.parse import parse_qs, urlparse import jwt -from botocore.awsrequest import AWSRequest -from botocore.credentials import Credentials from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError from snowflake.connector.vendored.requests.models import Response +# NEW: import the light-weight creds class from the refactored util +from snowflake.connector.wif_util import AwsCredentials + logger = logging.getLogger(__name__) @@ -47,6 +47,9 @@ def build_response(content: bytes, status_code: int = 200) -> Response: return response +# --------------------------------------------------------------------------- # +# Generic metadata-service test harness # +# --------------------------------------------------------------------------- # class FakeMetadataService(ABC): """Base class for fake metadata service implementations.""" @@ -244,19 +247,20 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeAwsEnvironment: - """Emulates the AWS environment-specific functions used in wif_util.py. - - Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so - emulating them here would be complex and fragile. Instead, we emulate the higher-level functions - called by the connector code. - """ + """Emulates the AWS environment-specific helpers now used in wif_util.py.""" def __init__(self): - # Defaults used for generating a token. Can be overriden in individual tests. self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.region = "us-east-1" - self.credentials = Credentials(access_key="ak", secret_key="sk") + self.credentials: AwsCredentials | None = AwsCredentials( + access_key="ak", + secret_key="sk", + token="SESSION_TOKEN", + ) + # --------------------------------------------------------------------- # + # Helper getters (used as side-effects for patching) # + # --------------------------------------------------------------------- # def get_region(self): return self.region @@ -266,43 +270,68 @@ def get_arn(self): def get_credentials(self): return self.credentials - def sign_request(self, request: AWSRequest): - request.headers.add_header("X-Amz-Date", datetime.time().isoformat()) - request.headers.add_header("X-Amz-Security-Token", "") - request.headers.add_header( - "Authorization", - f"AWS4-HMAC-SHA256 Credential=, SignedHeaders={';'.join(request.headers.keys())}, Signature=", - ) - + # --------------------------------------------------------------------- # + # Context-manager patching # + # --------------------------------------------------------------------- # def __enter__(self): - # Patch the relevant functions to do what we want. - self.patchers = [] - self.patchers.append( + # Stash current env so we can restore later + self._prev_env = { + "AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"), + "AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"), + "AWS_SESSION_TOKEN": os.environ.get("AWS_SESSION_TOKEN"), + "AWS_REGION": os.environ.get("AWS_REGION"), + } + + # Expose creds & region via env vars (preferred path in new util) + if self.credentials: + os.environ["AWS_ACCESS_KEY_ID"] = self.credentials.access_key + os.environ["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key + if self.credentials.token: + os.environ["AWS_SESSION_TOKEN"] = self.credentials.token + os.environ["AWS_REGION"] = self.region + + self.patchers: list[mock._patch] = [ + # Force util helpers to return our fake data mock.patch( - "boto3.session.Session.get_credentials", + "snowflake.connector.wif_util.get_aws_credentials", side_effect=self.get_credentials, - ) - ) - self.patchers.append( - mock.patch( - "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request - ) - ) - self.patchers.append( + ), mock.patch( "snowflake.connector.wif_util.get_aws_region", side_effect=self.get_region, - ) - ) - self.patchers.append( + ), + # _imds_v2_token() must not hit the network + ( + mock.patch( + "snowflake.connector.wif_util._imds_v2_token", + return_value=None, + ) + if hasattr( + __import__( + "snowflake.connector.wif_util", fromlist=["get_aws_arn"] + ), + "get_aws_arn", + ) + else mock.patch.dict({}, {}, clear=True) + ), # dummy, no-op patch + # Block any accidental real HTTP calls via urllib3 mock.patch( - "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn - ) - ) - for patcher in self.patchers: - patcher.__enter__() + "urllib3.connection.HTTPConnection.request", + side_effect=ConnectTimeout(), + ), + ] + + for p in self.patchers: + p.__enter__() return self - def __exit__(self, *args, **kwargs): - for patcher in self.patchers: - patcher.__exit__(*args, **kwargs) + def __exit__(self, *args): + for p in self.patchers: + p.__exit__(*args) + + # Restore previous env-vars + for key, val in self._prev_env.items(): + if val is None: + os.environ.pop(key, None) + else: + os.environ[key] = val From 3faa9423e4ecfd590b11904f3dbb44e7b7e1d5ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:04:02 +0200 Subject: [PATCH 10/54] Revert "SNOW-2183023: refactored" This reverts commit 53087d43a382245fd435c0f81683a2a5d36f2891. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 - src/snowflake/connector/http_client.py | 103 ----------- src/snowflake/connector/network.py | 145 +++++++++++++++- src/snowflake/connector/session_manager.py | 162 ------------------ src/snowflake/connector/wif_util.py | 77 +++------ 6 files changed, 165 insertions(+), 334 deletions(-) delete mode 100644 src/snowflake/connector/http_client.py delete mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 7f8ab60718..3c80c965e4 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,13 +74,10 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, *, conn, **kwargs: typing.Any) -> None: + def prepare(self, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, - self.entra_resource, - self.token, - session_manager=conn.session_manager if conn else None, + self.provider, self.entra_resource, self.token ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index ffc193df37..963e04ee8a 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,11 +712,6 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest - @property - def session_manager(self): - """Access to the connection's SessionManager for making HTTP requests.""" - return self._rest.session_manager if self._rest else None - @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py deleted file mode 100644 index 1c802c2439..0000000000 --- a/src/snowflake/connector/http_client.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Mapping - -from .session_manager import SessionManager -from .vendored.requests import Response - -logger = logging.getLogger(__name__) - - -class HttpClient: - """HTTP client that uses SessionManager for connection pooling and adapter management.""" - - def __init__(self, session_manager: SessionManager): - """Initialize HttpClient with a SessionManager. - - Args: - session_manager: SessionManager instance to use for all requests - """ - self.session_manager = session_manager - - def request( - self, - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - use_pooling: bool | None = None, - **kwargs: Any, - ) -> Response: - """Make an HTTP request using the configured SessionManager. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - """ - mgr = ( - self.session_manager - if use_pooling is None - else self.session_manager.clone(use_pooling=use_pooling) - ) - - with mgr.use_session(url) as session: - return session.request( - method=method.upper(), - url=url, - headers=headers, - timeout=timeout_sec, - **kwargs, - ) - - -# Convenience function for backwards compatibility and simple usage -def request( - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - session_manager: SessionManager | None = None, - use_pooling: bool | None = None, - **kwargs: Any, -) -> Response: - """Convenience function for making HTTP requests. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - session_manager: SessionManager instance to use (required) - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - - Raises: - ValueError: If session_manager is None - """ - if session_manager is None: - raise ValueError( - "session_manager is required - no default session manager available" - ) - - client = HttpClient(session_manager) - return client.request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - use_pooling=use_pooling, - **kwargs, - ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index fbea591258..84652205fa 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,12 +1,16 @@ #!/usr/bin/env python from __future__ import annotations +import collections +import contextlib import gzip +import itertools import json import logging import re import time import uuid +from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -14,6 +18,10 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest +from snowflake.connector.vendored.urllib3.connectionpool import ( + HTTPConnectionPool, + HTTPSConnectionPool, +) from . import ssl_wrap_socket from .compat import ( @@ -76,7 +84,6 @@ ServiceUnavailableError, TooManyRequests, ) -from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -89,16 +96,19 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response +from .vendored.requests import Response, Session from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, + InvalidProxyURL, ReadTimeout, SSLError, ) +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError +from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -238,6 +248,42 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: OrderedDict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + class RetryRequest(Exception): """Signal to retry request.""" @@ -288,6 +334,101 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py deleted file mode 100644 index 85b46ccfd4..0000000000 --- a/src/snowflake/connector/session_manager.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -import collections -import contextlib -import itertools -import logging -from typing import TYPE_CHECKING, Callable - -from .compat import urlparse -from .vendored import requests -from .vendored.requests import Session -from .vendored.requests.adapters import HTTPAdapter -from .vendored.requests.exceptions import InvalidProxyURL -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy -from .vendored.urllib3.poolmanager import ProxyManager -from .vendored.urllib3.util.url import parse_url - -if TYPE_CHECKING: - from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool - -logger = logging.getLogger(__name__) - -# requests parameters -REQUESTS_RETRY = 1 # requests library builtin retry - - -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: dict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - def clone(self, *, use_pooling: bool | None = None) -> SessionManager: - """Return an independent manager that reuses the adapter_factory.""" - return SessionManager( - use_pooling=self._use_pooling if use_pooling is None else use_pooling, - adapter_factory=self._adapter_factory, - ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index df6e7cea57..21a7142c30 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,8 +15,6 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError -from .http_client import request as http_request -from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -74,28 +72,15 @@ class AwsCredentials: def try_metadata_service_call( - method: str, - url: str, - headers: dict, - timeout_sec: int = 3, - session_manager: SessionManager | None = None, + method: str, url: str, headers: dict, timeout_sec: int = 3 ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - # If no session_manager provided, create a basic one for this call - if session_manager is None: - session_manager = SessionManager(use_pooling=False) - - res: Response = http_request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - session_manager=session_manager, - use_pooling=False, # IMDS calls are rare → don't pollute pool + res: Response = requests.request( + method=method, url=url, headers=headers, timeout=timeout_sec ) if not res.ok: return None @@ -134,19 +119,16 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: +def _imds_v2_token() -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, - session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials( - session_manager: SessionManager | None = None, -) -> AwsCredentials | None: +def get_aws_credentials() -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -161,7 +143,7 @@ def get_aws_credentials( # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -173,7 +155,6 @@ def get_aws_credentials( method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -189,7 +170,6 @@ def get_aws_credentials( method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -208,14 +188,14 @@ def get_aws_credentials( return None -def get_aws_region(session_manager: SessionManager | None = None) -> str | None: +def get_aws_region() -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -227,7 +207,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -236,7 +215,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -367,19 +345,17 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_aws_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials(session_manager) + credentials = get_aws_credentials() if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region(session_manager) + region = get_aws_region() if not region: logger.debug("No AWS region was found.") return None @@ -419,9 +395,7 @@ def create_aws_attestation( ) -def create_gcp_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_gcp_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -432,7 +406,6 @@ def create_gcp_attestation( headers={ "Metadata-Flavor": "Google", }, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -455,7 +428,6 @@ def create_gcp_attestation( def create_azure_attestation( snowflake_entra_resource: str, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -489,7 +461,6 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -540,9 +511,7 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, - token: str | None = None, - session_manager: SessionManager | None = None, + entra_resource: str, token: str | None = None ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -552,15 +521,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) if attestation: return attestation - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() if attestation: return attestation - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() if attestation: return attestation @@ -571,7 +540,6 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -581,23 +549,18 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE - session_manager = ( - session_manager.clone() if session_manager else SessionManager(use_pooling=True) - ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation( - entra_resource, token, session_manager - ) + attestation = create_autodetect_attestation(entra_resource, token) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From 43f7868f07b250f8e4828b7dd14d07d627c49e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:04:32 +0200 Subject: [PATCH 11/54] Reapply "SNOW-2183023: refactored" This reverts commit 3faa9423e4ecfd590b11904f3dbb44e7b7e1d5ad. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 + src/snowflake/connector/http_client.py | 103 +++++++++++ src/snowflake/connector/network.py | 145 +--------------- src/snowflake/connector/session_manager.py | 162 ++++++++++++++++++ src/snowflake/connector/wif_util.py | 77 ++++++--- 6 files changed, 334 insertions(+), 165 deletions(-) create mode 100644 src/snowflake/connector/http_client.py create mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 3c80c965e4..7f8ab60718 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,10 +74,13 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, **kwargs: typing.Any) -> None: + def prepare(self, *, conn, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, self.entra_resource, self.token + self.provider, + self.entra_resource, + self.token, + session_manager=conn.session_manager if conn else None, ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 963e04ee8a..ffc193df37 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,6 +712,11 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest + @property + def session_manager(self): + """Access to the connection's SessionManager for making HTTP requests.""" + return self._rest.session_manager if self._rest else None + @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py new file mode 100644 index 0000000000..1c802c2439 --- /dev/null +++ b/src/snowflake/connector/http_client.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from typing import Any, Mapping + +from .session_manager import SessionManager +from .vendored.requests import Response + +logger = logging.getLogger(__name__) + + +class HttpClient: + """HTTP client that uses SessionManager for connection pooling and adapter management.""" + + def __init__(self, session_manager: SessionManager): + """Initialize HttpClient with a SessionManager. + + Args: + session_manager: SessionManager instance to use for all requests + """ + self.session_manager = session_manager + + def request( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + use_pooling: bool | None = None, + **kwargs: Any, + ) -> Response: + """Make an HTTP request using the configured SessionManager. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + """ + mgr = ( + self.session_manager + if use_pooling is None + else self.session_manager.clone(use_pooling=use_pooling) + ) + + with mgr.use_session(url) as session: + return session.request( + method=method.upper(), + url=url, + headers=headers, + timeout=timeout_sec, + **kwargs, + ) + + +# Convenience function for backwards compatibility and simple usage +def request( + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + session_manager: SessionManager | None = None, + use_pooling: bool | None = None, + **kwargs: Any, +) -> Response: + """Convenience function for making HTTP requests. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + session_manager: SessionManager instance to use (required) + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + + Raises: + ValueError: If session_manager is None + """ + if session_manager is None: + raise ValueError( + "session_manager is required - no default session manager available" + ) + + client = HttpClient(session_manager) + return client.request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + use_pooling=use_pooling, + **kwargs, + ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 84652205fa..fbea591258 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,16 +1,12 @@ #!/usr/bin/env python from __future__ import annotations -import collections -import contextlib import gzip -import itertools import json import logging import re import time import uuid -from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -18,10 +14,6 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest -from snowflake.connector.vendored.urllib3.connectionpool import ( - HTTPConnectionPool, - HTTPSConnectionPool, -) from . import ssl_wrap_socket from .compat import ( @@ -84,6 +76,7 @@ ServiceUnavailableError, TooManyRequests, ) +from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -96,19 +89,16 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response, Session +from .vendored.requests import Response from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, - InvalidProxyURL, ReadTimeout, SSLError, ) -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError -from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -248,42 +238,6 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: OrderedDict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - class RetryRequest(Exception): """Signal to retry request.""" @@ -334,101 +288,6 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py new file mode 100644 index 0000000000..85b46ccfd4 --- /dev/null +++ b/src/snowflake/connector/session_manager.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import collections +import contextlib +import itertools +import logging +from typing import TYPE_CHECKING, Callable + +from .compat import urlparse +from .vendored import requests +from .vendored.requests import Session +from .vendored.requests.adapters import HTTPAdapter +from .vendored.requests.exceptions import InvalidProxyURL +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy +from .vendored.urllib3.poolmanager import ProxyManager +from .vendored.urllib3.util.url import parse_url + +if TYPE_CHECKING: + from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool + +logger = logging.getLogger(__name__) + +# requests parameters +REQUESTS_RETRY = 1 # requests library builtin retry + + +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: dict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + def clone(self, *, use_pooling: bool | None = None) -> SessionManager: + """Return an independent manager that reuses the adapter_factory.""" + return SessionManager( + use_pooling=self._use_pooling if use_pooling is None else use_pooling, + adapter_factory=self._adapter_factory, + ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 21a7142c30..df6e7cea57 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,6 +15,8 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError +from .http_client import request as http_request +from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -72,15 +74,28 @@ class AwsCredentials: def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, + url: str, + headers: dict, + timeout_sec: int = 3, + session_manager: SessionManager | None = None, ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + # If no session_manager provided, create a basic one for this call + if session_manager is None: + session_manager = SessionManager(use_pooling=False) + + res: Response = http_request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + session_manager=session_manager, + use_pooling=False, # IMDS calls are rare → don't pollute pool ) if not res.ok: return None @@ -119,16 +134,19 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token() -> str | None: +def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials() -> AwsCredentials | None: +def get_aws_credentials( + session_manager: SessionManager | None = None, +) -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -143,7 +161,7 @@ def get_aws_credentials() -> AwsCredentials | None: # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -155,6 +173,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -170,6 +189,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -188,14 +208,14 @@ def get_aws_credentials() -> AwsCredentials | None: return None -def get_aws_region() -> str | None: +def get_aws_region(session_manager: SessionManager | None = None) -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -207,6 +227,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -215,6 +236,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -345,17 +367,19 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation() -> WorkloadIdentityAttestation | None: +def create_aws_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials() + credentials = get_aws_credentials(session_manager) if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region() + region = get_aws_region(session_manager) if not region: logger.debug("No AWS region was found.") return None @@ -395,7 +419,9 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: ) -def create_gcp_attestation() -> WorkloadIdentityAttestation | None: +def create_gcp_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -406,6 +432,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: headers={ "Metadata-Flavor": "Google", }, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -428,6 +455,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: def create_azure_attestation( snowflake_entra_resource: str, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -461,6 +489,7 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -511,7 +540,9 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, token: str | None = None + entra_resource: str, + token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -521,15 +552,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) if attestation: return attestation - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) if attestation: return attestation - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) if attestation: return attestation @@ -540,6 +571,7 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -549,18 +581,23 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + session_manager = ( + session_manager.clone() if session_manager else SessionManager(use_pooling=True) + ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation(entra_resource, token) + attestation = create_autodetect_attestation( + entra_resource, token, session_manager + ) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From 2241969584c9eeba41e80ea379257bede3b0fd2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:24:16 +0200 Subject: [PATCH 12/54] SNOW-2183023: fixed csp_helper --- test/csp_helpers.py | 118 +++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 67 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index c2efce1a94..05509bab85 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import annotations + import json import logging import os @@ -11,40 +13,32 @@ from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError from snowflake.connector.vendored.requests.models import Response - -# NEW: import the light-weight creds class from the refactored util -from snowflake.connector.wif_util import AwsCredentials +from snowflake.connector.wif_util import AwsCredentials # light-weight creds logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- # +# Helpers # +# --------------------------------------------------------------------------- # def gen_dummy_id_token( - sub="test-subject", iss="test-issuer", aud="snowflakecomputing.com" + sub: str = "test-subject", + iss: str = "test-issuer", + aud: str = "snowflakecomputing.com", ) -> str: """Generates a dummy ID token using the given subject and issuer.""" now = int(time()) - key = "secret" - payload = { - "sub": sub, - "iss": iss, - "aud": aud, - "iat": now, - "exp": now + 60 * 60, - } + payload = {"sub": sub, "iss": iss, "aud": aud, "iat": now, "exp": now + 3600} logger.debug(f"Generating dummy token with the following claims:\n{str(payload)}") - return jwt.encode( - payload=payload, - key=key, - algorithm="HS256", - ) + return jwt.encode(payload, key="secret", algorithm="HS256") def build_response(content: bytes, status_code: int = 200) -> Response: """Builds a requests.Response object with the given status code and content.""" - response = Response() - response.status_code = status_code - response._content = content - return response + resp = Response() + resp.status_code = status_code + resp._content = content + return resp # --------------------------------------------------------------------------- # @@ -246,52 +240,52 @@ def handle_request(self, method, parsed_url, headers, timeout): return build_response(self.token.encode("utf-8")) +# --------------------------------------------------------------------------- # +# AWS environment fake # +# --------------------------------------------------------------------------- # class FakeAwsEnvironment: - """Emulates the AWS environment-specific helpers now used in wif_util.py.""" + """Emulates the AWS environment-specific helpers used in wif_util.py.""" - def __init__(self): - self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" + def __init__(self) -> None: self.region = "us-east-1" self.credentials: AwsCredentials | None = AwsCredentials( - access_key="ak", - secret_key="sk", - token="SESSION_TOKEN", + access_key="ak", secret_key="sk", token="SESSION_TOKEN" ) - # --------------------------------------------------------------------- # - # Helper getters (used as side-effects for patching) # - # --------------------------------------------------------------------- # - def get_region(self): + # ------------------------------------------------------------------ # + # Helper getters (swallow any extra args like session_manager) # + # ------------------------------------------------------------------ # + def get_region(self, *_, **__) -> str | None: return self.region - def get_arn(self): - return self.arn - - def get_credentials(self): + def get_credentials(self, *_, **__) -> AwsCredentials | None: return self.credentials - # --------------------------------------------------------------------- # - # Context-manager patching # - # --------------------------------------------------------------------- # + # ------------------------------------------------------------------ # + # Context-manager patching # + # ------------------------------------------------------------------ # def __enter__(self): - # Stash current env so we can restore later + # Save & override env vars self._prev_env = { - "AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"), - "AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"), - "AWS_SESSION_TOKEN": os.environ.get("AWS_SESSION_TOKEN"), - "AWS_REGION": os.environ.get("AWS_REGION"), + k: os.environ.get(k) + for k in ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_REGION", + ) } - - # Expose creds & region via env vars (preferred path in new util) if self.credentials: - os.environ["AWS_ACCESS_KEY_ID"] = self.credentials.access_key - os.environ["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key - if self.credentials.token: - os.environ["AWS_SESSION_TOKEN"] = self.credentials.token + os.environ.update( + { + "AWS_ACCESS_KEY_ID": self.credentials.access_key, + "AWS_SECRET_ACCESS_KEY": self.credentials.secret_key, + "AWS_SESSION_TOKEN": (self.credentials.token or ""), + } + ) os.environ["AWS_REGION"] = self.region self.patchers: list[mock._patch] = [ - # Force util helpers to return our fake data mock.patch( "snowflake.connector.wif_util.get_aws_credentials", side_effect=self.get_credentials, @@ -300,21 +294,12 @@ def __enter__(self): "snowflake.connector.wif_util.get_aws_region", side_effect=self.get_region, ), - # _imds_v2_token() must not hit the network - ( - mock.patch( - "snowflake.connector.wif_util._imds_v2_token", - return_value=None, - ) - if hasattr( - __import__( - "snowflake.connector.wif_util", fromlist=["get_aws_arn"] - ), - "get_aws_arn", - ) - else mock.patch.dict({}, {}, clear=True) - ), # dummy, no-op patch - # Block any accidental real HTTP calls via urllib3 + # Avoid real network for IMDS token + mock.patch( + "snowflake.connector.wif_util._imds_v2_token", + return_value=None, + ), + # Block stray HTTP traffic mock.patch( "urllib3.connection.HTTPConnection.request", side_effect=ConnectTimeout(), @@ -328,8 +313,7 @@ def __enter__(self): def __exit__(self, *args): for p in self.patchers: p.__exit__(*args) - - # Restore previous env-vars + # Restore original env vars for key, val in self._prev_env.items(): if val is None: os.environ.pop(key, None) From 4f0785c0341d3f2ade44c762ce40a37fd410771c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:24:26 +0200 Subject: [PATCH 13/54] Revert "Reapply "SNOW-2183023: refactored"" This reverts commit 43f7868f07b250f8e4828b7dd14d07d627c49e98. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 - src/snowflake/connector/http_client.py | 103 ----------- src/snowflake/connector/network.py | 145 +++++++++++++++- src/snowflake/connector/session_manager.py | 162 ------------------ src/snowflake/connector/wif_util.py | 77 +++------ 6 files changed, 165 insertions(+), 334 deletions(-) delete mode 100644 src/snowflake/connector/http_client.py delete mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 7f8ab60718..3c80c965e4 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,13 +74,10 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, *, conn, **kwargs: typing.Any) -> None: + def prepare(self, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, - self.entra_resource, - self.token, - session_manager=conn.session_manager if conn else None, + self.provider, self.entra_resource, self.token ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index ffc193df37..963e04ee8a 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,11 +712,6 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest - @property - def session_manager(self): - """Access to the connection's SessionManager for making HTTP requests.""" - return self._rest.session_manager if self._rest else None - @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py deleted file mode 100644 index 1c802c2439..0000000000 --- a/src/snowflake/connector/http_client.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Mapping - -from .session_manager import SessionManager -from .vendored.requests import Response - -logger = logging.getLogger(__name__) - - -class HttpClient: - """HTTP client that uses SessionManager for connection pooling and adapter management.""" - - def __init__(self, session_manager: SessionManager): - """Initialize HttpClient with a SessionManager. - - Args: - session_manager: SessionManager instance to use for all requests - """ - self.session_manager = session_manager - - def request( - self, - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - use_pooling: bool | None = None, - **kwargs: Any, - ) -> Response: - """Make an HTTP request using the configured SessionManager. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - """ - mgr = ( - self.session_manager - if use_pooling is None - else self.session_manager.clone(use_pooling=use_pooling) - ) - - with mgr.use_session(url) as session: - return session.request( - method=method.upper(), - url=url, - headers=headers, - timeout=timeout_sec, - **kwargs, - ) - - -# Convenience function for backwards compatibility and simple usage -def request( - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - session_manager: SessionManager | None = None, - use_pooling: bool | None = None, - **kwargs: Any, -) -> Response: - """Convenience function for making HTTP requests. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - session_manager: SessionManager instance to use (required) - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - - Raises: - ValueError: If session_manager is None - """ - if session_manager is None: - raise ValueError( - "session_manager is required - no default session manager available" - ) - - client = HttpClient(session_manager) - return client.request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - use_pooling=use_pooling, - **kwargs, - ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index fbea591258..84652205fa 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,12 +1,16 @@ #!/usr/bin/env python from __future__ import annotations +import collections +import contextlib import gzip +import itertools import json import logging import re import time import uuid +from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -14,6 +18,10 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest +from snowflake.connector.vendored.urllib3.connectionpool import ( + HTTPConnectionPool, + HTTPSConnectionPool, +) from . import ssl_wrap_socket from .compat import ( @@ -76,7 +84,6 @@ ServiceUnavailableError, TooManyRequests, ) -from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -89,16 +96,19 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response +from .vendored.requests import Response, Session from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, + InvalidProxyURL, ReadTimeout, SSLError, ) +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError +from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -238,6 +248,42 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: OrderedDict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + class RetryRequest(Exception): """Signal to retry request.""" @@ -288,6 +334,101 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py deleted file mode 100644 index 85b46ccfd4..0000000000 --- a/src/snowflake/connector/session_manager.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -import collections -import contextlib -import itertools -import logging -from typing import TYPE_CHECKING, Callable - -from .compat import urlparse -from .vendored import requests -from .vendored.requests import Session -from .vendored.requests.adapters import HTTPAdapter -from .vendored.requests.exceptions import InvalidProxyURL -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy -from .vendored.urllib3.poolmanager import ProxyManager -from .vendored.urllib3.util.url import parse_url - -if TYPE_CHECKING: - from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool - -logger = logging.getLogger(__name__) - -# requests parameters -REQUESTS_RETRY = 1 # requests library builtin retry - - -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: dict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - def clone(self, *, use_pooling: bool | None = None) -> SessionManager: - """Return an independent manager that reuses the adapter_factory.""" - return SessionManager( - use_pooling=self._use_pooling if use_pooling is None else use_pooling, - adapter_factory=self._adapter_factory, - ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index df6e7cea57..21a7142c30 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,8 +15,6 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError -from .http_client import request as http_request -from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -74,28 +72,15 @@ class AwsCredentials: def try_metadata_service_call( - method: str, - url: str, - headers: dict, - timeout_sec: int = 3, - session_manager: SessionManager | None = None, + method: str, url: str, headers: dict, timeout_sec: int = 3 ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - # If no session_manager provided, create a basic one for this call - if session_manager is None: - session_manager = SessionManager(use_pooling=False) - - res: Response = http_request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - session_manager=session_manager, - use_pooling=False, # IMDS calls are rare → don't pollute pool + res: Response = requests.request( + method=method, url=url, headers=headers, timeout=timeout_sec ) if not res.ok: return None @@ -134,19 +119,16 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: +def _imds_v2_token() -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, - session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials( - session_manager: SessionManager | None = None, -) -> AwsCredentials | None: +def get_aws_credentials() -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -161,7 +143,7 @@ def get_aws_credentials( # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -173,7 +155,6 @@ def get_aws_credentials( method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -189,7 +170,6 @@ def get_aws_credentials( method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -208,14 +188,14 @@ def get_aws_credentials( return None -def get_aws_region(session_manager: SessionManager | None = None) -> str | None: +def get_aws_region() -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -227,7 +207,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -236,7 +215,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -367,19 +345,17 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_aws_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials(session_manager) + credentials = get_aws_credentials() if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region(session_manager) + region = get_aws_region() if not region: logger.debug("No AWS region was found.") return None @@ -419,9 +395,7 @@ def create_aws_attestation( ) -def create_gcp_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_gcp_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -432,7 +406,6 @@ def create_gcp_attestation( headers={ "Metadata-Flavor": "Google", }, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -455,7 +428,6 @@ def create_gcp_attestation( def create_azure_attestation( snowflake_entra_resource: str, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -489,7 +461,6 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -540,9 +511,7 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, - token: str | None = None, - session_manager: SessionManager | None = None, + entra_resource: str, token: str | None = None ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -552,15 +521,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) if attestation: return attestation - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() if attestation: return attestation - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() if attestation: return attestation @@ -571,7 +540,6 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -581,23 +549,18 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE - session_manager = ( - session_manager.clone() if session_manager else SessionManager(use_pooling=True) - ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation( - entra_resource, token, session_manager - ) + attestation = create_autodetect_attestation(entra_resource, token) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From e603c1c76b116a26513f333d6d326f9694e67c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Mon, 7 Jul 2025 22:24:46 +0200 Subject: [PATCH 14/54] Reapply "Reapply "SNOW-2183023: refactored"" This reverts commit 4f0785c0341d3f2ade44c762ce40a37fd410771c. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 + src/snowflake/connector/http_client.py | 103 +++++++++++ src/snowflake/connector/network.py | 145 +--------------- src/snowflake/connector/session_manager.py | 162 ++++++++++++++++++ src/snowflake/connector/wif_util.py | 77 ++++++--- 6 files changed, 334 insertions(+), 165 deletions(-) create mode 100644 src/snowflake/connector/http_client.py create mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 3c80c965e4..7f8ab60718 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,10 +74,13 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, **kwargs: typing.Any) -> None: + def prepare(self, *, conn, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, self.entra_resource, self.token + self.provider, + self.entra_resource, + self.token, + session_manager=conn.session_manager if conn else None, ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 963e04ee8a..ffc193df37 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,6 +712,11 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest + @property + def session_manager(self): + """Access to the connection's SessionManager for making HTTP requests.""" + return self._rest.session_manager if self._rest else None + @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py new file mode 100644 index 0000000000..1c802c2439 --- /dev/null +++ b/src/snowflake/connector/http_client.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from typing import Any, Mapping + +from .session_manager import SessionManager +from .vendored.requests import Response + +logger = logging.getLogger(__name__) + + +class HttpClient: + """HTTP client that uses SessionManager for connection pooling and adapter management.""" + + def __init__(self, session_manager: SessionManager): + """Initialize HttpClient with a SessionManager. + + Args: + session_manager: SessionManager instance to use for all requests + """ + self.session_manager = session_manager + + def request( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + use_pooling: bool | None = None, + **kwargs: Any, + ) -> Response: + """Make an HTTP request using the configured SessionManager. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + """ + mgr = ( + self.session_manager + if use_pooling is None + else self.session_manager.clone(use_pooling=use_pooling) + ) + + with mgr.use_session(url) as session: + return session.request( + method=method.upper(), + url=url, + headers=headers, + timeout=timeout_sec, + **kwargs, + ) + + +# Convenience function for backwards compatibility and simple usage +def request( + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + session_manager: SessionManager | None = None, + use_pooling: bool | None = None, + **kwargs: Any, +) -> Response: + """Convenience function for making HTTP requests. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + session_manager: SessionManager instance to use (required) + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + + Raises: + ValueError: If session_manager is None + """ + if session_manager is None: + raise ValueError( + "session_manager is required - no default session manager available" + ) + + client = HttpClient(session_manager) + return client.request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + use_pooling=use_pooling, + **kwargs, + ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 84652205fa..fbea591258 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,16 +1,12 @@ #!/usr/bin/env python from __future__ import annotations -import collections -import contextlib import gzip -import itertools import json import logging import re import time import uuid -from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -18,10 +14,6 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest -from snowflake.connector.vendored.urllib3.connectionpool import ( - HTTPConnectionPool, - HTTPSConnectionPool, -) from . import ssl_wrap_socket from .compat import ( @@ -84,6 +76,7 @@ ServiceUnavailableError, TooManyRequests, ) +from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -96,19 +89,16 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response, Session +from .vendored.requests import Response from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, - InvalidProxyURL, ReadTimeout, SSLError, ) -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError -from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -248,42 +238,6 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: OrderedDict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - class RetryRequest(Exception): """Signal to retry request.""" @@ -334,101 +288,6 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py new file mode 100644 index 0000000000..85b46ccfd4 --- /dev/null +++ b/src/snowflake/connector/session_manager.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import collections +import contextlib +import itertools +import logging +from typing import TYPE_CHECKING, Callable + +from .compat import urlparse +from .vendored import requests +from .vendored.requests import Session +from .vendored.requests.adapters import HTTPAdapter +from .vendored.requests.exceptions import InvalidProxyURL +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy +from .vendored.urllib3.poolmanager import ProxyManager +from .vendored.urllib3.util.url import parse_url + +if TYPE_CHECKING: + from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool + +logger = logging.getLogger(__name__) + +# requests parameters +REQUESTS_RETRY = 1 # requests library builtin retry + + +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: dict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + def clone(self, *, use_pooling: bool | None = None) -> SessionManager: + """Return an independent manager that reuses the adapter_factory.""" + return SessionManager( + use_pooling=self._use_pooling if use_pooling is None else use_pooling, + adapter_factory=self._adapter_factory, + ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 21a7142c30..df6e7cea57 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,6 +15,8 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError +from .http_client import request as http_request +from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -72,15 +74,28 @@ class AwsCredentials: def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, + url: str, + headers: dict, + timeout_sec: int = 3, + session_manager: SessionManager | None = None, ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + # If no session_manager provided, create a basic one for this call + if session_manager is None: + session_manager = SessionManager(use_pooling=False) + + res: Response = http_request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + session_manager=session_manager, + use_pooling=False, # IMDS calls are rare → don't pollute pool ) if not res.ok: return None @@ -119,16 +134,19 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token() -> str | None: +def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials() -> AwsCredentials | None: +def get_aws_credentials( + session_manager: SessionManager | None = None, +) -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -143,7 +161,7 @@ def get_aws_credentials() -> AwsCredentials | None: # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -155,6 +173,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -170,6 +189,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -188,14 +208,14 @@ def get_aws_credentials() -> AwsCredentials | None: return None -def get_aws_region() -> str | None: +def get_aws_region(session_manager: SessionManager | None = None) -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -207,6 +227,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -215,6 +236,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -345,17 +367,19 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation() -> WorkloadIdentityAttestation | None: +def create_aws_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials() + credentials = get_aws_credentials(session_manager) if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region() + region = get_aws_region(session_manager) if not region: logger.debug("No AWS region was found.") return None @@ -395,7 +419,9 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: ) -def create_gcp_attestation() -> WorkloadIdentityAttestation | None: +def create_gcp_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -406,6 +432,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: headers={ "Metadata-Flavor": "Google", }, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -428,6 +455,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: def create_azure_attestation( snowflake_entra_resource: str, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -461,6 +489,7 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -511,7 +540,9 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, token: str | None = None + entra_resource: str, + token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -521,15 +552,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) if attestation: return attestation - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) if attestation: return attestation - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) if attestation: return attestation @@ -540,6 +571,7 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -549,18 +581,23 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + session_manager = ( + session_manager.clone() if session_manager else SessionManager(use_pooling=True) + ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation(entra_resource, token) + attestation = create_autodetect_attestation( + entra_resource, token, session_manager + ) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From fde7556d600c31fd1da46889f1f12850ea12d67d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 08:25:04 +0200 Subject: [PATCH 15/54] SNOW-2183023: fixed csp_helper and added boto as dev-dep --- setup.cfg | 4 +- test/csp_helpers.py | 150 ++++++++++++++++++++++++++------------------ 2 files changed, 90 insertions(+), 64 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2fd14a0ea8..ee1c38407a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,8 +44,6 @@ python_requires = >=3.9 packages = find_namespace: install_requires = asn1crypto>0.24.0,<2.0.0 - boto3>=1.24 - botocore>=1.24 cffi>=1.9,<2.0.0 cryptography>=3.1.0 pyOpenSSL>=22.0.0,<26.0.0 @@ -92,6 +90,8 @@ development = pytest-timeout pytest-xdist pytzdata + boto3>=1.24 + botocore>=1.24 pandas = pandas>=2.1.2,<3.0.0 pyarrow<19.0.0 diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 05509bab85..d9a395b6f8 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +import datetime import json import logging import os @@ -11,39 +12,46 @@ import jwt +# Boto is left as a development-dependency - to be sure our http requests correspond to the appropriate behavior and old driver tests are passing in the future +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError from snowflake.connector.vendored.requests.models import Response -from snowflake.connector.wif_util import AwsCredentials # light-weight creds +from snowflake.connector.wif_util import AwsCredentials logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- # -# Helpers # -# --------------------------------------------------------------------------- # def gen_dummy_id_token( - sub: str = "test-subject", - iss: str = "test-issuer", - aud: str = "snowflakecomputing.com", + sub="test-subject", iss="test-issuer", aud="snowflakecomputing.com" ) -> str: """Generates a dummy ID token using the given subject and issuer.""" now = int(time()) - payload = {"sub": sub, "iss": iss, "aud": aud, "iat": now, "exp": now + 3600} + key = "secret" + payload = { + "sub": sub, + "iss": iss, + "aud": aud, + "iat": now, + "exp": now + 60 * 60, + } logger.debug(f"Generating dummy token with the following claims:\n{str(payload)}") - return jwt.encode(payload, key="secret", algorithm="HS256") + return jwt.encode( + payload=payload, + key=key, + algorithm="HS256", + ) def build_response(content: bytes, status_code: int = 200) -> Response: """Builds a requests.Response object with the given status code and content.""" - resp = Response() - resp.status_code = status_code - resp._content = content - return resp + response = Response() + response.status_code = status_code + response._content = content + return response -# --------------------------------------------------------------------------- # -# Generic metadata-service test harness # -# --------------------------------------------------------------------------- # class FakeMetadataService(ABC): """Base class for fake metadata service implementations.""" @@ -240,33 +248,48 @@ def handle_request(self, method, parsed_url, headers, timeout): return build_response(self.token.encode("utf-8")) -# --------------------------------------------------------------------------- # -# AWS environment fake # -# --------------------------------------------------------------------------- # class FakeAwsEnvironment: - """Emulates the AWS environment-specific helpers used in wif_util.py.""" + """Emulates AWS for both the legacy boto path and the new SDK-free helpers.""" - def __init__(self) -> None: + def __init__(self): + # Defaults used for generating a token. Can be overriden in individual tests. + self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-abc123" self.region = "us-east-1" - self.credentials: AwsCredentials | None = AwsCredentials( - access_key="ak", secret_key="sk", token="SESSION_TOKEN" + + # boto-style creds (used by old tests / patches) + self.boto_creds = Credentials("AKIA123", "SECRET123", token="SESSION_TOKEN") + + # util-style creds (returned by get_aws_credentials) + self.util_creds = AwsCredentials( + access_key=self.boto_creds.access_key, + secret_key=self.boto_creds.secret_key, + token=self.boto_creds.token, ) - # ------------------------------------------------------------------ # - # Helper getters (swallow any extra args like session_manager) # - # ------------------------------------------------------------------ # - def get_region(self, *_, **__) -> str | None: + def get_region(self, *_, **__) -> str: return self.region - def get_credentials(self, *_, **__) -> AwsCredentials | None: - return self.credentials + def get_arn(self, *_, **__) -> str: + return self.arn + + def get_boto_credentials(self, *_, **__) -> Credentials | None: + return self.boto_creds + + def get_aws_credentials(self, *_, **__) -> AwsCredentials | None: + return self.util_creds + + def sign_request(self, request: AWSRequest): + request.headers.add_header("X-Amz-Date", datetime.datetime.utcnow().isoformat()) + request.headers.add_header("X-Amz-Security-Token", "") + request.headers.add_header( + "Authorization", + "AWS4-HMAC-SHA256 Credential=, SignedHeaders=host;x-amz-date," + " Signature=", + ) - # ------------------------------------------------------------------ # - # Context-manager patching # - # ------------------------------------------------------------------ # def __enter__(self): - # Save & override env vars - self._prev_env = { + # Preserve existing env and then set creds/region for util fallback + self._old_env = { k: os.environ.get(k) for k in ( "AWS_ACCESS_KEY_ID", @@ -275,47 +298,50 @@ def __enter__(self): "AWS_REGION", ) } - if self.credentials: - os.environ.update( - { - "AWS_ACCESS_KEY_ID": self.credentials.access_key, - "AWS_SECRET_ACCESS_KEY": self.credentials.secret_key, - "AWS_SESSION_TOKEN": (self.credentials.token or ""), - } - ) - os.environ["AWS_REGION"] = self.region + os.environ.update( + { + "AWS_ACCESS_KEY_ID": self.util_creds.access_key, + "AWS_SECRET_ACCESS_KEY": self.util_creds.secret_key, + "AWS_SESSION_TOKEN": self.util_creds.token or "", + "AWS_REGION": self.region, + } + ) - self.patchers: list[mock._patch] = [ + self.patchers = [ + # boto patches - for old driver tests mock.patch( - "snowflake.connector.wif_util.get_aws_credentials", - side_effect=self.get_credentials, + "boto3.session.Session.get_credentials", + side_effect=self.get_boto_credentials, ), + mock.patch( + "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request + ), + # http approach patches - for new driver tests mock.patch( "snowflake.connector.wif_util.get_aws_region", side_effect=self.get_region, ), - # Avoid real network for IMDS token mock.patch( - "snowflake.connector.wif_util._imds_v2_token", - return_value=None, + "snowflake.connector.wif_util.get_aws_credentials", + side_effect=self.get_aws_credentials, ), - # Block stray HTTP traffic + # never contact IMDS for token mock.patch( - "urllib3.connection.HTTPConnection.request", - side_effect=ConnectTimeout(), + "snowflake.connector.wif_util._imds_v2_token", return_value=None ), ] - for p in self.patchers: - p.__enter__() + for patcher in self.patchers: + patcher.__enter__() return self - def __exit__(self, *args): - for p in self.patchers: - p.__exit__(*args) - # Restore original env vars - for key, val in self._prev_env.items(): - if val is None: - os.environ.pop(key, None) + def __exit__(self, *args, **kwargs): + for patcher in self.patchers: + patcher.__exit__(*args, **kwargs) + + # restore previous env + for k, v in self._old_env.items(): + if v is None: + os.environ.pop(k, None) else: - os.environ[key] = val + os.environ[k] = v From 023f0aeaf99514ca36b287d6f2fffd744188d4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 08:33:38 +0200 Subject: [PATCH 16/54] SNOW-2183023: added arn fetch --- src/snowflake/connector/wif_util.py | 88 +++++++++++++++++++++++- test/csp_helpers.py | 4 ++ test/unit/test_auth_workload_identity.py | 72 +++++++++++++++++++ 3 files changed, 162 insertions(+), 2 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index df6e7cea57..53faffadfa 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -246,6 +246,79 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: return None +def get_aws_arn(session_manager: SessionManager | None = None) -> str | None: + """Get the current AWS workload's ARN by calling GetCallerIdentity. + + Note: This function makes a network call to AWS STS and is only used for + assertion content generation (logging and backward compatibility purposes). + The ARN is not required for authentication - it's just used as a unique + identifier for the workload in logs and assertion content. + + Returns the ARN of the current AWS identity, or None if it cannot be determined. + """ + credentials = get_aws_credentials(session_manager) + if not credentials: + logger.debug("No AWS credentials available for ARN lookup.") + return None + + region = get_aws_region(session_manager) + if not region: + logger.debug("No AWS region available for ARN lookup.") + return None + + try: + # Create the GetCallerIdentity request + sts_hostname = get_aws_sts_hostname(region) + url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" + + base_headers = { + "Content-Type": "application/x-amz-json-1.1", + } + + signed_headers = aws_signature_v4_sign( + credentials=credentials, + method="POST", + url=url, + region=region, + service="sts", + headers=base_headers, + ) + + # Make the actual request to get caller identity + response = http_request( + method="POST", + url=url, + headers=signed_headers, + timeout_sec=10, + session_manager=session_manager, + ) + + if response and response.ok: + # Parse the XML response to extract the ARN + import xml.etree.ElementTree as ET + + # Ensure content is bytes and decode it + content = response.content + if isinstance(content, bytes): + content_str = content.decode("utf-8") + else: + content_str = str(content) if content else "" + + if content_str: + root = ET.fromstring(content_str) + + # Find the Arn element in the response + for elem in root.iter(): + if elem.tag.endswith("Arn") and elem.text: + return elem.text.strip() + + logger.debug("Failed to get ARN from GetCallerIdentity response.") + return None + except Exception as e: + logger.debug(f"Error getting AWS ARN: {e}") + return None + + def get_aws_sts_hostname(region: str) -> str | None: """Constructs the AWS STS hostname for a given region. @@ -412,10 +485,15 @@ def create_aws_attestation( "utf-8" ) + # Get the ARN for user identifier components (used only for assertion content - logging and backward compatibility) + # The ARN is not required for authentication, but provides a unique identifier for the workload + arn = get_aws_arn(session_manager) + user_identifier_components = {"arn": arn} if arn else {} + return WorkloadIdentityAttestation( AttestationProvider.AWS, credential, - {}, # No user identifier components needed - Snowflake will extract from the signed request + user_identifier_components, ) @@ -439,7 +517,13 @@ def create_gcp_attestation( logger.debug("GCP metadata server request was not successful.") return None - jwt_str = res.content.decode("utf-8") + # Ensure content is bytes and decode it + content = res.content + if isinstance(content, bytes): + jwt_str = content.decode("utf-8") + else: + jwt_str = str(content) + issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str) if not issuer or not subject: return None diff --git a/test/csp_helpers.py b/test/csp_helpers.py index d9a395b6f8..b6b50cb850 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -325,6 +325,10 @@ def __enter__(self): "snowflake.connector.wif_util.get_aws_credentials", side_effect=self.get_aws_credentials, ), + mock.patch( + "snowflake.connector.wif_util.get_aws_arn", + side_effect=self.get_arn, + ), # never contact IMDS for token mock.patch( "snowflake.connector.wif_util._imds_v2_token", return_value=None diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 6df7e0fb11..c0dfd62e20 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -1,11 +1,14 @@ import json import logging +import re from base64 import b64decode from unittest import mock from urllib.parse import parse_qs, urlparse import jwt import pytest +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest from snowflake.connector.auth import AuthByWorkloadIdentity from snowflake.connector.errors import ProgrammingError @@ -17,6 +20,7 @@ from snowflake.connector.wif_util import ( AZURE_ISSUER_PREFIXES, AttestationProvider, + get_aws_credentials, get_aws_sts_hostname, ) @@ -25,6 +29,15 @@ logger = logging.getLogger(__name__) +_SIGV4_RE = re.compile( + r"^AWS4-HMAC-SHA256 " + r"Credential=[^/]+/\d{8}/[^/]+/sts/aws4_request, " + r"SignedHeaders=[a-z0-9\-;]+, " + r"Signature=[0-9a-f]{64}$", + re.IGNORECASE, +) + + def extract_api_data(auth_class: AuthByWorkloadIdentity): """Extracts the 'data' portion of the request body populated by the given auth class.""" req_body = {"data": {}} @@ -428,3 +441,62 @@ def test_autodetect_no_provider_raises_error(no_metadata_service): assert "No workload identity credential was found for 'auto-detect" in str( excinfo.value ) + + +def test_aws_token_authorization_header_format( + fake_aws_environment: FakeAwsEnvironment, +): + """ + The internal signer should still emit a well-formed SigV4 `Authorization` + header (same format that botocore would produce). + """ + auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth.prepare() + + headers = json.loads(b64decode(extract_api_data(auth)["TOKEN"]))["headers"] + assert _SIGV4_RE.match(headers["Authorization"]) + + +def test_internal_signer_vs_botocore_prefix(fake_aws_environment: FakeAwsEnvironment): + """ + Compare the *static* parts of the Authorization header (algorithm, + Credential=…, SignedHeaders=…) between our pure-HTTP signer and the + reference implementation in botocore. We ignore the Signature value + because it will differ when the dates differ. + """ + # --- headers from the new HTTP-only path ----------------------------- + auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth.prepare() + new_hdr = json.loads(b64decode(extract_api_data(auth)["TOKEN"]))["headers"][ + "Authorization" + ] + new_prefix = new_hdr.split("Signature=")[0] + + # --- headers from real botocore SigV4Auth ---------------------------- + creds = fake_aws_environment.credentials # boto Credentials + region = fake_aws_environment.region + url = ( + f"https://sts.{region}.amazonaws.com/" + "?Action=GetCallerIdentity&Version=2011-06-15" + ) + + req = AWSRequest(method="POST", url=url) + req.headers["Host"] = f"sts.{region}.amazonaws.com" + req.headers["X-Snowflake-Audience"] = "snowflakecomputing.com" + SigV4Auth(creds, "sts", region).add_auth(req) + boto_prefix = req.headers["Authorization"].split("Signature=")[0] + + # Credential=… and SignedHeaders=… should be identical + assert new_prefix == boto_prefix + + +def test_get_aws_credentials_fallback_env(fake_aws_environment: FakeAwsEnvironment): + """ + The util’s environment-variable fallback path should return exactly the + values we injected via FakeAwsEnvironment. + """ + + creds = get_aws_credentials() + assert creds.access_key == fake_aws_environment.util_creds.access_key + assert creds.secret_key == fake_aws_environment.util_creds.secret_key + assert creds.token == fake_aws_environment.util_creds.token From 3fef4e71a3b8b0f889d62e6164df132369eff901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 08:35:04 +0200 Subject: [PATCH 17/54] Revert "Reapply "Reapply "SNOW-2183023: refactored""" This reverts commit e603c1c76b116a26513f333d6d326f9694e67c1d. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 - src/snowflake/connector/http_client.py | 103 ----------- src/snowflake/connector/network.py | 145 +++++++++++++++- src/snowflake/connector/session_manager.py | 162 ------------------ src/snowflake/connector/wif_util.py | 77 +++------ 6 files changed, 165 insertions(+), 334 deletions(-) delete mode 100644 src/snowflake/connector/http_client.py delete mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 7f8ab60718..3c80c965e4 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,13 +74,10 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, *, conn, **kwargs: typing.Any) -> None: + def prepare(self, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, - self.entra_resource, - self.token, - session_manager=conn.session_manager if conn else None, + self.provider, self.entra_resource, self.token ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index ffc193df37..963e04ee8a 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,11 +712,6 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest - @property - def session_manager(self): - """Access to the connection's SessionManager for making HTTP requests.""" - return self._rest.session_manager if self._rest else None - @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py deleted file mode 100644 index 1c802c2439..0000000000 --- a/src/snowflake/connector/http_client.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Mapping - -from .session_manager import SessionManager -from .vendored.requests import Response - -logger = logging.getLogger(__name__) - - -class HttpClient: - """HTTP client that uses SessionManager for connection pooling and adapter management.""" - - def __init__(self, session_manager: SessionManager): - """Initialize HttpClient with a SessionManager. - - Args: - session_manager: SessionManager instance to use for all requests - """ - self.session_manager = session_manager - - def request( - self, - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - use_pooling: bool | None = None, - **kwargs: Any, - ) -> Response: - """Make an HTTP request using the configured SessionManager. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - """ - mgr = ( - self.session_manager - if use_pooling is None - else self.session_manager.clone(use_pooling=use_pooling) - ) - - with mgr.use_session(url) as session: - return session.request( - method=method.upper(), - url=url, - headers=headers, - timeout=timeout_sec, - **kwargs, - ) - - -# Convenience function for backwards compatibility and simple usage -def request( - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - session_manager: SessionManager | None = None, - use_pooling: bool | None = None, - **kwargs: Any, -) -> Response: - """Convenience function for making HTTP requests. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - session_manager: SessionManager instance to use (required) - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - - Raises: - ValueError: If session_manager is None - """ - if session_manager is None: - raise ValueError( - "session_manager is required - no default session manager available" - ) - - client = HttpClient(session_manager) - return client.request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - use_pooling=use_pooling, - **kwargs, - ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index fbea591258..84652205fa 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,12 +1,16 @@ #!/usr/bin/env python from __future__ import annotations +import collections +import contextlib import gzip +import itertools import json import logging import re import time import uuid +from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -14,6 +18,10 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest +from snowflake.connector.vendored.urllib3.connectionpool import ( + HTTPConnectionPool, + HTTPSConnectionPool, +) from . import ssl_wrap_socket from .compat import ( @@ -76,7 +84,6 @@ ServiceUnavailableError, TooManyRequests, ) -from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -89,16 +96,19 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response +from .vendored.requests import Response, Session from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, + InvalidProxyURL, ReadTimeout, SSLError, ) +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError +from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -238,6 +248,42 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: OrderedDict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + class RetryRequest(Exception): """Signal to retry request.""" @@ -288,6 +334,101 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py deleted file mode 100644 index 85b46ccfd4..0000000000 --- a/src/snowflake/connector/session_manager.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -import collections -import contextlib -import itertools -import logging -from typing import TYPE_CHECKING, Callable - -from .compat import urlparse -from .vendored import requests -from .vendored.requests import Session -from .vendored.requests.adapters import HTTPAdapter -from .vendored.requests.exceptions import InvalidProxyURL -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy -from .vendored.urllib3.poolmanager import ProxyManager -from .vendored.urllib3.util.url import parse_url - -if TYPE_CHECKING: - from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool - -logger = logging.getLogger(__name__) - -# requests parameters -REQUESTS_RETRY = 1 # requests library builtin retry - - -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: dict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - def clone(self, *, use_pooling: bool | None = None) -> SessionManager: - """Return an independent manager that reuses the adapter_factory.""" - return SessionManager( - use_pooling=self._use_pooling if use_pooling is None else use_pooling, - adapter_factory=self._adapter_factory, - ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 53faffadfa..0063c23f07 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,8 +15,6 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError -from .http_client import request as http_request -from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -74,28 +72,15 @@ class AwsCredentials: def try_metadata_service_call( - method: str, - url: str, - headers: dict, - timeout_sec: int = 3, - session_manager: SessionManager | None = None, + method: str, url: str, headers: dict, timeout_sec: int = 3 ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - # If no session_manager provided, create a basic one for this call - if session_manager is None: - session_manager = SessionManager(use_pooling=False) - - res: Response = http_request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - session_manager=session_manager, - use_pooling=False, # IMDS calls are rare → don't pollute pool + res: Response = requests.request( + method=method, url=url, headers=headers, timeout=timeout_sec ) if not res.ok: return None @@ -134,19 +119,16 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: +def _imds_v2_token() -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, - session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials( - session_manager: SessionManager | None = None, -) -> AwsCredentials | None: +def get_aws_credentials() -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -161,7 +143,7 @@ def get_aws_credentials( # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -173,7 +155,6 @@ def get_aws_credentials( method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -189,7 +170,6 @@ def get_aws_credentials( method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, - session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -208,14 +188,14 @@ def get_aws_credentials( return None -def get_aws_region(session_manager: SessionManager | None = None) -> str | None: +def get_aws_region() -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token(session_manager) + token = _imds_v2_token() if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -227,7 +207,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -236,7 +215,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, - session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -440,19 +418,17 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_aws_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials(session_manager) + credentials = get_aws_credentials() if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region(session_manager) + region = get_aws_region() if not region: logger.debug("No AWS region was found.") return None @@ -497,9 +473,7 @@ def create_aws_attestation( ) -def create_gcp_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation | None: +def create_gcp_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -510,7 +484,6 @@ def create_gcp_attestation( headers={ "Metadata-Flavor": "Google", }, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -539,7 +512,6 @@ def create_gcp_attestation( def create_azure_attestation( snowflake_entra_resource: str, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -573,7 +545,6 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, - session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -624,9 +595,7 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, - token: str | None = None, - session_manager: SessionManager | None = None, + entra_resource: str, token: str | None = None ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -636,15 +605,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) if attestation: return attestation - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() if attestation: return attestation - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() if attestation: return attestation @@ -655,7 +624,6 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, - session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -665,23 +633,18 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE - session_manager = ( - session_manager.clone() if session_manager else SessionManager(use_pooling=True) - ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation(session_manager) + attestation = create_aws_attestation() elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource, session_manager) + attestation = create_azure_attestation(entra_resource) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation(session_manager) + attestation = create_gcp_attestation() elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation( - entra_resource, token, session_manager - ) + attestation = create_autodetect_attestation(entra_resource, token) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From 2e23fe3b47ecf78ebfdf2ffa485e43b907c703e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 08:35:41 +0200 Subject: [PATCH 18/54] Reapply "Reapply "Reapply "SNOW-2183023: refactored""" This reverts commit 3fef4e71a3b8b0f889d62e6164df132369eff901. --- .../connector/auth/workload_identity.py | 7 +- src/snowflake/connector/connection.py | 5 + src/snowflake/connector/http_client.py | 103 +++++++++++ src/snowflake/connector/network.py | 145 +--------------- src/snowflake/connector/session_manager.py | 162 ++++++++++++++++++ src/snowflake/connector/wif_util.py | 77 ++++++--- 6 files changed, 334 insertions(+), 165 deletions(-) create mode 100644 src/snowflake/connector/http_client.py create mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 3c80c965e4..7f8ab60718 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -74,10 +74,13 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, **kwargs: typing.Any) -> None: + def prepare(self, *, conn, **kwargs: typing.Any) -> None: """Fetch the token.""" self.attestation = create_attestation( - self.provider, self.entra_resource, self.token + self.provider, + self.entra_resource, + self.token, + session_manager=conn.session_manager if conn else None, ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 963e04ee8a..ffc193df37 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -712,6 +712,11 @@ def client_fetch_use_mp(self) -> bool: def rest(self) -> SnowflakeRestful | None: return self._rest + @property + def session_manager(self): + """Access to the connection's SessionManager for making HTTP requests.""" + return self._rest.session_manager if self._rest else None + @property def application(self) -> str: return self._application diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py new file mode 100644 index 0000000000..1c802c2439 --- /dev/null +++ b/src/snowflake/connector/http_client.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from typing import Any, Mapping + +from .session_manager import SessionManager +from .vendored.requests import Response + +logger = logging.getLogger(__name__) + + +class HttpClient: + """HTTP client that uses SessionManager for connection pooling and adapter management.""" + + def __init__(self, session_manager: SessionManager): + """Initialize HttpClient with a SessionManager. + + Args: + session_manager: SessionManager instance to use for all requests + """ + self.session_manager = session_manager + + def request( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + use_pooling: bool | None = None, + **kwargs: Any, + ) -> Response: + """Make an HTTP request using the configured SessionManager. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + """ + mgr = ( + self.session_manager + if use_pooling is None + else self.session_manager.clone(use_pooling=use_pooling) + ) + + with mgr.use_session(url) as session: + return session.request( + method=method.upper(), + url=url, + headers=headers, + timeout=timeout_sec, + **kwargs, + ) + + +# Convenience function for backwards compatibility and simple usage +def request( + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + session_manager: SessionManager | None = None, + use_pooling: bool | None = None, + **kwargs: Any, +) -> Response: + """Convenience function for making HTTP requests. + + Args: + method: HTTP method (GET, POST, etc.) + url: Target URL + headers: Optional HTTP headers + timeout_sec: Request timeout in seconds + session_manager: SessionManager instance to use (required) + use_pooling: Whether to use connection pooling (overrides session_manager setting) + **kwargs: Additional arguments passed to requests.Session.request + + Returns: + Response object from the request + + Raises: + ValueError: If session_manager is None + """ + if session_manager is None: + raise ValueError( + "session_manager is required - no default session manager available" + ) + + client = HttpClient(session_manager) + return client.request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + use_pooling=use_pooling, + **kwargs, + ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 84652205fa..fbea591258 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,16 +1,12 @@ #!/usr/bin/env python from __future__ import annotations -import collections -import contextlib import gzip -import itertools import json import logging import re import time import uuid -from collections import OrderedDict from threading import Lock from typing import TYPE_CHECKING, Any, Callable @@ -18,10 +14,6 @@ from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.vendored.requests.models import PreparedRequest -from snowflake.connector.vendored.urllib3.connectionpool import ( - HTTPConnectionPool, - HTTPSConnectionPool, -) from . import ssl_wrap_socket from .compat import ( @@ -84,6 +76,7 @@ ServiceUnavailableError, TooManyRequests, ) +from .session_manager import SessionManager, SessionPool from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -96,19 +89,16 @@ ) from .tool.probe_connection import probe_connection from .vendored import requests -from .vendored.requests import Response, Session +from .vendored.requests import Response from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase from .vendored.requests.exceptions import ( ConnectionError, ConnectTimeout, - InvalidProxyURL, ReadTimeout, SSLError, ) -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy from .vendored.urllib3.exceptions import ProtocolError -from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url if TYPE_CHECKING: @@ -248,42 +238,6 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: OrderedDict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - class RetryRequest(Exception): """Signal to retry request.""" @@ -334,101 +288,6 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: return r -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - # Customizable JSONEncoder to support additional types. class SnowflakeRestfulJsonEncoder(json.JSONEncoder): def default(self, o): diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py new file mode 100644 index 0000000000..85b46ccfd4 --- /dev/null +++ b/src/snowflake/connector/session_manager.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import collections +import contextlib +import itertools +import logging +from typing import TYPE_CHECKING, Callable + +from .compat import urlparse +from .vendored import requests +from .vendored.requests import Session +from .vendored.requests.adapters import HTTPAdapter +from .vendored.requests.exceptions import InvalidProxyURL +from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy +from .vendored.urllib3.poolmanager import ProxyManager +from .vendored.urllib3.util.url import parse_url + +if TYPE_CHECKING: + from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool + +logger = logging.getLogger(__name__) + +# requests parameters +REQUESTS_RETRY = 1 # requests library builtin retry + + +class ProxySupportAdapter(HTTPAdapter): + """This Adapter creates proper headers for Proxy CONNECT messages.""" + + def get_connection( + self, url: str, proxies: dict | None = None + ) -> HTTPConnectionPool | HTTPSConnectionPool: + proxy = select_proxy(url, proxies) + parsed_url = urlparse(url) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed" + " and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + + if isinstance(proxy_manager, ProxyManager): + # Add Host to proxy header SNOW-232777 + proxy_manager.proxy_headers["Host"] = parsed_url.hostname + else: + logger.debug( + f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" + f" it does not have attribute 'proxy_headers'." + ) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + url = parsed_url.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + +class SessionPool: + def __init__(self, manager: SessionManager) -> None: + # A stack of the idle sessions + self._idle_sessions = [] + self._active_sessions = set() + self._manager = manager + + def get_session(self) -> Session: + """Returns a session from the session pool or creates a new one.""" + try: + session = self._idle_sessions.pop() + except IndexError: + session = self._manager.make_session() + self._active_sessions.add(session) + return session + + def return_session(self, session: Session) -> None: + """Places an active session back into the idle session stack.""" + try: + self._active_sessions.remove(session) + except KeyError: + logger.debug("session doesn't exist in the active session pool. Ignored...") + self._idle_sessions.append(session) + + def __str__(self) -> str: + total_sessions = len(self._active_sessions) + len(self._idle_sessions) + return ( + f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" + ) + + def close(self) -> None: + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for session in itertools.chain(self._active_sessions, self._idle_sessions): + try: + session.close() + except Exception as e: + logger.info(f"Session cleanup failed - failed to close session: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + +class SessionManager: + def __init__( + self, + use_pooling: bool = True, + adapter_factory: ( + Callable[..., HTTPAdapter] | None + ) = lambda *args, **kwargs: None, + ): + self._use_pooling = use_pooling + self._adapter_factory = adapter_factory or ProxySupportAdapter + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + + @property + def sessions_map(self) -> dict[str, SessionPool]: + return self._sessions_map + + def _mount_adapter(self, session: requests.Session) -> None: + adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) + if adapter is not None: + session.mount("http://", adapter) + session.mount("https://", adapter) + + def make_session(self) -> Session: + s = requests.Session() + self._mount_adapter(s) + s._reuse_count = itertools.count() + return s + + @contextlib.contextmanager + def use_session(self, url: str | None = None): + if not self._use_pooling: + session = self.make_session() + try: + yield session + finally: + session.close() + else: + hostname = urlparse(url).hostname if url else None + pool = self._sessions_map[hostname] + session = pool.get_session() + try: + yield session + finally: + pool.return_session(session) + + def close(self): + for pool in self._sessions_map.values(): + pool.close() + + def clone(self, *, use_pooling: bool | None = None) -> SessionManager: + """Return an independent manager that reuses the adapter_factory.""" + return SessionManager( + use_pooling=self._use_pooling if use_pooling is None else use_pooling, + adapter_factory=self._adapter_factory, + ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 0063c23f07..53faffadfa 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,6 +15,8 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError +from .http_client import request as http_request +from .session_manager import SessionManager from .vendored import requests from .vendored.requests import Response @@ -72,15 +74,28 @@ class AwsCredentials: def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, + url: str, + headers: dict, + timeout_sec: int = 3, + session_manager: SessionManager | None = None, ) -> Response | None: """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: - res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + # If no session_manager provided, create a basic one for this call + if session_manager is None: + session_manager = SessionManager(use_pooling=False) + + res: Response = http_request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + session_manager=session_manager, + use_pooling=False, # IMDS calls are rare → don't pollute pool ) if not res.ok: return None @@ -119,16 +134,19 @@ def extract_iss_and_sub_without_signature_verification( # --------------------------------------------------------------------------- # # AWS helper utilities (token, credentials, region) # # --------------------------------------------------------------------------- # -def _imds_v2_token() -> str | None: +def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: res = try_metadata_service_call( method="PUT", url="http://169.254.169.254/latest/api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + session_manager=session_manager, ) return res.text.strip() if res else None -def get_aws_credentials() -> AwsCredentials | None: +def get_aws_credentials( + session_manager: SessionManager | None = None, +) -> AwsCredentials | None: """Get AWS credentials from environment variables or instance metadata. Implements the AWS credential chain without using boto3. @@ -143,7 +161,7 @@ def get_aws_credentials() -> AwsCredentials | None: # Try instance metadata service (IMDSv2) try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -155,6 +173,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url="http://169.254.169.254/latest/meta-data/iam/security-credentials/", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role list from metadata service.") @@ -170,6 +189,7 @@ def get_aws_credentials() -> AwsCredentials | None: method="GET", url=f"http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}", headers=token_hdr, + session_manager=session_manager, ) if res is None: logger.debug("Failed to get IAM role credentials from metadata service.") @@ -188,14 +208,14 @@ def get_aws_credentials() -> AwsCredentials | None: return None -def get_aws_region() -> str | None: +def get_aws_region(session_manager: SessionManager | None = None) -> str | None: """Get the current AWS workload's region, if any.""" region = os.environ.get("AWS_REGION") if region: return region try: - token = _imds_v2_token() + token = _imds_v2_token(session_manager) if token is None: logger.debug("Failed to get IMDSv2 token from metadata service.") return None @@ -207,6 +227,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/region", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip() @@ -215,6 +236,7 @@ def get_aws_region() -> str | None: method="GET", url="http://169.254.169.254/latest/meta-data/placement/availability-zone", headers=token_hdr, + session_manager=session_manager, ) if res is not None: return res.text.strip()[:-1] @@ -418,17 +440,19 @@ def hmac_sha256(key: bytes, msg: str) -> bytes: return final_headers -def create_aws_attestation() -> WorkloadIdentityAttestation | None: +def create_aws_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - credentials = get_aws_credentials() + credentials = get_aws_credentials(session_manager) if not credentials: logger.debug("No AWS credentials were found.") return None - region = get_aws_region() + region = get_aws_region(session_manager) if not region: logger.debug("No AWS region was found.") return None @@ -473,7 +497,9 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: ) -def create_gcp_attestation() -> WorkloadIdentityAttestation | None: +def create_gcp_attestation( + session_manager: SessionManager | None = None, +) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for GCP. If the application isn't running on GCP or no credentials were found, returns None. @@ -484,6 +510,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: headers={ "Metadata-Flavor": "Google", }, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on GCP, which may be expected. @@ -512,6 +539,7 @@ def create_gcp_attestation() -> WorkloadIdentityAttestation | None: def create_azure_attestation( snowflake_entra_resource: str, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for Azure. @@ -545,6 +573,7 @@ def create_azure_attestation( method="GET", url=f"{url_without_query_string}?{query_params}", headers=headers, + session_manager=session_manager, ) if res is None: # Most likely we're just not running on Azure, which may be expected. @@ -595,7 +624,9 @@ def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | def create_autodetect_attestation( - entra_resource: str, token: str | None = None + entra_resource: str, + token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation | None: """Tries to create an attestation using the auto-detected runtime environment. @@ -605,15 +636,15 @@ def create_autodetect_attestation( if attestation: return attestation - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) if attestation: return attestation - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) if attestation: return attestation - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) if attestation: return attestation @@ -624,6 +655,7 @@ def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, + session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -633,18 +665,23 @@ def create_attestation( If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + session_manager = ( + session_manager.clone() if session_manager else SessionManager(use_pooling=True) + ) attestation: WorkloadIdentityAttestation | None = None if provider == AttestationProvider.AWS: - attestation = create_aws_attestation() + attestation = create_aws_attestation(session_manager) elif provider == AttestationProvider.AZURE: - attestation = create_azure_attestation(entra_resource) + attestation = create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - attestation = create_gcp_attestation() + attestation = create_gcp_attestation(session_manager) elif provider == AttestationProvider.OIDC: attestation = create_oidc_attestation(token) elif provider is None: - attestation = create_autodetect_attestation(entra_resource, token) + attestation = create_autodetect_attestation( + entra_resource, token, session_manager + ) if not attestation: provider_str = "auto-detect" if provider is None else provider.value From 59a3373599f6d2e3c9b133b77a75b6c4cbe7bbb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 09:21:15 +0200 Subject: [PATCH 19/54] SNOW-2183023: fixed missing prepare args --- .../connector/auth/workload_identity.py | 7 ++++- test/unit/test_auth_workload_identity.py | 28 +++++++++---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 7f8ab60718..e9a5a2c41d 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -4,6 +4,9 @@ import typing from enum import Enum, unique +if typing.TYPE_CHECKING: + from snowflake.connector.connection import SnowflakeConnection + from ..network import WORKLOAD_IDENTITY_AUTHENTICATOR from ..wif_util import ( AttestationProvider, @@ -74,7 +77,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: ).value body["data"]["TOKEN"] = self.attestation.credential - def prepare(self, *, conn, **kwargs: typing.Any) -> None: + def prepare( + self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any + ) -> None: """Fetch the token.""" self.attestation = create_attestation( self.provider, diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index c0dfd62e20..3c0191c967 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -82,7 +82,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api(): auth_class = AuthByWorkloadIdentity( provider=AttestationProvider.OIDC, token=dummy_token ) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -96,7 +96,7 @@ def test_explicit_oidc_valid_inline_token_generates_unique_assertion_content(): auth_class = AuthByWorkloadIdentity( provider=AttestationProvider.OIDC, token=dummy_token ) - auth_class.prepare() + auth_class.prepare(conn=None) assert ( auth_class.assertion_content == '{"_provider":"OIDC","iss":"issuer-1","sub":"service-1"}' @@ -109,14 +109,14 @@ def test_explicit_oidc_invalid_inline_token_raises_error(): provider=AttestationProvider.OIDC, token=invalid_token ) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) def test_explicit_oidc_no_token_raises_error(): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.OIDC, token=None) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) @@ -128,7 +128,7 @@ def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironm auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'AWS'" in str(excinfo.value) @@ -136,7 +136,7 @@ def test_explicit_aws_encodes_audience_host_signature_to_api( fake_aws_environment: FakeAwsEnvironment, ): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) - auth_class.prepare() + auth_class.prepare(conn=None) data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" @@ -152,7 +152,7 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro fake_aws_environment.region = "antarctica-northeast-3" auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) - auth_class.prepare() + auth_class.prepare(conn=None) data = extract_api_data(auth_class) decoded_token = json.loads(b64decode(data["TOKEN"])) @@ -171,7 +171,7 @@ def test_explicit_aws_generates_unique_assertion_content( "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab" ) auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) - auth_class.prepare() + auth_class.prepare(conn=None) assert ( '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' @@ -224,7 +224,7 @@ def test_explicit_gcp_metadata_server_error_raises_auth_error(exception): "snowflake.connector.vendored.requests.request", side_effect=exception ): with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'GCP'" in str( excinfo.value ) @@ -237,7 +237,7 @@ def test_explicit_gcp_wrong_issuer_raises_error( auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'GCP'" in str(excinfo.value) @@ -245,7 +245,7 @@ def test_explicit_gcp_plumbs_token_to_api( fake_gce_metadata_service: FakeGceMetadataService, ): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -260,7 +260,7 @@ def test_explicit_gcp_generates_unique_assertion_content( fake_gce_metadata_service.sub = "123456" auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) - auth_class.prepare() + auth_class.prepare(conn=None) assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' @@ -282,7 +282,7 @@ def test_explicit_azure_metadata_server_error_raises_auth_error(exception): "snowflake.connector.vendored.requests.request", side_effect=exception ): with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'AZURE'" in str( excinfo.value ) @@ -293,7 +293,7 @@ def test_explicit_azure_wrong_issuer_raises_error(fake_azure_metadata_service): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'AZURE'" in str(excinfo.value) From 093b5b525355d2514cf210299238627ea9594b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 09:42:03 +0200 Subject: [PATCH 20/54] SNOW-2183023: fixed missing prepare args --- test/unit/test_auth_workload_identity.py | 33 ++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 3c0191c967..c796d989c9 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -124,7 +124,7 @@ def test_explicit_oidc_no_token_raises_error(): def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironment): - fake_aws_environment.credentials = None + fake_aws_environment.util_creds = None auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) with pytest.raises(ProgrammingError) as excinfo: @@ -144,7 +144,7 @@ def test_explicit_aws_encodes_audience_host_signature_to_api( verify_aws_token( data["TOKEN"], fake_aws_environment.region, - expect_session_token=fake_aws_environment.credentials.token is not None, + expect_session_token=fake_aws_environment.util_creds.token is not None, ) @@ -310,14 +310,14 @@ def test_explicit_azure_v1_and_v2_issuers_accepted(fake_azure_metadata_service, fake_azure_metadata_service.iss = issuer auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - auth_class.prepare() + auth_class.prepare(conn=None) assert issuer == json.loads(auth_class.assertion_content)["iss"] def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -333,7 +333,7 @@ def test_explicit_azure_generates_unique_assertion_content(fake_azure_metadata_s fake_azure_metadata_service.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - auth_class.prepare() + auth_class.prepare(conn=None) assert ( '{"_provider":"AZURE","iss":"https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd","sub":"611ab25b-2e81-4e18-92a7-b21f2bebb269"}' @@ -345,7 +345,7 @@ def test_explicit_azure_uses_default_entra_resource_if_unspecified( fake_azure_metadata_service, ): auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - auth_class.prepare() + auth_class.prepare(conn=None) token = fake_azure_metadata_service.token parsed = jwt.decode(token, options={"verify_signature": False}) @@ -358,7 +358,7 @@ def test_explicit_azure_uses_explicit_entra_resource(fake_azure_metadata_service auth_class = AuthByWorkloadIdentity( provider=AttestationProvider.AZURE, entra_resource="api://non-standard" ) - auth_class.prepare() + auth_class.prepare(conn=None) token = fake_azure_metadata_service.token parsed = jwt.decode(token, options={"verify_signature": False}) @@ -388,7 +388,7 @@ def test_autodetect_aws_present( no_metadata_service, fake_aws_environment: FakeAwsEnvironment ): auth_class = AuthByWorkloadIdentity(provider=None) - auth_class.prepare() + auth_class.prepare(conn=None) data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" @@ -396,13 +396,13 @@ def test_autodetect_aws_present( verify_aws_token( data["TOKEN"], fake_aws_environment.region, - expect_session_token=fake_aws_environment.credentials.token is not None, + expect_session_token=fake_aws_environment.util_creds.token is not None, ) def test_autodetect_gcp_present(fake_gce_metadata_service: FakeGceMetadataService): auth_class = AuthByWorkloadIdentity(provider=None) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -413,7 +413,7 @@ def test_autodetect_gcp_present(fake_gce_metadata_service: FakeGceMetadataServic def test_autodetect_azure_present(fake_azure_metadata_service): auth_class = AuthByWorkloadIdentity(provider=None) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -425,7 +425,7 @@ def test_autodetect_azure_present(fake_azure_metadata_service): def test_autodetect_oidc_present(no_metadata_service): dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") auth_class = AuthByWorkloadIdentity(provider=None, token=dummy_token) - auth_class.prepare() + auth_class.prepare(conn=None) assert extract_api_data(auth_class) == { "AUTHENTICATOR": "WORKLOAD_IDENTITY", @@ -437,7 +437,7 @@ def test_autodetect_oidc_present(no_metadata_service): def test_autodetect_no_provider_raises_error(no_metadata_service): auth_class = AuthByWorkloadIdentity(provider=None, token=None) with pytest.raises(ProgrammingError) as excinfo: - auth_class.prepare() + auth_class.prepare(conn=None) assert "No workload identity credential was found for 'auto-detect" in str( excinfo.value ) @@ -451,7 +451,7 @@ def test_aws_token_authorization_header_format( header (same format that botocore would produce). """ auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) - auth.prepare() + auth.prepare(conn=None) headers = json.loads(b64decode(extract_api_data(auth)["TOKEN"]))["headers"] assert _SIGV4_RE.match(headers["Authorization"]) @@ -466,14 +466,14 @@ def test_internal_signer_vs_botocore_prefix(fake_aws_environment: FakeAwsEnviron """ # --- headers from the new HTTP-only path ----------------------------- auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) - auth.prepare() + auth.prepare(conn=None) new_hdr = json.loads(b64decode(extract_api_data(auth)["TOKEN"]))["headers"][ "Authorization" ] new_prefix = new_hdr.split("Signature=")[0] # --- headers from real botocore SigV4Auth ---------------------------- - creds = fake_aws_environment.credentials # boto Credentials + creds = fake_aws_environment.util_creds # boto Credentials region = fake_aws_environment.region url = ( f"https://sts.{region}.amazonaws.com/" @@ -497,6 +497,7 @@ def test_get_aws_credentials_fallback_env(fake_aws_environment: FakeAwsEnvironme """ creds = get_aws_credentials() + assert fake_aws_environment.util_creds is not None assert creds.access_key == fake_aws_environment.util_creds.access_key assert creds.secret_key == fake_aws_environment.util_creds.secret_key assert creds.token == fake_aws_environment.util_creds.token From 219e8c66f7846dfe620334e85f719680253f26e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 09:48:17 +0200 Subject: [PATCH 21/54] SNOW-2183023: fixed patching --- test/csp_helpers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index b6b50cb850..8102ff2dfe 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -97,11 +97,12 @@ def __enter__(self): """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() self.patchers = [] - # requests.request is used by the direct metadata service API calls from our code. This is the main + # Session.request is used by the direct metadata service API calls from our code. This is the main # thing being faked here. self.patchers.append( mock.patch( - "snowflake.connector.vendored.requests.request", side_effect=self + "snowflake.connector.vendored.requests.Session.request", + side_effect=self, ) ) # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we From b4862391c9f3fdba3b507eb3aacd1bfcc41d8030 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 09:51:59 +0200 Subject: [PATCH 22/54] SNOW-2183023: fixed tests --- test/csp_helpers.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 8102ff2dfe..2218f1c1b2 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -280,14 +280,31 @@ def get_aws_credentials(self, *_, **__) -> AwsCredentials | None: return self.util_creds def sign_request(self, request: AWSRequest): - request.headers.add_header("X-Amz-Date", datetime.datetime.utcnow().isoformat()) - request.headers.add_header("X-Amz-Security-Token", "") - request.headers.add_header( - "Authorization", - "AWS4-HMAC-SHA256 Credential=, SignedHeaders=host;x-amz-date," - " Signature=", + # Generate a proper-looking authorization header that matches what the real signer would produce + utc_now = datetime.datetime.utcnow() + amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") + date_string = utc_now.strftime("%Y%m%d") + + # Add the same headers that the real signer would add + request.headers.add_header("X-Amz-Date", amz_date) + request.headers.add_header("X-Amz-Security-Token", self.util_creds.token) + + # Generate signed headers list that matches what the real signer would include + header_keys = [] + for key in sorted(request.headers.keys(), key=str.lower): + header_keys.append(key.lower()) + + signed_headers = ";".join(header_keys) + credential_scope = f"{date_string}/{self.region}/sts/aws4_request" + + authorization = ( + f"AWS4-HMAC-SHA256 " + f"Credential={self.util_creds.access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, Signature=" ) + request.headers.add_header("Authorization", authorization) + def __enter__(self): # Preserve existing env and then set creds/region for util fallback self._old_env = { From af2f4a32eb6a5bb615c1895768407be3fc44ea9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 09:55:13 +0200 Subject: [PATCH 23/54] SNOW-2183023: fixed tests --- test/csp_helpers.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 2218f1c1b2..c4c40912b0 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -279,32 +279,36 @@ def get_boto_credentials(self, *_, **__) -> Credentials | None: def get_aws_credentials(self, *_, **__) -> AwsCredentials | None: return self.util_creds - def sign_request(self, request: AWSRequest): - # Generate a proper-looking authorization header that matches what the real signer would produce + def sign_request(self, request: AWSRequest) -> None: + """ + Fake replacement for botocore SigV4Auth.add_auth that produces the same + *static* parts of the Authorization header (everything before + `Signature=`). + """ + # Add the headers a real signer would inject utc_now = datetime.datetime.utcnow() amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") - date_string = utc_now.strftime("%Y%m%d") + date_stamp = utc_now.strftime("%Y%m%d") - # Add the same headers that the real signer would add - request.headers.add_header("X-Amz-Date", amz_date) - request.headers.add_header("X-Amz-Security-Token", self.util_creds.token) + request.headers["X-Amz-Date"] = amz_date + if self.util_creds.token: + request.headers["X-Amz-Security-Token"] = self.util_creds.token - # Generate signed headers list that matches what the real signer would include - header_keys = [] - for key in sorted(request.headers.keys(), key=str.lower): - header_keys.append(key.lower()) + # Host header is already set by the test; add it if a future test forgets + if "Host" not in request.headers: + request.headers["Host"] = urlparse(request.url).netloc - signed_headers = ";".join(header_keys) - credential_scope = f"{date_string}/{self.region}/sts/aws4_request" + # Build the signed-headers list + signed_headers = ";".join(sorted(h.lower() for h in request.headers.keys())) - authorization = ( - f"AWS4-HMAC-SHA256 " + credential_scope = f"{date_stamp}/{self.region}/sts/aws4_request" + + request.headers["Authorization"] = ( + "AWS4-HMAC-SHA256 " f"Credential={self.util_creds.access_key}/{credential_scope}, " f"SignedHeaders={signed_headers}, Signature=" ) - request.headers.add_header("Authorization", authorization) - def __enter__(self): # Preserve existing env and then set creds/region for util fallback self._old_env = { From 88ccb83f5e2e35ec0e48d3a9e9a1aa18e4cbc942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 10:02:09 +0200 Subject: [PATCH 24/54] SNOW-2183023: fixed tests --- src/snowflake/connector/wif_util.py | 3 --- test/csp_helpers.py | 3 +-- test/unit/test_auth_workload_identity.py | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 53faffadfa..4abaa81617 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -131,9 +131,6 @@ def extract_iss_and_sub_without_signature_verification( return claims["iss"], claims["sub"] -# --------------------------------------------------------------------------- # -# AWS helper utilities (token, credentials, region) # -# --------------------------------------------------------------------------- # def _imds_v2_token(session_manager: SessionManager | None = None) -> str | None: res = try_metadata_service_call( method="PUT", diff --git a/test/csp_helpers.py b/test/csp_helpers.py index c4c40912b0..e17bcbc0f6 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -291,8 +291,7 @@ def sign_request(self, request: AWSRequest) -> None: date_stamp = utc_now.strftime("%Y%m%d") request.headers["X-Amz-Date"] = amz_date - if self.util_creds.token: - request.headers["X-Amz-Security-Token"] = self.util_creds.token + request.headers["X-Amz-Security-Token"] = self.util_creds.token # Host header is already set by the test; add it if a future test forgets if "Host" not in request.headers: diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index c796d989c9..ef065af37f 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -473,7 +473,7 @@ def test_internal_signer_vs_botocore_prefix(fake_aws_environment: FakeAwsEnviron new_prefix = new_hdr.split("Signature=")[0] # --- headers from real botocore SigV4Auth ---------------------------- - creds = fake_aws_environment.util_creds # boto Credentials + creds = fake_aws_environment.boto_creds # boto Credentials region = fake_aws_environment.region url = ( f"https://sts.{region}.amazonaws.com/" From 9fbda91fe1e7927b726a40d1a4e20905208001cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 10:27:22 +0200 Subject: [PATCH 25/54] SNOW-2183023: removed http_client for now --- DESCRIPTION.md | 1 + src/snowflake/connector/http_client.py | 103 --------------------- src/snowflake/connector/session_manager.py | 65 ++++++++++++- src/snowflake/connector/wif_util.py | 2 +- 4 files changed, 63 insertions(+), 108 deletions(-) delete mode 100644 src/snowflake/connector/http_client.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index ba30b3b78b..0258e64624 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -17,6 +17,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Fixed `write_pandas` special characters usage in the location name. - Fixed usage of `use_virtual_url` when building the location for gcs storage client. - Added support for Snowflake OAuth for local applications. + - Removed boto and botocore dependencies. - v3.15.0(Apr 29,2025) - Bumped up min boto and botocore version to 1.24. diff --git a/src/snowflake/connector/http_client.py b/src/snowflake/connector/http_client.py deleted file mode 100644 index 1c802c2439..0000000000 --- a/src/snowflake/connector/http_client.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Mapping - -from .session_manager import SessionManager -from .vendored.requests import Response - -logger = logging.getLogger(__name__) - - -class HttpClient: - """HTTP client that uses SessionManager for connection pooling and adapter management.""" - - def __init__(self, session_manager: SessionManager): - """Initialize HttpClient with a SessionManager. - - Args: - session_manager: SessionManager instance to use for all requests - """ - self.session_manager = session_manager - - def request( - self, - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - use_pooling: bool | None = None, - **kwargs: Any, - ) -> Response: - """Make an HTTP request using the configured SessionManager. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - """ - mgr = ( - self.session_manager - if use_pooling is None - else self.session_manager.clone(use_pooling=use_pooling) - ) - - with mgr.use_session(url) as session: - return session.request( - method=method.upper(), - url=url, - headers=headers, - timeout=timeout_sec, - **kwargs, - ) - - -# Convenience function for backwards compatibility and simple usage -def request( - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - session_manager: SessionManager | None = None, - use_pooling: bool | None = None, - **kwargs: Any, -) -> Response: - """Convenience function for making HTTP requests. - - Args: - method: HTTP method (GET, POST, etc.) - url: Target URL - headers: Optional HTTP headers - timeout_sec: Request timeout in seconds - session_manager: SessionManager instance to use (required) - use_pooling: Whether to use connection pooling (overrides session_manager setting) - **kwargs: Additional arguments passed to requests.Session.request - - Returns: - Response object from the request - - Raises: - ValueError: If session_manager is None - """ - if session_manager is None: - raise ValueError( - "session_manager is required - no default session manager available" - ) - - client = HttpClient(session_manager) - return client.request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - use_pooling=use_pooling, - **kwargs, - ) diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index 85b46ccfd4..c6d12c9e43 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -4,11 +4,11 @@ import contextlib import itertools import logging -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable, Mapping from .compat import urlparse from .vendored import requests -from .vendored.requests import Session +from .vendored.requests import Response, Session from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.exceptions import InvalidProxyURL from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy @@ -134,8 +134,11 @@ def make_session(self) -> Session: return s @contextlib.contextmanager - def use_session(self, url: str | None = None): - if not self._use_pooling: + def use_session( + self, url: str | None = None, use_pooling: bool | None = None + ) -> Session: + use_pooling = use_pooling if use_pooling is not None else self._use_pooling + if not use_pooling: session = self.make_session() try: yield session @@ -150,6 +153,30 @@ def use_session(self, url: str | None = None): finally: pool.return_session(session) + def request( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + use_pooling: bool | None = None, + **kwargs: Any, + ) -> Response: + """Make a single HTTP request handled by this *SessionManager*. + + This wraps :pymeth:`use_session` so callers don’t have to manage the + context manager themselves. + """ + with self.use_session(url, use_pooling) as session: + return session.request( + method=method.upper(), + url=url, + headers=headers, + timeout=timeout_sec, + **kwargs, + ) + def close(self): for pool in self._sessions_map.values(): pool.close() @@ -160,3 +187,33 @@ def clone(self, *, use_pooling: bool | None = None) -> SessionManager: use_pooling=self._use_pooling if use_pooling is None else use_pooling, adapter_factory=self._adapter_factory, ) + + +def request( + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + timeout_sec: int | None = 3, + session_manager: SessionManager | None = None, + use_pooling: bool | None = None, + **kwargs: Any, +) -> Response: + """Convenience wrapper – *requires* an explicit ``session_manager``. + + This keeps a one-liner API equivalent to the old + ``snowflake.connector.http_client.request`` helper. + """ + if session_manager is None: + raise ValueError( + "session_manager is required - no default session manager available" + ) + + return session_manager.request( + method=method, + url=url, + headers=headers, + timeout_sec=timeout_sec, + use_pooling=use_pooling, + **kwargs, + ) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 4abaa81617..ca04c0e3c6 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -15,8 +15,8 @@ from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError -from .http_client import request as http_request from .session_manager import SessionManager +from .session_manager import request as http_request from .vendored import requests from .vendored.requests import Response From 5e921d8bd7e4ca537f61a015284b2932be5e1374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 11:07:01 +0200 Subject: [PATCH 26/54] SNOW-2183023: fixed test proxies --- test/unit/test_proxies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index 8835695aa2..fbd2d47268 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -60,7 +60,7 @@ def mock_proxy_manager_for_url_wiht_header(*args, **kwargs): # bad path with unittest.mock.patch( - "snowflake.connector.network.ProxySupportAdapter.proxy_manager_for", + "snowflake.connector.session_manager.ProxySupportAdapter.proxy_manager_for", mock_proxy_manager_for_url_no_header, ): with pytest.raises(OperationalError): @@ -77,7 +77,7 @@ def mock_proxy_manager_for_url_wiht_header(*args, **kwargs): # happy path with unittest.mock.patch( - "snowflake.connector.network.ProxySupportAdapter.proxy_manager_for", + "snowflake.connector.session_manager.ProxySupportAdapter.proxy_manager_for", mock_proxy_manager_for_url_wiht_header, ): with pytest.raises(OperationalError): From a3faefd3ac77c3fbb43c804dc49728f35147a583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 13:46:13 +0200 Subject: [PATCH 27/54] SNOW-2183023: fixed user id --- src/snowflake/connector/wif_util.py | 79 +----------------------- test/unit/test_auth_workload_identity.py | 9 +-- 2 files changed, 6 insertions(+), 82 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index ca04c0e3c6..bd03222144 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -243,79 +243,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None: return None -def get_aws_arn(session_manager: SessionManager | None = None) -> str | None: - """Get the current AWS workload's ARN by calling GetCallerIdentity. - - Note: This function makes a network call to AWS STS and is only used for - assertion content generation (logging and backward compatibility purposes). - The ARN is not required for authentication - it's just used as a unique - identifier for the workload in logs and assertion content. - - Returns the ARN of the current AWS identity, or None if it cannot be determined. - """ - credentials = get_aws_credentials(session_manager) - if not credentials: - logger.debug("No AWS credentials available for ARN lookup.") - return None - - region = get_aws_region(session_manager) - if not region: - logger.debug("No AWS region available for ARN lookup.") - return None - - try: - # Create the GetCallerIdentity request - sts_hostname = get_aws_sts_hostname(region) - url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" - - base_headers = { - "Content-Type": "application/x-amz-json-1.1", - } - - signed_headers = aws_signature_v4_sign( - credentials=credentials, - method="POST", - url=url, - region=region, - service="sts", - headers=base_headers, - ) - - # Make the actual request to get caller identity - response = http_request( - method="POST", - url=url, - headers=signed_headers, - timeout_sec=10, - session_manager=session_manager, - ) - - if response and response.ok: - # Parse the XML response to extract the ARN - import xml.etree.ElementTree as ET - - # Ensure content is bytes and decode it - content = response.content - if isinstance(content, bytes): - content_str = content.decode("utf-8") - else: - content_str = str(content) if content else "" - - if content_str: - root = ET.fromstring(content_str) - - # Find the Arn element in the response - for elem in root.iter(): - if elem.tag.endswith("Arn") and elem.text: - return elem.text.strip() - - logger.debug("Failed to get ARN from GetCallerIdentity response.") - return None - except Exception as e: - logger.debug(f"Error getting AWS ARN: {e}") - return None - - def get_aws_sts_hostname(region: str) -> str | None: """Constructs the AWS STS hostname for a given region. @@ -481,11 +408,7 @@ def create_aws_attestation( credential = b64encode(json.dumps(attestation_request).encode("utf-8")).decode( "utf-8" ) - - # Get the ARN for user identifier components (used only for assertion content - logging and backward compatibility) - # The ARN is not required for authentication, but provides a unique identifier for the workload - arn = get_aws_arn(session_manager) - user_identifier_components = {"arn": arn} if arn else {} + user_identifier_components = {"region": region} return WorkloadIdentityAttestation( AttestationProvider.AWS, diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index ef065af37f..aca3a200f9 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -173,10 +173,11 @@ def test_explicit_aws_generates_unique_assertion_content( auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) auth_class.prepare(conn=None) - assert ( - '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' - == auth_class.assertion_content - ) + expected = { + "_provider": "AWS", + "region": fake_aws_environment.region, + } + assert json.loads(auth_class.assertion_content) == expected @pytest.mark.parametrize( From a2c02953e217bc2565dad69d3c28057b358b166e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Tue, 8 Jul 2025 13:52:51 +0200 Subject: [PATCH 28/54] SNOW-2183023: fixed mock --- test/csp_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index e17bcbc0f6..a998c51a5c 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -349,6 +349,7 @@ def __enter__(self): mock.patch( "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn, + create=True, ), # never contact IMDS for token mock.patch( From e176195f471c7e9ac10cd0607a91c4c1f0b1b5af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Wed, 9 Jul 2025 20:26:19 +0200 Subject: [PATCH 29/54] wif credential working --- src/snowflake/connector/_aws_credentials.py | 130 ++++++++++++++++++++ src/snowflake/connector/wif_util.py | 35 +++++- 2 files changed, 160 insertions(+), 5 deletions(-) create mode 100644 src/snowflake/connector/_aws_credentials.py diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py new file mode 100644 index 0000000000..4a5bfd12a1 --- /dev/null +++ b/src/snowflake/connector/_aws_credentials.py @@ -0,0 +1,130 @@ +""" +Lightweight AWS credential resolution without boto3. + +This replicates the standard AWS SDK credential chain (environment → container → EC2 IMDSv2). +It purposely returns a `botocore.credentials.Credentials` instance so existing +code that relies on `SigV4Auth` continues to work unchanged while we phase out +boto3 usage incrementally. +""" + +from __future__ import annotations + +import logging +import os + +from .vendored import requests + +try: + from botocore.credentials import Credentials # type: ignore +except Exception: # pragma: no cover + # botocore is still available at this migration stage; if it isn’t we’ll + # replace it in a later step. + Credentials = None # type: ignore + +logger = logging.getLogger(__name__) + +# Internal constants +_ECS_CREDENTIALS_BASE_URI = "http://169.254.170.2" +_IMDS_BASE_URI = "http://169.254.169.254" + + +def _credentials_from_env() -> Credentials | None: + """Load credentials from environment variables.""" + access_key = os.getenv("AWS_ACCESS_KEY_ID") + secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") + if access_key and secret_key: + token = os.getenv("AWS_SESSION_TOKEN") + return Credentials(access_key, secret_key, token) if Credentials else None + return None + + +def _credentials_from_container() -> Credentials | None: + """Retrieve credentials from ECS / EKS task metadata (IAM Roles for Tasks).""" + rel_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + full_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") + if not rel_uri and not full_uri: + return None + creds_url = full_uri or f"{_ECS_CREDENTIALS_BASE_URI}{rel_uri}" + try: + res = requests.get(creds_url, timeout=2) + if res.ok: + data = res.json() + return ( + Credentials( + data["AccessKeyId"], + data["SecretAccessKey"], + data.get("Token"), + ) + if Credentials + else None + ) + except Exception as exc: + logger.debug("Failed to fetch container credentials: %s", exc, exc_info=True) + return None + + +def _imds_v2_token() -> str | None: + """Fetch an IMDSv2 session token (falls back silently if IMDSv1).""" + try: + res = requests.put( + f"{_IMDS_BASE_URI}/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + timeout=1, + ) + if res.ok: + return res.text + except Exception: + pass + return None + + +def _credentials_from_imds() -> Credentials | None: + """Retrieve credentials from the EC2 Instance Metadata Service (IMDS).""" + token = _imds_v2_token() + headers = {"X-aws-ec2-metadata-token": token} if token else {} + try: + role_res = requests.get( + f"{_IMDS_BASE_URI}/latest/meta-data/iam/security-credentials/", + headers=headers, + timeout=1, + ) + if not role_res.ok: + return None + role_name = role_res.text.strip() + creds_res = requests.get( + f"{_IMDS_BASE_URI}/latest/meta-data/iam/security-credentials/{role_name}", + headers=headers, + timeout=1, + ) + if not creds_res.ok: + return None + data = creds_res.json() + return ( + Credentials( + data["AccessKeyId"], + data["SecretAccessKey"], + data.get("Token"), + ) + if Credentials + else None + ) + except Exception as exc: + logger.debug("Failed to fetch IMDS credentials: %s", exc, exc_info=True) + return None + + +def load_default_credentials() -> Credentials | None: + """Attempt to load AWS credentials using the default resolution order. + + Order: environment → ECS/EKS task role → EC2 instance profile (IMDS). + Returns `None` if no credentials are found. + """ + for provider in ( + _credentials_from_env, + _credentials_from_container, + _credentials_from_imds, + ): + creds = provider() + if creds is not None: + return creds + return None diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 3449cdd5ef..6f1a9ef6de 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -7,12 +7,23 @@ from dataclasses import dataclass from enum import Enum, unique -import boto3 +try: + import boto3 # type: ignore +except ImportError: # pragma: no cover + boto3 = None # type: ignore + +try: + from botocore.auth import SigV4Auth # type: ignore + from botocore.awsrequest import AWSRequest # type: ignore + from botocore.utils import InstanceMetadataRegionFetcher # type: ignore +except ImportError: # pragma: no cover + SigV4Auth = None # type: ignore + AWSRequest = None # type: ignore + InstanceMetadataRegionFetcher = None # type: ignore + import jwt -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.utils import InstanceMetadataRegionFetcher +from ._aws_credentials import load_default_credentials from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError from .vendored import requests @@ -110,11 +121,19 @@ def get_aws_region() -> str | None: if "AWS_REGION" in os.environ: # Lambda return os.environ["AWS_REGION"] else: # EC2 + if InstanceMetadataRegionFetcher is None: + logger.debug("botocore is not available; cannot determine region via IMDS.") + return None return InstanceMetadataRegionFetcher().retrieve_region() def get_aws_arn() -> str | None: """Get the current AWS workload's ARN, if any.""" + if boto3 is None: + logger.debug( + "boto3 is not available; cannot call sts:GetCallerIdentity to fetch ARN." + ) + return None caller_identity = boto3.client("sts").get_caller_identity() if not caller_identity or "Arn" not in caller_identity: return None @@ -185,12 +204,14 @@ def get_aws_sts_hostname(region: str, partition: str) -> str | None: return None +# Ensure that botocore components are available before attempting to generate an +# AWS attestation. def create_aws_attestation() -> WorkloadIdentityAttestation | None: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, returns None. """ - aws_creds = boto3.session.Session().get_credentials() + aws_creds = load_default_credentials() if not aws_creds: logger.debug("No AWS credentials were found.") return None @@ -207,6 +228,10 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: logger.debug("No AWS partition was found.") return None + if AWSRequest is None or SigV4Auth is None: + logger.debug("botocore is not available; cannot generate AWS attestation.") + return None + sts_hostname = get_aws_sts_hostname(region, partition) request = AWSRequest( method="POST", From 1efb8143afb0fc67beb219cc23d7cff97c2856c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Wed, 9 Jul 2025 20:35:49 +0200 Subject: [PATCH 30/54] get region working --- src/snowflake/connector/_aws_credentials.py | 32 +++++++++++++++++++++ src/snowflake/connector/wif_util.py | 12 ++------ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py index 4a5bfd12a1..218d219b06 100644 --- a/src/snowflake/connector/_aws_credentials.py +++ b/src/snowflake/connector/_aws_credentials.py @@ -128,3 +128,35 @@ def load_default_credentials() -> Credentials | None: if creds is not None: return creds return None + + +def get_region() -> str | None: + """Return the AWS region for the current workload, if it can be determined. + + Resolution order: + 1. `AWS_REGION` or `AWS_DEFAULT_REGION` env vars (commonly set in Lambda/ECS). + 2. EC2 Instance Metadata Service (IMDS) – derive from availability zone. + """ + # 1. Environment variables + env_region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") + if env_region: + return env_region + + # 2. EC2 / on-prem metadata endpoint + token = _imds_v2_token() + headers = {"X-aws-ec2-metadata-token": token} if token else {} + try: + res = requests.get( + f"{_IMDS_BASE_URI}/latest/meta-data/placement/availability-zone", + headers=headers, + timeout=1, + ) + if res.ok: + az = res.text.strip() + # availability zone is region + letter, e.g. us-east-1a → us-east-1 + if len(az) >= 2 and az[-1].isalpha(): + return az[:-1] + except Exception as exc: + logger.debug("Failed to fetch region from IMDS: %s", exc, exc_info=True) + + return None diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 6f1a9ef6de..fe1f9967d9 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -23,7 +23,7 @@ import jwt -from ._aws_credentials import load_default_credentials +from ._aws_credentials import get_region, load_default_credentials from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError from .vendored import requests @@ -117,14 +117,8 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st def get_aws_region() -> str | None: - """Get the current AWS workload's region, if any.""" - if "AWS_REGION" in os.environ: # Lambda - return os.environ["AWS_REGION"] - else: # EC2 - if InstanceMetadataRegionFetcher is None: - logger.debug("botocore is not available; cannot determine region via IMDS.") - return None - return InstanceMetadataRegionFetcher().retrieve_region() + """Determine AWS region using our lightweight helper.""" + return get_region() def get_aws_arn() -> str | None: From ec76ade7979fe6a38c594799b2464d72a13a871c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Wed, 9 Jul 2025 21:12:04 +0200 Subject: [PATCH 31/54] signing not yet working --- src/snowflake/connector/sign_v4.py | 99 +++++++++++++++++++++++++++++ src/snowflake/connector/wif_util.py | 25 ++++---- 2 files changed, 112 insertions(+), 12 deletions(-) create mode 100644 src/snowflake/connector/sign_v4.py diff --git a/src/snowflake/connector/sign_v4.py b/src/snowflake/connector/sign_v4.py new file mode 100644 index 0000000000..16289d24cc --- /dev/null +++ b/src/snowflake/connector/sign_v4.py @@ -0,0 +1,99 @@ +# sign_v4.py (no external deps) +from __future__ import annotations + +import datetime +import hashlib +import hmac +import urllib.parse + +_ALGO = "AWS4-HMAC-SHA256" +_EMPTY_HASH = hashlib.sha256(b"").hexdigest() + + +def _hmac(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + +def _canonical_qs(qs: str) -> str: + pairs = urllib.parse.parse_qsl(qs, keep_blank_values=True) + pairs.sort() + safe = "-_.~" + return "&".join( + f"{urllib.parse.quote(k, safe=safe)}=" f"{urllib.parse.quote(v, safe=safe)}" + for k, v in pairs + ) + + +def sign_get_caller_identity( + url: str, + region: str, + access_key: str, + secret_key: str, + session_token: str | None = None, + extra_headers: dict[str, str] | None = None, + now: datetime.datetime | None = None, +) -> dict[str, str]: + """Return SigV4 headers for STS:GetCallerIdentity.""" + now = now or datetime.datetime.utcnow() + amz_date = now.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now.strftime("%Y%m%d") + svc = "sts" + + parsed = urllib.parse.urlparse(url) + host = parsed.netloc + canonical_uri = urllib.parse.quote(parsed.path or "/", safe="/") + canonical_qs = _canonical_qs(parsed.query) + + # ---------- headers (lower-case keys) ---------- + headers = { + "host": host, + "x-amz-date": amz_date, + "x-snowflake-audience": "snowflakecomputing.com", + } + if session_token: + headers["x-amz-security-token"] = session_token + if extra_headers: + for k, v in extra_headers.items(): + headers[k.lower()] = v.strip() + + # CanonicalHeaders & SignedHeaders + sorted_hdrs = sorted((k, " ".join(v.split())) for k, v in headers.items()) + canonical_headers = "".join(f"{k}:{v}\n" for k, v in sorted_hdrs) + signed_headers = ";".join(k for k, _ in sorted_hdrs) + + canonical_request = "\n".join( + [ + "POST", + canonical_uri, + canonical_qs, + canonical_headers, + signed_headers, + _EMPTY_HASH, + ] + ) + hash_canonical = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + + # ---------- string to sign ---------- + scope = f"{date_stamp}/{region}/{svc}/aws4_request" + string_to_sign = "\n".join([_ALGO, amz_date, scope, hash_canonical]) + + # ---------- signing key ---------- + k_date = _hmac(("AWS4" + secret_key).encode(), date_stamp) + k_region = _hmac(k_date, region) + k_service = _hmac(k_region, svc) + k_signing = _hmac(k_service, "aws4_request") + + signature = hmac.new( + k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + authorization = ( + f"{_ALGO} Credential={access_key}/{scope}, " + f"SignedHeaders={signed_headers}, Signature={signature}" + ) + + # ---------- final headers ---------- + headers["authorization"] = authorization + headers["x-amz-content-sha256"] = _EMPTY_HASH + # canonicalisation used lower-case; restore Host capitalisation if desired + return {k.title() if k == "host" else k: v for k, v in headers.items()} diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index fe1f9967d9..c30b2d9378 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -26,6 +26,7 @@ from ._aws_credentials import get_region, load_default_credentials from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError +from .sign_v4 import sign_get_caller_identity from .vendored import requests from .vendored.requests import Response @@ -227,22 +228,22 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: return None sts_hostname = get_aws_sts_hostname(region, partition) - request = AWSRequest( - method="POST", - url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", - headers={ - "Host": sts_hostname, - "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, - }, - ) - SigV4Auth(aws_creds, "sts", region).add_auth(request) + sts_url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" + signed_headers = sign_get_caller_identity( + url=sts_url, + region=region, + access_key=aws_creds.access_key, + secret_key=aws_creds.secret_key, + session_token=aws_creds.token, + ) assertion_dict = { - "url": request.url, - "method": request.method, - "headers": dict(request.headers.items()), + "url": sts_url, + "method": "POST", + "headers": signed_headers, } + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") return WorkloadIdentityAttestation( AttestationProvider.AWS, credential, {"arn": arn} From 12cf2cd5dedc03cd673dfcf0d8d4521829d833c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Wed, 9 Jul 2025 21:59:03 +0200 Subject: [PATCH 32/54] that worked --- src/snowflake/connector/sign_v4.py | 104 ++++++++++------------------ src/snowflake/connector/wif_util.py | 11 +-- 2 files changed, 40 insertions(+), 75 deletions(-) diff --git a/src/snowflake/connector/sign_v4.py b/src/snowflake/connector/sign_v4.py index 16289d24cc..86b649fe37 100644 --- a/src/snowflake/connector/sign_v4.py +++ b/src/snowflake/connector/sign_v4.py @@ -1,99 +1,69 @@ -# sign_v4.py (no external deps) +# wif_util/sign_v4.py from __future__ import annotations import datetime import hashlib import hmac -import urllib.parse +import urllib.parse as _u _ALGO = "AWS4-HMAC-SHA256" _EMPTY_HASH = hashlib.sha256(b"").hexdigest() +_SAFE = "-_.~" -def _hmac(key: bytes, msg: str) -> bytes: - return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() +def _h(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode(), hashlib.sha256).digest() def _canonical_qs(qs: str) -> str: - pairs = urllib.parse.parse_qsl(qs, keep_blank_values=True) + pairs = _u.parse_qsl(qs, keep_blank_values=True) pairs.sort() - safe = "-_.~" - return "&".join( - f"{urllib.parse.quote(k, safe=safe)}=" f"{urllib.parse.quote(v, safe=safe)}" - for k, v in pairs - ) + return "&".join(f"{_u.quote(k, _SAFE)}={_u.quote(v, _SAFE)}" for k, v in pairs) -def sign_get_caller_identity( - url: str, - region: str, - access_key: str, - secret_key: str, - session_token: str | None = None, - extra_headers: dict[str, str] | None = None, - now: datetime.datetime | None = None, -) -> dict[str, str]: - """Return SigV4 headers for STS:GetCallerIdentity.""" - now = now or datetime.datetime.utcnow() - amz_date = now.strftime("%Y%m%dT%H%M%SZ") - date_stamp = now.strftime("%Y%m%d") +def sign_get_caller_identity(url, region, access_key, secret_key, session_token=None): + now = datetime.datetime.utcnow() + amz_d = now.strftime("%Y%m%dT%H%M%SZ") + date = now.strftime("%Y%m%d") svc = "sts" - parsed = urllib.parse.urlparse(url) - host = parsed.netloc - canonical_uri = urllib.parse.quote(parsed.path or "/", safe="/") - canonical_qs = _canonical_qs(parsed.query) - - # ---------- headers (lower-case keys) ---------- - headers = { - "host": host, - "x-amz-date": amz_date, + p = _u.urlparse(url) + hdrs = { + "host": p.netloc.lower(), + "x-amz-date": amz_d, "x-snowflake-audience": "snowflakecomputing.com", + # "x-amz-content-sha256": _EMPTY_HASH, } if session_token: - headers["x-amz-security-token"] = session_token - if extra_headers: - for k, v in extra_headers.items(): - headers[k.lower()] = v.strip() - - # CanonicalHeaders & SignedHeaders - sorted_hdrs = sorted((k, " ".join(v.split())) for k, v in headers.items()) - canonical_headers = "".join(f"{k}:{v}\n" for k, v in sorted_hdrs) - signed_headers = ";".join(k for k, _ in sorted_hdrs) + hdrs["x-amz-security-token"] = session_token - canonical_request = "\n".join( + # ----- canonical request ----- + signed = ";".join(sorted(hdrs)) + can_req = "\n".join( [ "POST", - canonical_uri, - canonical_qs, - canonical_headers, - signed_headers, + _u.quote(p.path or "/", safe="/"), + _canonical_qs(p.query), + "".join(f"{k}:{hdrs[k]}\n" for k in sorted(hdrs)), + signed, _EMPTY_HASH, ] ) - hash_canonical = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + hash_can = hashlib.sha256(can_req.encode()).hexdigest() - # ---------- string to sign ---------- - scope = f"{date_stamp}/{region}/{svc}/aws4_request" - string_to_sign = "\n".join([_ALGO, amz_date, scope, hash_canonical]) + # ----- string to sign ----- + scope = f"{date}/{region}/{svc}/aws4_request" + sts = "\n".join([_ALGO, amz_d, scope, hash_can]) - # ---------- signing key ---------- - k_date = _hmac(("AWS4" + secret_key).encode(), date_stamp) - k_region = _hmac(k_date, region) - k_service = _hmac(k_region, svc) - k_signing = _hmac(k_service, "aws4_request") + # ----- HMAC chain ----- + k = _h(("AWS4" + secret_key).encode(), date) + k = _h(k, region) + k = _h(k, svc) + k = _h(k, "aws4_request") + sig = hmac.new(k, sts.encode(), hashlib.sha256).hexdigest() - signature = hmac.new( - k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 - ).hexdigest() - - authorization = ( + hdrs["authorization"] = ( f"{_ALGO} Credential={access_key}/{scope}, " - f"SignedHeaders={signed_headers}, Signature={signature}" + f"SignedHeaders={signed}, Signature={sig}" ) - - # ---------- final headers ---------- - headers["authorization"] = authorization - headers["x-amz-content-sha256"] = _EMPTY_HASH - # canonicalisation used lower-case; restore Host capitalisation if desired - return {k.title() if k == "host" else k: v for k, v in headers.items()} + return hdrs diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index c30b2d9378..9ffcb93614 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -230,7 +230,7 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: sts_hostname = get_aws_sts_hostname(region, partition) sts_url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" - signed_headers = sign_get_caller_identity( + hdrs = sign_get_caller_identity( url=sts_url, region=region, access_key=aws_creds.access_key, @@ -238,13 +238,8 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: session_token=aws_creds.token, ) - assertion_dict = { - "url": sts_url, - "method": "POST", - "headers": signed_headers, - } - - credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") + assertion_dict = {"url": sts_url, "method": "POST", "headers": hdrs} + credential = b64encode(json.dumps(assertion_dict).encode()).decode() return WorkloadIdentityAttestation( AttestationProvider.AWS, credential, {"arn": arn} ) From 0151c7d1c5d70645e11d62beec777245085d47b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Thu, 10 Jul 2025 10:12:34 +0200 Subject: [PATCH 33/54] boto3 removed --- .../connector/{sign_v4.py => _aws_sign_v4.py} | 1 - src/snowflake/connector/wif_util.py | 125 +++++++----------- 2 files changed, 48 insertions(+), 78 deletions(-) rename src/snowflake/connector/{sign_v4.py => _aws_sign_v4.py} (98%) diff --git a/src/snowflake/connector/sign_v4.py b/src/snowflake/connector/_aws_sign_v4.py similarity index 98% rename from src/snowflake/connector/sign_v4.py rename to src/snowflake/connector/_aws_sign_v4.py index 86b649fe37..de06eb987e 100644 --- a/src/snowflake/connector/sign_v4.py +++ b/src/snowflake/connector/_aws_sign_v4.py @@ -1,4 +1,3 @@ -# wif_util/sign_v4.py from __future__ import annotations import datetime diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 9ffcb93614..001c83d5c5 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -7,11 +7,6 @@ from dataclasses import dataclass from enum import Enum, unique -try: - import boto3 # type: ignore -except ImportError: # pragma: no cover - boto3 = None # type: ignore - try: from botocore.auth import SigV4Auth # type: ignore from botocore.awsrequest import AWSRequest # type: ignore @@ -24,9 +19,9 @@ import jwt from ._aws_credentials import get_region, load_default_credentials +from ._aws_sign_v4 import sign_get_caller_identity from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from .errors import ProgrammingError -from .sign_v4 import sign_get_caller_identity from .vendored import requests from .vendored.requests import Response @@ -122,19 +117,6 @@ def get_aws_region() -> str | None: return get_region() -def get_aws_arn() -> str | None: - """Get the current AWS workload's ARN, if any.""" - if boto3 is None: - logger.debug( - "boto3 is not available; cannot call sts:GetCallerIdentity to fetch ARN." - ) - return None - caller_identity = boto3.client("sts").get_caller_identity() - if not caller_identity or "Arn" not in caller_identity: - return None - return caller_identity["Arn"] - - def get_aws_partition(arn: str) -> str | None: """Get the current AWS partition from ARN, if any. @@ -156,7 +138,7 @@ def get_aws_partition(arn: str) -> str | None: return None -def get_aws_sts_hostname(region: str, partition: str) -> str | None: +def get_aws_sts_hostname(region: str) -> str | None: """Constructs the AWS STS hostname for a given region and partition. Args: @@ -172,77 +154,66 @@ def get_aws_sts_hostname(region: str, partition: str) -> str | None: - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - https://docs.aws.amazon.com/general/latest/gr/sts.html """ - if ( - not region - or not partition - or not isinstance(region, str) - or not isinstance(partition, str) - ): - return None - - if partition == "aws": - # For the 'aws' partition, STS endpoints are generally regional - # except for the global endpoint (sts.amazonaws.com) which is - # generally resolved to us-east-1 under the hood by the SDKs - # when a region is not explicitly specified. - # However, for explicit regional endpoints, the format is sts..amazonaws.com - return f"sts.{region}.amazonaws.com" - elif partition == "aws-cn": - # China regions have a different domain suffix + partition = partition_from_region(region) + if partition is AWSPartition.CHINA: return f"sts.{region}.amazonaws.com.cn" - elif partition == "aws-us-gov": - return ( - f"sts.{region}.amazonaws.com" # GovCloud uses .com, but dedicated regions - ) + elif partition is AWSPartition.BASE or partition is AWSPartition.GOV: + return f"sts.{region}.amazonaws.com" else: - logger.warning("Invalid AWS partition: %s", partition) + logger.warning("Invalid AWS partition: %s", region) return None -# Ensure that botocore components are available before attempting to generate an -# AWS attestation. -def create_aws_attestation() -> WorkloadIdentityAttestation | None: - """Tries to create a workload identity attestation for AWS. +class AWSPartition(str, Enum): + BASE = "aws" + CHINA = "aws-cn" + GOV = "aws-us-gov" - If the application isn't running on AWS or no credentials were found, returns None. - """ - aws_creds = load_default_credentials() - if not aws_creds: - logger.debug("No AWS credentials were found.") - return None - region = get_aws_region() - if not region: - logger.debug("No AWS region was found.") - return None - arn = get_aws_arn() - if not arn: - logger.debug("No AWS caller identity was found.") - return None - partition = get_aws_partition(arn) - if not partition: - logger.debug("No AWS partition was found.") - return None - if AWSRequest is None or SigV4Auth is None: - logger.debug("botocore is not available; cannot generate AWS attestation.") +def partition_from_region(region: str) -> AWSPartition: + if region.startswith("cn-"): + return AWSPartition.CHINA + if region.startswith("us-gov-"): + return AWSPartition.GOV + return AWSPartition.BASE + + +def sts_host_from_region(region: str) -> str: + part = partition_from_region(region) + suffix = ".amazonaws.com.cn" if part is AWSPartition.CHINA else ".amazonaws.com" + return f"sts.{region}{suffix}" + + +def create_aws_attestation() -> WorkloadIdentityAttestation | None: + creds = load_default_credentials() + if not creds: + logger.debug("No AWS credentials available.") return None - sts_hostname = get_aws_sts_hostname(region, partition) + region = get_region() + if not region: + logger.debug("Region could not be determined.") + return None - sts_url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15" - hdrs = sign_get_caller_identity( + sts_url = ( + f"https://{sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) + signed_headers = sign_get_caller_identity( url=sts_url, region=region, - access_key=aws_creds.access_key, - secret_key=aws_creds.secret_key, - session_token=aws_creds.token, + access_key=creds.access_key, + secret_key=creds.secret_key, + session_token=creds.token, ) - assertion_dict = {"url": sts_url, "method": "POST", "headers": hdrs} - credential = b64encode(json.dumps(assertion_dict).encode()).decode() - return WorkloadIdentityAttestation( - AttestationProvider.AWS, credential, {"arn": arn} - ) + attestation = b64encode( + json.dumps( + {"url": sts_url, "method": "POST", "headers": signed_headers} + ).encode() + ).decode() + + return WorkloadIdentityAttestation(AttestationProvider.AWS, attestation, {}) def create_gcp_attestation() -> WorkloadIdentityAttestation | None: From 39b32673956562ac59d9a3a65bba38016f314da5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Thu, 10 Jul 2025 23:08:48 +0200 Subject: [PATCH 34/54] botocore removed. cleanup --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2fd14a0ea8..3a8ce0e243 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,8 +44,6 @@ python_requires = >=3.9 packages = find_namespace: install_requires = asn1crypto>0.24.0,<2.0.0 - boto3>=1.24 - botocore>=1.24 cffi>=1.9,<2.0.0 cryptography>=3.1.0 pyOpenSSL>=22.0.0,<26.0.0 @@ -92,6 +90,8 @@ development = pytest-timeout pytest-xdist pytzdata + botocore>=1.24 + boto3>=1.24 pandas = pandas>=2.1.2,<3.0.0 pyarrow<19.0.0 From 110a63b306bed6a204cf2fd1db66075ea131bc8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Thu, 10 Jul 2025 23:11:38 +0200 Subject: [PATCH 35/54] botocore removed. cleanup --- src/snowflake/connector/wif_util.py | 123 ++++++++++------------------ 1 file changed, 45 insertions(+), 78 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 001c83d5c5..6826c08f38 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -1,20 +1,21 @@ from __future__ import annotations +"""Workload‑identity attestation helpers. + +This module builds the attestation token that the Snowflake Python connector +sends when Authenticating with *Workload Identity Federation* (WIF). +It supports AWS, Azure, GCP and generic OIDC environments **without** pulling +in heavy SDKs such as *botocore* – we only need a small presigned STS request +for AWS and a couple of metadata‑server calls for Azure / GCP. +""" + import json import logging import os from base64 import b64encode from dataclasses import dataclass from enum import Enum, unique - -try: - from botocore.auth import SigV4Auth # type: ignore - from botocore.awsrequest import AWSRequest # type: ignore - from botocore.utils import InstanceMetadataRegionFetcher # type: ignore -except ImportError: # pragma: no cover - SigV4Auth = None # type: ignore - AWSRequest = None # type: ignore - InstanceMetadataRegionFetcher = None # type: ignore +from typing import Any import jwt @@ -58,36 +59,36 @@ class AttestationProvider(Enum): @staticmethod def from_string(provider: str) -> AttestationProvider: - """Converts a string to a strongly-typed enum value of AttestationProvider.""" + """Converts a string to a strongly-typed enum value of :class:`AttestationProvider`.""" return AttestationProvider[provider.upper()] @dataclass class WorkloadIdentityAttestation: provider: AttestationProvider - credential: str - user_identifier_components: dict + credential: str # **base64** JSON blob – provider‑specific + user_identifier_components: dict[str, Any] def try_metadata_service_call( - method: str, url: str, headers: dict, timeout_sec: int = 3 + method: str, url: str, headers: dict[str, str], *, timeout: int = 3 ) -> Response | None: - """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. + """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout in seconds. If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. """ try: res: Response = requests.request( - method=method, url=url, headers=headers, timeout=timeout_sec + method=method, url=url, headers=headers, timeout=timeout ) - if not res.ok: - return None + return res if res.ok else None except requests.RequestException: return None - return res -def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]: +def extract_iss_and_sub_without_signature_verification( + jwt_str: str, +) -> tuple[str | None, str | None]: """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature. Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have @@ -99,68 +100,23 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st If there are any errors in parsing the token or extracting iss and sub, this will return (None, None). """ - try: - claims = jwt.decode(jwt_str, options={"verify_signature": False}) - except jwt.exceptions.InvalidTokenError: - logger.warning("Token is not a valid JWT.", exc_info=True) + claims = _decode_jwt_without_validation(jwt_str) + if claims is None: return None, None - if not ("iss" in claims and "sub" in claims): + if "iss" not in claims or "sub" not in claims: logger.warning("Token is missing 'iss' or 'sub' claims.") return None, None return claims["iss"], claims["sub"] -def get_aws_region() -> str | None: - """Determine AWS region using our lightweight helper.""" - return get_region() - - -def get_aws_partition(arn: str) -> str | None: - """Get the current AWS partition from ARN, if any. - - Args: - arn (str): The Amazon Resource Name (ARN) string. - - Returns: - str | None: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov') - if found, otherwise None. - - Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html. - """ - if not arn or not isinstance(arn, str): - return None - parts = arn.split(":") - if len(parts) > 1 and parts[0] == "arn" and parts[1]: - return parts[1] - logger.warning("Invalid AWS ARN: %s", arn) - return None - - -def get_aws_sts_hostname(region: str) -> str | None: - """Constructs the AWS STS hostname for a given region and partition. - - Args: - region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1'). - partition (str): The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov'). - - Returns: - str | None: The AWS STS hostname (e.g., 'sts.us-east-1.amazonaws.com') - if a valid hostname can be constructed, otherwise None. - - References: - - https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html - - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - - https://docs.aws.amazon.com/general/latest/gr/sts.html - """ - partition = partition_from_region(region) - if partition is AWSPartition.CHINA: - return f"sts.{region}.amazonaws.com.cn" - elif partition is AWSPartition.BASE or partition is AWSPartition.GOV: - return f"sts.{region}.amazonaws.com" - else: - logger.warning("Invalid AWS partition: %s", region) +def _decode_jwt_without_validation(token: str) -> Any: + """Helper that decodes *token* with ``verify_signature=False``.:contentReference[oaicite:1]{index=1}""" + try: + return jwt.decode(token, options={"verify_signature": False}) + except jwt.exceptions.InvalidTokenError: + logger.warning("Token is not a valid JWT.", exc_info=True) return None @@ -170,7 +126,7 @@ class AWSPartition(str, Enum): GOV = "aws-us-gov" -def partition_from_region(region: str) -> AWSPartition: +def _partition_from_region(region: str) -> AWSPartition: if region.startswith("cn-"): return AWSPartition.CHINA if region.startswith("us-gov-"): @@ -178,13 +134,24 @@ def partition_from_region(region: str) -> AWSPartition: return AWSPartition.BASE -def sts_host_from_region(region: str) -> str: - part = partition_from_region(region) +def _sts_host_from_region(region: str) -> str: + """ + Construct the STS endpoint hostname for *region* according to the + regionalised-STS rules published by AWS.:contentReference[oaicite:2]{index=2} + + References: + - https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html + - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html + - https://docs.aws.amazon.com/general/latest/gr/sts.html + """ + part = _partition_from_region(region) suffix = ".amazonaws.com.cn" if part is AWSPartition.CHINA else ".amazonaws.com" return f"sts.{region}{suffix}" def create_aws_attestation() -> WorkloadIdentityAttestation | None: + """Return AWS attestation or *None* if we're not on AWS / creds missing.""" + creds = load_default_credentials() if not creds: logger.debug("No AWS credentials available.") @@ -192,11 +159,11 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: region = get_region() if not region: - logger.debug("Region could not be determined.") + logger.debug("AWS region could not be determined.") return None sts_url = ( - f"https://{sts_host_from_region(region)}" + f"https://{_sts_host_from_region(region)}" "/?Action=GetCallerIdentity&Version=2011-06-15" ) signed_headers = sign_get_caller_identity( From bd581790a969f76b843fc2be3b9ead8585290f15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Fri, 11 Jul 2025 08:36:21 +0200 Subject: [PATCH 36/54] cleanup to own Credential class --- src/snowflake/connector/_aws_credentials.py | 195 +++++++++----------- 1 file changed, 84 insertions(+), 111 deletions(-) diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py index 218d219b06..1ebd62f6de 100644 --- a/src/snowflake/connector/_aws_credentials.py +++ b/src/snowflake/connector/_aws_credentials.py @@ -1,162 +1,135 @@ """ Lightweight AWS credential resolution without boto3. -This replicates the standard AWS SDK credential chain (environment → container → EC2 IMDSv2). -It purposely returns a `botocore.credentials.Credentials` instance so existing -code that relies on `SigV4Auth` continues to work unchanged while we phase out -boto3 usage incrementally. +Resolves credentials in the order: environment → ECS/EKS task metadata → EC2 IMDSv2. +Returns a minimal `Credentials` object that works with SigV4 signing helpers. """ from __future__ import annotations import logging import os +from dataclasses import dataclass +from functools import partial +from typing import Callable from .vendored import requests -try: - from botocore.credentials import Credentials # type: ignore -except Exception: # pragma: no cover - # botocore is still available at this migration stage; if it isn’t we’ll - # replace it in a later step. - Credentials = None # type: ignore - logger = logging.getLogger(__name__) -# Internal constants -_ECS_CREDENTIALS_BASE_URI = "http://169.254.170.2" -_IMDS_BASE_URI = "http://169.254.169.254" +_ECS_CRED_BASE_URL = "http://169.254.170.2" +_IMDS_BASE_URL = "http://169.254.169.254" +_IMDS_TOKEN_PATH = "/latest/api/token" +_IMDS_ROLE_PATH = "/latest/meta-data/iam/security-credentials/" +_IMDS_AZ_PATH = "/latest/meta-data/placement/availability-zone" + + +@dataclass +class Credentials: + """Minimal stand-in for ``botocore.credentials.Credentials``.""" + + access_key: str + secret_key: str + token: str | None = None -def _credentials_from_env() -> Credentials | None: - """Load credentials from environment variables.""" - access_key = os.getenv("AWS_ACCESS_KEY_ID") - secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") - if access_key and secret_key: - token = os.getenv("AWS_SESSION_TOKEN") - return Credentials(access_key, secret_key, token) if Credentials else None +def get_env_credentials() -> Credentials | None: + """Static credentials from environment variables.""" + key, secret = os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY") + if key and secret: + return Credentials(key, secret, os.getenv("AWS_SESSION_TOKEN")) return None -def _credentials_from_container() -> Credentials | None: - """Retrieve credentials from ECS / EKS task metadata (IAM Roles for Tasks).""" +def get_container_credentials(*, timeout: float) -> Credentials | None: + """Credentials from ECS/EKS task-metadata endpoint.""" rel_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") full_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") if not rel_uri and not full_uri: return None - creds_url = full_uri or f"{_ECS_CREDENTIALS_BASE_URI}{rel_uri}" + + url = full_uri or f"{_ECS_CRED_BASE_URL}{rel_uri}" try: - res = requests.get(creds_url, timeout=2) - if res.ok: - data = res.json() - return ( - Credentials( - data["AccessKeyId"], - data["SecretAccessKey"], - data.get("Token"), - ) - if Credentials - else None + response = requests.get(url, timeout=timeout) + if response.ok: + data = response.json() + return Credentials( + data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") ) - except Exception as exc: - logger.debug("Failed to fetch container credentials: %s", exc, exc_info=True) + except (requests.Timeout, requests.ConnectionError, ValueError) as exc: + logger.debug("ECS credential fetch failed: %s", exc, exc_info=True) return None -def _imds_v2_token() -> str | None: - """Fetch an IMDSv2 session token (falls back silently if IMDSv1).""" +def _get_imds_v2_token(timeout: float) -> str | None: try: - res = requests.put( - f"{_IMDS_BASE_URI}/latest/api/token", + response = requests.put( + f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}", headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, - timeout=1, + timeout=timeout, ) - if res.ok: - return res.text - except Exception: - pass - return None + return response.text if response.ok else None + except (requests.Timeout, requests.ConnectionError): + return None -def _credentials_from_imds() -> Credentials | None: - """Retrieve credentials from the EC2 Instance Metadata Service (IMDS).""" - token = _imds_v2_token() +def get_imds_credentials(*, timeout: float) -> Credentials | None: + """Instance-profile credentials from the EC2 metadata service.""" + token = _get_imds_v2_token(timeout) headers = {"X-aws-ec2-metadata-token": token} if token else {} + try: - role_res = requests.get( - f"{_IMDS_BASE_URI}/latest/meta-data/iam/security-credentials/", - headers=headers, - timeout=1, + role_resp = requests.get( + f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}", headers=headers, timeout=timeout ) - if not role_res.ok: + if not role_resp.ok: return None - role_name = role_res.text.strip() - creds_res = requests.get( - f"{_IMDS_BASE_URI}/latest/meta-data/iam/security-credentials/{role_name}", + role_name = role_resp.text.strip() + + cred_resp = requests.get( + f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{role_name}", headers=headers, - timeout=1, + timeout=timeout, ) - if not creds_res.ok: - return None - data = creds_res.json() - return ( - Credentials( - data["AccessKeyId"], - data["SecretAccessKey"], - data.get("Token"), + if cred_resp.ok: + data = cred_resp.json() + return Credentials( + data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") ) - if Credentials - else None - ) - except Exception as exc: - logger.debug("Failed to fetch IMDS credentials: %s", exc, exc_info=True) - return None + except (requests.Timeout, requests.ConnectionError, ValueError) as exc: + logger.debug("IMDS credential fetch failed: %s", exc, exc_info=True) + return None -def load_default_credentials() -> Credentials | None: - """Attempt to load AWS credentials using the default resolution order. - - Order: environment → ECS/EKS task role → EC2 instance profile (IMDS). - Returns `None` if no credentials are found. - """ - for provider in ( - _credentials_from_env, - _credentials_from_container, - _credentials_from_imds, - ): - creds = provider() - if creds is not None: - return creds +def load_default_credentials(timeout: float = 2.0) -> Credentials | None: + """Resolve credentials using the default AWS chain (env → task → IMDS).""" + providers: tuple[Callable[[], Credentials | None], ...] = ( + get_env_credentials, + partial(get_container_credentials, timeout=timeout), + partial(get_imds_credentials, timeout=timeout), + ) + for try_fetch_credentials in providers: + credentials = try_fetch_credentials() + if credentials: + return credentials return None -def get_region() -> str | None: - """Return the AWS region for the current workload, if it can be determined. +def get_region(timeout: float = 1.0) -> str | None: + """Return the current AWS region if it can be discovered.""" + if region := os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION"): + return region - Resolution order: - 1. `AWS_REGION` or `AWS_DEFAULT_REGION` env vars (commonly set in Lambda/ECS). - 2. EC2 Instance Metadata Service (IMDS) – derive from availability zone. - """ - # 1. Environment variables - env_region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") - if env_region: - return env_region - - # 2. EC2 / on-prem metadata endpoint - token = _imds_v2_token() + token = _get_imds_v2_token(timeout) headers = {"X-aws-ec2-metadata-token": token} if token else {} try: - res = requests.get( - f"{_IMDS_BASE_URI}/latest/meta-data/placement/availability-zone", - headers=headers, - timeout=1, + response = requests.get( + f"{_IMDS_BASE_URL}{_IMDS_AZ_PATH}", headers=headers, timeout=timeout ) - if res.ok: - az = res.text.strip() - # availability zone is region + letter, e.g. us-east-1a → us-east-1 - if len(az) >= 2 and az[-1].isalpha(): - return az[:-1] - except Exception as exc: - logger.debug("Failed to fetch region from IMDS: %s", exc, exc_info=True) + if response.ok: + az = response.text.strip() + return az[:-1] if az and az[-1].isalpha() else None + except (requests.Timeout, requests.ConnectionError) as exc: + logger.debug("IMDS region lookup failed: %s", exc, exc_info=True) return None From 564ac06e837376d4bb35e4169fc7e77848d48674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Fri, 11 Jul 2025 08:47:05 +0200 Subject: [PATCH 37/54] refactored files --- src/snowflake/connector/_aws_credentials.py | 22 ++-- src/snowflake/connector/_aws_sign_v4.py | 137 +++++++++++++------- 2 files changed, 98 insertions(+), 61 deletions(-) diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py index 1ebd62f6de..05cc26a154 100644 --- a/src/snowflake/connector/_aws_credentials.py +++ b/src/snowflake/connector/_aws_credentials.py @@ -2,7 +2,8 @@ Lightweight AWS credential resolution without boto3. Resolves credentials in the order: environment → ECS/EKS task metadata → EC2 IMDSv2. -Returns a minimal `Credentials` object that works with SigV4 signing helpers. +Returns a minimal `SfAWSCredentials` object that can be passed to SigV4 signing +helpers unchanged. """ from __future__ import annotations @@ -25,7 +26,7 @@ @dataclass -class Credentials: +class SfAWSCredentials: """Minimal stand-in for ``botocore.credentials.Credentials``.""" access_key: str @@ -33,15 +34,14 @@ class Credentials: token: str | None = None -def get_env_credentials() -> Credentials | None: - """Static credentials from environment variables.""" +def get_env_credentials() -> SfAWSCredentials | None: key, secret = os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY") if key and secret: - return Credentials(key, secret, os.getenv("AWS_SESSION_TOKEN")) + return SfAWSCredentials(key, secret, os.getenv("AWS_SESSION_TOKEN")) return None -def get_container_credentials(*, timeout: float) -> Credentials | None: +def get_container_credentials(*, timeout: float) -> SfAWSCredentials | None: """Credentials from ECS/EKS task-metadata endpoint.""" rel_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") full_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") @@ -53,7 +53,7 @@ def get_container_credentials(*, timeout: float) -> Credentials | None: response = requests.get(url, timeout=timeout) if response.ok: data = response.json() - return Credentials( + return SfAWSCredentials( data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") ) except (requests.Timeout, requests.ConnectionError, ValueError) as exc: @@ -73,7 +73,7 @@ def _get_imds_v2_token(timeout: float) -> str | None: return None -def get_imds_credentials(*, timeout: float) -> Credentials | None: +def get_imds_credentials(*, timeout: float) -> SfAWSCredentials | None: """Instance-profile credentials from the EC2 metadata service.""" token = _get_imds_v2_token(timeout) headers = {"X-aws-ec2-metadata-token": token} if token else {} @@ -93,7 +93,7 @@ def get_imds_credentials(*, timeout: float) -> Credentials | None: ) if cred_resp.ok: data = cred_resp.json() - return Credentials( + return SfAWSCredentials( data["AccessKeyId"], data["SecretAccessKey"], data.get("Token") ) except (requests.Timeout, requests.ConnectionError, ValueError) as exc: @@ -101,9 +101,9 @@ def get_imds_credentials(*, timeout: float) -> Credentials | None: return None -def load_default_credentials(timeout: float = 2.0) -> Credentials | None: +def load_default_credentials(timeout: float = 2.0) -> SfAWSCredentials | None: """Resolve credentials using the default AWS chain (env → task → IMDS).""" - providers: tuple[Callable[[], Credentials | None], ...] = ( + providers: tuple[Callable[[], SfAWSCredentials | None], ...] = ( get_env_credentials, partial(get_container_credentials, timeout=timeout), partial(get_imds_credentials, timeout=timeout), diff --git a/src/snowflake/connector/_aws_sign_v4.py b/src/snowflake/connector/_aws_sign_v4.py index de06eb987e..f4efc4323e 100644 --- a/src/snowflake/connector/_aws_sign_v4.py +++ b/src/snowflake/connector/_aws_sign_v4.py @@ -1,68 +1,105 @@ from __future__ import annotations -import datetime -import hashlib -import hmac -import urllib.parse as _u +import datetime as _dt +import hashlib as _hashlib +import hmac as _hmac +import urllib.parse as _urlparse -_ALGO = "AWS4-HMAC-SHA256" -_EMPTY_HASH = hashlib.sha256(b"").hexdigest() -_SAFE = "-_.~" +_ALGORITHM: str = "AWS4-HMAC-SHA256" +_EMPTY_PAYLOAD_SHA256: str = _hashlib.sha256(b"").hexdigest() +_SAFE_CHARS: str = "-_.~" -def _h(key: bytes, msg: str) -> bytes: - return hmac.new(key, msg.encode(), hashlib.sha256).digest() +def _sign(key: bytes, msg: str) -> bytes: + """Return an HMAC-SHA256 of *msg* keyed with *key*.""" + return _hmac.new(key, msg.encode(), _hashlib.sha256).digest() -def _canonical_qs(qs: str) -> str: - pairs = _u.parse_qsl(qs, keep_blank_values=True) +def _canonical_query_string(query: str) -> str: + """Return the query string in canonical (sorted & URL-escaped) form.""" + pairs = _urlparse.parse_qsl(query, keep_blank_values=True) pairs.sort() - return "&".join(f"{_u.quote(k, _SAFE)}={_u.quote(v, _SAFE)}" for k, v in pairs) + return "&".join( + f"{_urlparse.quote(k, _SAFE_CHARS)}={_urlparse.quote(v, _SAFE_CHARS)}" + for k, v in pairs + ) + +def sign_get_caller_identity( + url: str, + region: str, + access_key: str, + secret_key: str, + session_token: str | None = None, +) -> dict[str, str]: + """ + Return the SigV4 headers needed for a presigned **POST** to AWS STS + `GetCallerIdentity`. -def sign_get_caller_identity(url, region, access_key, secret_key, session_token=None): - now = datetime.datetime.utcnow() - amz_d = now.strftime("%Y%m%dT%H%M%SZ") - date = now.strftime("%Y%m%d") - svc = "sts" + Parameters + ---------- + url + The full STS endpoint with query parameters + (e.g. ``https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15``) + region + The AWS region used for signing (``us-east-1``, ``us-gov-west-1`` …). + access_key + AWS access-key ID. + secret_key + AWS secret-access key. + session_token + (Optional) session token for temporary credentials. + """ + timestamp = _dt.datetime.utcnow() + amz_date = timestamp.strftime("%Y%m%dT%H%M%SZ") + short_date = timestamp.strftime("%Y%m%d") + service = "sts" - p = _u.urlparse(url) - hdrs = { - "host": p.netloc.lower(), - "x-amz-date": amz_d, + parsed = _urlparse.urlparse(url) + + headers: dict[str, str] = { + "host": parsed.netloc.lower(), + "x-amz-date": amz_date, "x-snowflake-audience": "snowflakecomputing.com", - # "x-amz-content-sha256": _EMPTY_HASH, } if session_token: - hdrs["x-amz-security-token"] = session_token + headers["x-amz-security-token"] = session_token - # ----- canonical request ----- - signed = ";".join(sorted(hdrs)) - can_req = "\n".join( - [ + # Canonical request + signed_headers = ";".join(sorted(headers)) # e.g. host;x-amz-date;... + canonical_request = "\n".join( + ( "POST", - _u.quote(p.path or "/", safe="/"), - _canonical_qs(p.query), - "".join(f"{k}:{hdrs[k]}\n" for k in sorted(hdrs)), - signed, - _EMPTY_HASH, - ] + _urlparse.quote(parsed.path or "/", safe="/"), + _canonical_query_string(parsed.query), + "".join(f"{k}:{headers[k]}\n" for k in sorted(headers)), + signed_headers, + _EMPTY_PAYLOAD_SHA256, + ) + ) + canonical_request_hash = _hashlib.sha256(canonical_request.encode()).hexdigest() + + # String to sign + credential_scope = f"{short_date}/{region}/{service}/aws4_request" + string_to_sign = "\n".join( + (_ALGORITHM, amz_date, credential_scope, canonical_request_hash) ) - hash_can = hashlib.sha256(can_req.encode()).hexdigest() - - # ----- string to sign ----- - scope = f"{date}/{region}/{svc}/aws4_request" - sts = "\n".join([_ALGO, amz_d, scope, hash_can]) - - # ----- HMAC chain ----- - k = _h(("AWS4" + secret_key).encode(), date) - k = _h(k, region) - k = _h(k, svc) - k = _h(k, "aws4_request") - sig = hmac.new(k, sts.encode(), hashlib.sha256).hexdigest() - - hdrs["authorization"] = ( - f"{_ALGO} Credential={access_key}/{scope}, " - f"SignedHeaders={signed}, Signature={sig}" + + # Signature + key_date = _sign(("AWS4" + secret_key).encode(), short_date) + key_region = _sign(key_date, region) + key_service = _sign(key_region, service) + key_signing = _sign(key_service, "aws4_request") + signature = _hmac.new( + key_signing, string_to_sign.encode(), _hashlib.sha256 + ).hexdigest() + + # Final Authorization header + headers["authorization"] = ( + f"{_ALGORITHM} " + f"Credential={access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, " + f"Signature={signature}" ) - return hdrs + + return headers From 8df9bdf1816d8e49828020dd912e50a4f9c85f11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sat, 12 Jul 2025 14:24:40 +0200 Subject: [PATCH 38/54] arn from env vars or only region --- setup.cfg | 4 +-- src/snowflake/connector/wif_util.py | 41 ++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3a8ce0e243..ab2fc4e5a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,8 +90,8 @@ development = pytest-timeout pytest-xdist pytzdata - botocore>=1.24 - boto3>=1.24 + botocore + boto3 pandas = pandas>=2.1.2,<3.0.0 pyarrow<19.0.0 diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 6826c08f38..59499e0717 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -149,6 +149,42 @@ def _sts_host_from_region(region: str) -> str: return f"sts.{region}{suffix}" +def _try_get_arn_from_env_vars() -> str | None: + """Try to get ARN already exposed by the runtime (no extra network I/O). + + • `AWS_ROLE_ARN` – web-identity / many FaaS runtimes + • `AWS_EC2_METADATA_ARN` – some IMDSv2 environments + • `AWS_SESSION_ARN` – recent AWS SDKs export this when assuming a role + """ + for possible_arn_env_var in ( + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", + ): + value = os.getenv(possible_arn_env_var) + if value and value.startswith("arn:"): + return value + return None + + +def try_compose_aws_user_identifier(region: str | None = None) -> dict[str, str]: + """Return an identifier for the running AWS workload. + + Always includes the AWS *region*; adds an *arn* key only if one is already + discoverable via common environment variables. Returns **{}** only if + the region cannot be determined.""" + region = region or get_region() + if not region: + return {} + + identifier: dict[str, str] = {"region": region} + + if arn := _try_get_arn_from_env_vars(): + identifier["arn"] = arn + + return identifier + + def create_aws_attestation() -> WorkloadIdentityAttestation | None: """Return AWS attestation or *None* if we're not on AWS / creds missing.""" @@ -180,7 +216,10 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None: ).encode() ).decode() - return WorkloadIdentityAttestation(AttestationProvider.AWS, attestation, {}) + user_identifier = try_compose_aws_user_identifier(region) + return WorkloadIdentityAttestation( + AttestationProvider.AWS, attestation, user_identifier + ) def create_gcp_attestation() -> WorkloadIdentityAttestation | None: From 04a11d84344cd86ef79ce5456231cb55e4d0ed4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sat, 12 Jul 2025 15:38:22 +0200 Subject: [PATCH 39/54] tests working --- src/snowflake/connector/wif_util.py | 5 +- test/csp_helpers.py | 14 +-- test/unit/test_auth_workload_identity.py | 149 ++++++++++++----------- 3 files changed, 82 insertions(+), 86 deletions(-) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 59499e0717..5ba26c6d4f 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -134,7 +134,7 @@ def _partition_from_region(region: str) -> AWSPartition: return AWSPartition.BASE -def _sts_host_from_region(region: str) -> str: +def _sts_host_from_region(region: str) -> str | None: """ Construct the STS endpoint hostname for *region* according to the regionalised-STS rules published by AWS.:contentReference[oaicite:2]{index=2} @@ -144,6 +144,9 @@ def _sts_host_from_region(region: str) -> str: - https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html - https://docs.aws.amazon.com/general/latest/gr/sts.html """ + if not region or not isinstance(region, str): + return None + part = _partition_from_region(region) suffix = ".amazonaws.com.cn" if part is AWSPartition.CHINA else ".amazonaws.com" return f"sts.{region}{suffix}" diff --git a/test/csp_helpers.py b/test/csp_helpers.py index ac35336166..aa8e3bd966 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -279,26 +279,16 @@ def __enter__(self): self.patchers = [] self.patchers.append( mock.patch( - "boto3.session.Session.get_credentials", + "snowflake.connector.wif_util.load_default_credentials", side_effect=self.get_credentials, ) ) self.patchers.append( mock.patch( - "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request - ) - ) - self.patchers.append( - mock.patch( - "snowflake.connector.wif_util.get_aws_region", + "snowflake.connector.wif_util.get_region", side_effect=self.get_region, ) ) - self.patchers.append( - mock.patch( - "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn - ) - ) for patcher in self.patchers: patcher.__enter__() return self diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index f2e42aae3e..1870f01bc3 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -17,8 +17,8 @@ from snowflake.connector.wif_util import ( AZURE_ISSUER_PREFIXES, AttestationProvider, - get_aws_partition, - get_aws_sts_hostname, + _partition_from_region, + _sts_host_from_region, ) from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token @@ -47,15 +47,17 @@ def verify_aws_token(token: str, region: str): assert decoded_token["method"] == "POST" headers = decoded_token["headers"] - assert set(headers.keys()) == { - "Host", - "X-Snowflake-Audience", - "X-Amz-Date", - "X-Amz-Security-Token", - "Authorization", + headers_lc = {k.lower(): v for k, v in headers.items()} + + expected_header_keys = { + "host", + "x-snowflake-audience", + "x-amz-date", + "authorization", } - assert headers["Host"] == f"sts.{region}.amazonaws.com" - assert headers["X-Snowflake-Audience"] == "snowflakecomputing.com" + assert set(headers_lc.keys()) == expected_header_keys + assert headers_lc["host"] == f"sts.{region}.amazonaws.com" + assert headers_lc["x-snowflake-audience"] == "snowflakecomputing.com" # -- OIDC Tests -- @@ -137,7 +139,7 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro data = extract_api_data(auth_class) decoded_token = json.loads(b64decode(data["TOKEN"])) hostname_from_url = urlparse(decoded_token["url"]).hostname - hostname_from_header = decoded_token["headers"]["Host"] + hostname_from_header = decoded_token["headers"]["host"] expected_hostname = "sts.antarctica-northeast-3.amazonaws.com" assert expected_hostname == hostname_from_url @@ -147,83 +149,84 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro def test_explicit_aws_generates_unique_assertion_content( fake_aws_environment: FakeAwsEnvironment, ): - fake_aws_environment.arn = ( - "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab" - ) + # Change region to ensure assertion_content updates accordingly. + fake_aws_environment.region = "antarctica-northeast-3" + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) auth_class.prepare() - assert ( - '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' - == auth_class.assertion_content - ) + expected = '{"_provider":"AWS","region":"' + fake_aws_environment.region + '"}' + assert auth_class.assertion_content == expected @pytest.mark.parametrize( - "arn, expected_partition", + "arn_env_var", [ - ("arn:aws:iam::123456789012:role/MyTestRole", "aws"), - ( - "arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0", - "aws-cn", - ), - ("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"), - ("arn:aws:s3:::my-bucket/my/key", "aws"), - ("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"), - ("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"), - # Edge cases / Invalid inputs - ("invalid-arn", None), - ("arn::service:region:account:resource", None), # Missing partition - ("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present - ("", None), # Empty string - (None, None), # None input - (123, None), # Non-string input + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", ], ) -def test_get_aws_partition_valid_and_invalid_arns(arn, expected_partition): - assert get_aws_partition(arn) == expected_partition +def test_explicit_aws_includes_arn_when_env_present( + fake_aws_environment: FakeAwsEnvironment, + monkeypatch, + arn_env_var, +): + dummy_arn = "arn:aws:sts::123456789012:assumed-role/MyRole/i-abcdef123456" + monkeypatch.setenv(arn_env_var, dummy_arn) + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare() + + # Parse the JSON to ignore ordering. + assertion_data = json.loads(auth_class.assertion_content) + + assert assertion_data["_provider"] == "AWS" + assert assertion_data["region"] == fake_aws_environment.region + assert assertion_data["arn"] == dummy_arn @pytest.mark.parametrize( - "region, partition, expected_hostname", + "region, expected_partition", [ - # AWS partition - ("us-east-1", "aws", "sts.us-east-1.amazonaws.com"), - ("eu-west-2", "aws", "sts.eu-west-2.amazonaws.com"), - ("ap-southeast-1", "aws", "sts.ap-southeast-1.amazonaws.com"), - ( - "us-east-1", - "aws", - "sts.us-east-1.amazonaws.com", - ), # Redundant but good for coverage - # AWS China partition - ("cn-north-1", "aws-cn", "sts.cn-north-1.amazonaws.com.cn"), - ("cn-northwest-1", "aws-cn", "sts.cn-northwest-1.amazonaws.com.cn"), - ("", "aws-cn", None), # No global endpoint for 'aws-cn' without region - # AWS GovCloud partition - ("us-gov-west-1", "aws-us-gov", "sts.us-gov-west-1.amazonaws.com"), - ("us-gov-east-1", "aws-us-gov", "sts.us-gov-east-1.amazonaws.com"), - ("", "aws-us-gov", None), # No global endpoint for 'aws-us-gov' without region - # Invalid/Edge cases - ("us-east-1", "unknown-partition", None), # Unknown partition - ("some-region", "invalid-partition", None), # Invalid partition - (None, "aws", None), # None region - ("us-east-1", None, None), # None partition - (123, "aws", None), # Non-string region - ("us-east-1", 456, None), # Non-string partition - ("", "", None), # Empty region and partition - ("us-east-1", "", None), # Empty partition - ( - "invalid-region", - "aws", - "sts.invalid-region.amazonaws.com", - ), # Valid format, invalid region name + # — happy-path AWS commercial + ("us-east-1", "aws"), + ("eu-central-1", "aws"), + ("ap-south-1", "aws"), + # — China partitions + ("cn-north-1", "aws-cn"), + ("cn-northwest-1", "aws-cn"), + # — GovCloud partitions + ("us-gov-west-1", "aws-us-gov"), + ("us-gov-east-1", "aws-us-gov"), + # - Weird values also fall back to commercial + ("invalid-region", "aws"), + ("", "aws"), ], ) -def test_get_aws_sts_hostname_valid_and_invalid_inputs( - region, partition, expected_hostname -): - assert get_aws_sts_hostname(region, partition) == expected_hostname +def test_partition_from_region(region, expected_partition): + assert _partition_from_region(region).value == expected_partition + + +@pytest.mark.parametrize( + "region, expected_hostname", + [ + # commercial partition + ("us-east-1", "sts.us-east-1.amazonaws.com"), + ("eu-west-2", "sts.eu-west-2.amazonaws.com"), + # China + ("cn-north-1", "sts.cn-north-1.amazonaws.com.cn"), + # GovCloud + ("us-gov-east-1", "sts.us-gov-east-1.amazonaws.com"), + # unknown but syntactically valid - still formatted + ("invalid-region", "sts.invalid-region.amazonaws.com"), + ("", None), + (None, None), + (123, None), + ], +) +def test_sts_host_from_region_valid_inputs(region, expected_hostname): + assert _sts_host_from_region(region) == expected_hostname # -- GCP Tests -- From f018af5c66423c41f9ee42a71eb4e08f9859bc84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sat, 12 Jul 2025 18:22:19 +0200 Subject: [PATCH 40/54] botocore compatibility tests --- test/unit/test_boto_compatibility.py | 193 +++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 test/unit/test_boto_compatibility.py diff --git a/test/unit/test_boto_compatibility.py b/test/unit/test_boto_compatibility.py new file mode 100644 index 0000000000..c0e36e0885 --- /dev/null +++ b/test/unit/test_boto_compatibility.py @@ -0,0 +1,193 @@ +import urllib.parse as _urlparse + +import pytest +from botocore import session as _botocore_session # type: ignore +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + +from snowflake.connector import _aws_credentials +from snowflake.connector._aws_sign_v4 import sign_get_caller_identity +from snowflake.connector.wif_util import _sts_host_from_region + + +@pytest.mark.parametrize( + "region", + [ + "us-east-1", + "eu-west-1", + "us-gov-west-1", + ], +) +def test_sign_get_caller_identity_matches_botocore(region): + """Ensure our lightweight SigV4 signing implementation stays in lock-step with botocore. + + The main reason for this test is to detect any behavioural changes introduced + by new botocore versions that we might need to replicate in our stripped-down + implementation. The test uses static credentials and a fixed request template + (POST GetCallerIdentity) so that both implementations should end up with an + identical *Authorization* header. + """ + url = f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15" + + access_key = "AKIDEXAMPLE" + secret_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + + driver_implementation_headers = sign_get_caller_identity( + url=url, + region=region, + access_key=access_key, + secret_key=secret_key, + ) + + botocore_request = AWSRequest( + method="POST", + url=url, + headers={ + "Host": f"sts.{region}.amazonaws.com", + "X-Amz-Date": driver_implementation_headers["x-amz-date"], + "X-Snowflake-Audience": "snowflakecomputing.com", + }, + ) + + creds = Credentials(access_key, secret_key) + SigV4Auth(creds, "sts", region).add_auth(botocore_request) + + botocore_headers = { + k.lower(): v + for k, v in botocore_request.headers.items() + if k.lower() != "user-agent" + } + + driver_implementation_headers_normalized = { + k.lower(): v for k, v in driver_implementation_headers.items() + } + + assert ( + driver_implementation_headers_normalized["authorization"] + == botocore_headers["authorization"] + ) + assert ( + driver_implementation_headers_normalized["x-amz-date"] + == botocore_headers["x-amz-date"] + ) + # All headers our implementation produces must be present in botocore's output. + assert set(driver_implementation_headers_normalized.keys()).issubset( + set(botocore_headers.keys()) + ) + + +@pytest.mark.parametrize( + "region", + [ + "us-east-1", + "eu-west-1", + "us-gov-west-1", + ], +) +def test_sign_get_caller_identity_with_session_token_matches_botocore(region): + """SigV4 signing **with** temporary-session credentials must stay bit-for-bit compatible.""" + + url = f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15" + + access_key = "AKIDEXAMPLE" + secret_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + session_token = "IQoJb3JpZ2luX2VjEPr//////////wEaCXVzLWVhc3QtMSJHMEUCIQDO0f6o" + + driver_implementation_headers = sign_get_caller_identity( + url=url, + region=region, + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + ) + + botocore_request = AWSRequest( + method="POST", + url=url, + headers={ + "Host": f"sts.{region}.amazonaws.com", + "X-Amz-Date": driver_implementation_headers["x-amz-date"], + "X-Snowflake-Audience": "snowflakecomputing.com", + "X-Amz-Security-Token": session_token, + }, + ) + + creds = Credentials(access_key, secret_key, token=session_token) + SigV4Auth(creds, "sts", region).add_auth(botocore_request) + + botocore_headers = { + k.lower(): v + for k, v in botocore_request.headers.items() + if k.lower() != "user-agent" + } + + driver_implementation_headers_normalized = { + k.lower(): v for k, v in driver_implementation_headers.items() + } + + assert ( + driver_implementation_headers_normalized["authorization"] + == botocore_headers["authorization"] + ) + assert ( + driver_implementation_headers_normalized["x-amz-date"] + == botocore_headers["x-amz-date"] + ) + assert ( + driver_implementation_headers_normalized["x-amz-security-token"] + == botocore_headers["x-amz-security-token"] + ) + assert set(driver_implementation_headers_normalized.keys()).issubset( + set(botocore_headers.keys()) + ) + + +@pytest.mark.parametrize( + "region", + [ + "us-east-1", + "eu-west-1", + "us-gov-west-1", + "cn-north-1", + ], +) +def test_sts_host_from_region_matches_botocore(region): + """Ensure we derive the same STS endpoint as botocore.""" + + driver_implementation_host = _sts_host_from_region(region) + + session = _botocore_session.Session() + client = session.create_client( + "sts", + region_name=region, + aws_access_key_id="dummy", + aws_secret_access_key="dummy", + ) + boto_host = _urlparse.urlparse(client.meta.endpoint_url).netloc.lower() + + assert driver_implementation_host == boto_host + + +@pytest.mark.parametrize("env_var", ["AWS_REGION", "AWS_DEFAULT_REGION"]) +def test_get_region_matches_botocore(monkeypatch, env_var): + """Our region helper should respect the same env-var precedence as botocore.""" + + test_region = "ap-southeast-2" + + monkeypatch.delenv("AWS_REGION", raising=False) + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + monkeypatch.setenv(env_var, test_region) + + driver_region = _aws_credentials.get_region() + + session = _botocore_session.Session() + s3_client = session.create_client( + "s3", + region_name=None, + aws_access_key_id="dummy", + aws_secret_access_key="dummy", + ) + boto_region = s3_client.meta.region_name + + assert driver_region == boto_region == test_region From b5e35549be37c9e196afda7367958c462b833cd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 08:38:05 +0200 Subject: [PATCH 41/54] finished boto comp --- test/csp_helpers.py | 33 ++-- test/unit/test_boto_compatibility.py | 263 +++++++++++++-------------- 2 files changed, 149 insertions(+), 147 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index aa8e3bd966..dacfdbd31a 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -4,10 +4,12 @@ import logging import os from abc import ABC, abstractmethod +from contextlib import ExitStack from time import time from unittest import mock from urllib.parse import parse_qs, urlparse +import botocore.endpoint import jwt from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials @@ -275,24 +277,33 @@ def sign_request(self, request: AWSRequest): ) def __enter__(self): - # Patch the relevant functions to do what we want. - self.patchers = [] - self.patchers.append( + self._stack = ExitStack() + # patch connector helpers + self._stack.enter_context( mock.patch( "snowflake.connector.wif_util.load_default_credentials", side_effect=self.get_credentials, ) ) - self.patchers.append( + self._stack.enter_context( mock.patch( - "snowflake.connector.wif_util.get_region", - side_effect=self.get_region, + "snowflake.connector.wif_util.get_region", side_effect=self.get_region + ) + ) + + # hard-fail any botocore endpoint attempts – guarantees offline tests + def _no_http(*a, **k): + raise AssertionError("botocore attempted real HTTP call") + + self._stack.enter_context( + mock.patch.object( + botocore.endpoint.EndpointCreator, + "create_endpoint", + _no_http, + autospec=True, ) ) - for patcher in self.patchers: - patcher.__enter__() return self - def __exit__(self, *args, **kwargs): - for patcher in self.patchers: - patcher.__exit__(*args, **kwargs) + def __exit__(self, *exc): + self._stack.close() diff --git a/test/unit/test_boto_compatibility.py b/test/unit/test_boto_compatibility.py index c0e36e0885..e46177fc67 100644 --- a/test/unit/test_boto_compatibility.py +++ b/test/unit/test_boto_compatibility.py @@ -1,7 +1,10 @@ -import urllib.parse as _urlparse +from __future__ import annotations + +import datetime +import urllib.parse as urlparse import pytest -from botocore import session as _botocore_session # type: ignore +from botocore import session as boto_session from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials @@ -11,183 +14,171 @@ from snowflake.connector.wif_util import _sts_host_from_region -@pytest.mark.parametrize( - "region", - [ - "us-east-1", - "eu-west-1", - "us-gov-west-1", - ], -) -def test_sign_get_caller_identity_matches_botocore(region): - """Ensure our lightweight SigV4 signing implementation stays in lock-step with botocore. - - The main reason for this test is to detect any behavioural changes introduced - by new botocore versions that we might need to replicate in our stripped-down - implementation. The test uses static credentials and a fixed request template - (POST GetCallerIdentity) so that both implementations should end up with an - identical *Authorization* header. - """ - url = f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15" +def _normalise_headers(headers: dict[str, str]) -> dict[str, str]: + """Lower-case keys, trim values, drop User-Agent (botocore adds it).""" + return { + k.lower(): v.strip() for k, v in headers.items() if k.lower() != "user-agent" + } + + +@pytest.fixture +def freeze_utcnow(monkeypatch: pytest.MonkeyPatch): + """Freeze `datetime.datetime.utcnow()` for deterministic SigV4 signatures.""" + fixed = datetime.datetime(2025, 1, 1, 0, 0, 0) + + class _FrozenDateTime(datetime.datetime): + @classmethod + def utcnow(cls): # type: ignore[override] + return fixed + + monkeypatch.setattr(datetime, "datetime", _FrozenDateTime) + yield + + +@pytest.mark.parametrize("region", ["us-east-1", "eu-west-1", "us-gov-west-1"]) +def test_sigv4_parity_with_botocore(region: str, freeze_utcnow): + url = ( + f"https://{_sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) - access_key = "AKIDEXAMPLE" - secret_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + access_key_id = "AKIDEXAMPLE" + secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" - driver_implementation_headers = sign_get_caller_identity( + sf_driver_aws_headers = sign_get_caller_identity( url=url, region=region, - access_key=access_key, - secret_key=secret_key, + access_key=access_key_id, + secret_key=secret_access_key, ) - botocore_request = AWSRequest( + boto_req = AWSRequest( method="POST", url=url, headers={ - "Host": f"sts.{region}.amazonaws.com", - "X-Amz-Date": driver_implementation_headers["x-amz-date"], + "Host": sf_driver_aws_headers["host"], "X-Snowflake-Audience": "snowflakecomputing.com", + "X-Amz-Date": datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ"), }, ) - - creds = Credentials(access_key, secret_key) - SigV4Auth(creds, "sts", region).add_auth(botocore_request) - - botocore_headers = { - k.lower(): v - for k, v in botocore_request.headers.items() - if k.lower() != "user-agent" - } - - driver_implementation_headers_normalized = { - k.lower(): v for k, v in driver_implementation_headers.items() - } - - assert ( - driver_implementation_headers_normalized["authorization"] - == botocore_headers["authorization"] - ) - assert ( - driver_implementation_headers_normalized["x-amz-date"] - == botocore_headers["x-amz-date"] - ) - # All headers our implementation produces must be present in botocore's output. - assert set(driver_implementation_headers_normalized.keys()).issubset( - set(botocore_headers.keys()) + SigV4Auth(Credentials(access_key_id, secret_access_key), "sts", region).add_auth( + boto_req ) + assert "authorization" in sf_driver_aws_headers + assert _normalise_headers(sf_driver_aws_headers) == _normalise_headers( + boto_req.headers + ) -@pytest.mark.parametrize( - "region", - [ - "us-east-1", - "eu-west-1", - "us-gov-west-1", - ], -) -def test_sign_get_caller_identity_with_session_token_matches_botocore(region): - """SigV4 signing **with** temporary-session credentials must stay bit-for-bit compatible.""" - url = f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15" +@pytest.mark.parametrize("region", ["us-east-1", "eu-west-1", "us-gov-west-1"]) +def test_sigv4_parity_with_session_token(region: str, freeze_utcnow): + url = ( + f"https://{_sts_host_from_region(region)}" + "/?Action=GetCallerIdentity&Version=2011-06-15" + ) - access_key = "AKIDEXAMPLE" - secret_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" - session_token = "IQoJb3JpZ2luX2VjEPr//////////wEaCXVzLWVhc3QtMSJHMEUCIQDO0f6o" + access_key_id = "AKIDEXAMPLE" + secret_access_key = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + session_token = "IQoJb3JpZ2luX2VjEPr//////////wEaCXVzLWFz" - driver_implementation_headers = sign_get_caller_identity( + sf_driver_aws_headers = sign_get_caller_identity( url=url, region=region, - access_key=access_key, - secret_key=secret_key, + access_key=access_key_id, + secret_key=secret_access_key, session_token=session_token, ) - botocore_request = AWSRequest( + boto_req = AWSRequest( method="POST", url=url, headers={ - "Host": f"sts.{region}.amazonaws.com", - "X-Amz-Date": driver_implementation_headers["x-amz-date"], + "Host": sf_driver_aws_headers["host"], "X-Snowflake-Audience": "snowflakecomputing.com", + "X-Amz-Date": sf_driver_aws_headers["x-amz-date"], "X-Amz-Security-Token": session_token, }, ) + SigV4Auth( + Credentials(access_key_id, secret_access_key, token=session_token), + "sts", + region, + ).add_auth(boto_req) - creds = Credentials(access_key, secret_key, token=session_token) - SigV4Auth(creds, "sts", region).add_auth(botocore_request) - - botocore_headers = { - k.lower(): v - for k, v in botocore_request.headers.items() - if k.lower() != "user-agent" - } - - driver_implementation_headers_normalized = { - k.lower(): v for k, v in driver_implementation_headers.items() - } - - assert ( - driver_implementation_headers_normalized["authorization"] - == botocore_headers["authorization"] - ) - assert ( - driver_implementation_headers_normalized["x-amz-date"] - == botocore_headers["x-amz-date"] - ) - assert ( - driver_implementation_headers_normalized["x-amz-security-token"] - == botocore_headers["x-amz-security-token"] - ) - assert set(driver_implementation_headers_normalized.keys()).issubset( - set(botocore_headers.keys()) + assert _normalise_headers(sf_driver_aws_headers) == _normalise_headers( + boto_req.headers ) @pytest.mark.parametrize( - "region", - [ - "us-east-1", - "eu-west-1", - "us-gov-west-1", - "cn-north-1", - ], + "region", ["us-east-1", "eu-west-1", "us-gov-west-1", "cn-north-1"] ) -def test_sts_host_from_region_matches_botocore(region): - """Ensure we derive the same STS endpoint as botocore.""" - - driver_implementation_host = _sts_host_from_region(region) +def test_sts_host_from_region_matches_botocore( + monkeypatch: pytest.MonkeyPatch, region: str +): + sf_host = _sts_host_from_region(region) - session = _botocore_session.Session() - client = session.create_client( - "sts", - region_name=region, - aws_access_key_id="dummy", - aws_secret_access_key="dummy", - ) - boto_host = _urlparse.urlparse(client.meta.endpoint_url).netloc.lower() - - assert driver_implementation_host == boto_host + # Force botocore into **regional** mode so that it doesn’t fall back to the + # legacy global host (sts.amazonaws.com) for the particular regions (like us-east-1). + # Both approaches work correctly. + monkeypatch.setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") + boto_host = urlparse.urlparse( + boto_session.Session() + .create_client( + "sts", region_name=region, aws_access_key_id="x", aws_secret_access_key="y" + ) + .meta.endpoint_url + ).netloc.lower() -@pytest.mark.parametrize("env_var", ["AWS_REGION", "AWS_DEFAULT_REGION"]) -def test_get_region_matches_botocore(monkeypatch, env_var): - """Our region helper should respect the same env-var precedence as botocore.""" + assert sf_host == boto_host - test_region = "ap-southeast-2" +def test_region_env_var_default(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Both libraries should resolve the region from AWS_DEFAULT_REGION + without any extra hints. + """ + expected_region = "ap-southeast-2" monkeypatch.delenv("AWS_REGION", raising=False) - monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) - monkeypatch.setenv(env_var, test_region) + monkeypatch.setenv("AWS_DEFAULT_REGION", expected_region) - driver_region = _aws_credentials.get_region() + # Driver + sf_region = _aws_credentials.get_region() + assert sf_region == expected_region - session = _botocore_session.Session() - s3_client = session.create_client( - "s3", - region_name=None, - aws_access_key_id="dummy", - aws_secret_access_key="dummy", + # Botocore + boto_region = ( + boto_session.Session() + .create_client("s3", aws_access_key_id="x", aws_secret_access_key="y") + .meta.region_name ) - boto_region = s3_client.meta.region_name + assert boto_region == sf_region - assert driver_region == boto_region == test_region + +def test_region_env_var_legacy(monkeypatch: pytest.MonkeyPatch) -> None: + """ + AWS_REGION is *ignored* by botocore currently, but should be introduced in the future: https://docs.aws.amazon.com/sdkref/latest/guide/feature-region.html + Therefore for now we set it as env_var for the driver and pass via explicit parameter to botocore. + """ + desired_region = "ca-central-1" + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + monkeypatch.setenv("AWS_REGION", desired_region) + + # Snowflake helper sees AWS_REGION + sf_region = _aws_credentials.get_region() + assert sf_region == desired_region + + # botocore needs an explicit region_name when AWS_REGION is set + boto_region = ( + boto_session.Session() + .create_client( + "s3", + region_name=desired_region, + aws_access_key_id="x", + aws_secret_access_key="y", + ) + .meta.region_name + ) + assert boto_region == desired_region From 6ab7a50fde49089a73c3083c342c085644f6fcbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 09:05:28 +0200 Subject: [PATCH 42/54] SNOW-2183023: removed session manager --- src/snowflake/connector/session_manager.py | 219 --------------------- 1 file changed, 219 deletions(-) delete mode 100644 src/snowflake/connector/session_manager.py diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py deleted file mode 100644 index c6d12c9e43..0000000000 --- a/src/snowflake/connector/session_manager.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import collections -import contextlib -import itertools -import logging -from typing import TYPE_CHECKING, Any, Callable, Mapping - -from .compat import urlparse -from .vendored import requests -from .vendored.requests import Response, Session -from .vendored.requests.adapters import HTTPAdapter -from .vendored.requests.exceptions import InvalidProxyURL -from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy -from .vendored.urllib3.poolmanager import ProxyManager -from .vendored.urllib3.util.url import parse_url - -if TYPE_CHECKING: - from .vendored.urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool - -logger = logging.getLogger(__name__) - -# requests parameters -REQUESTS_RETRY = 1 # requests library builtin retry - - -class ProxySupportAdapter(HTTPAdapter): - """This Adapter creates proper headers for Proxy CONNECT messages.""" - - def get_connection( - self, url: str, proxies: dict | None = None - ) -> HTTPConnectionPool | HTTPSConnectionPool: - proxy = select_proxy(url, proxies) - parsed_url = urlparse(url) - - if proxy: - proxy = prepend_scheme_if_needed(proxy, "http") - proxy_url = parse_url(proxy) - if not proxy_url.host: - raise InvalidProxyURL( - "Please check proxy URL. It is malformed" - " and could be missing the host." - ) - proxy_manager = self.proxy_manager_for(proxy) - - if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname - else: - logger.debug( - f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" - f" it does not have attribute 'proxy_headers'." - ) - conn = proxy_manager.connection_from_url(url) - else: - # Only scheme should be lower case - url = parsed_url.geturl() - conn = self.poolmanager.connection_from_url(url) - - return conn - - -class SessionPool: - def __init__(self, manager: SessionManager) -> None: - # A stack of the idle sessions - self._idle_sessions = [] - self._active_sessions = set() - self._manager = manager - - def get_session(self) -> Session: - """Returns a session from the session pool or creates a new one.""" - try: - session = self._idle_sessions.pop() - except IndexError: - session = self._manager.make_session() - self._active_sessions.add(session) - return session - - def return_session(self, session: Session) -> None: - """Places an active session back into the idle session stack.""" - try: - self._active_sessions.remove(session) - except KeyError: - logger.debug("session doesn't exist in the active session pool. Ignored...") - self._idle_sessions.append(session) - - def __str__(self) -> str: - total_sessions = len(self._active_sessions) + len(self._idle_sessions) - return ( - f"SessionPool {len(self._active_sessions)}/{total_sessions} active sessions" - ) - - def close(self) -> None: - """Closes all active and idle sessions in this session pool.""" - if self._active_sessions: - logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for session in itertools.chain(self._active_sessions, self._idle_sessions): - try: - session.close() - except Exception as e: - logger.info(f"Session cleanup failed - failed to close session: {e}") - self._active_sessions.clear() - self._idle_sessions.clear() - - -class SessionManager: - def __init__( - self, - use_pooling: bool = True, - adapter_factory: ( - Callable[..., HTTPAdapter] | None - ) = lambda *args, **kwargs: None, - ): - self._use_pooling = use_pooling - self._adapter_factory = adapter_factory or ProxySupportAdapter - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( - lambda: SessionPool(self) - ) - - @property - def sessions_map(self) -> dict[str, SessionPool]: - return self._sessions_map - - def _mount_adapter(self, session: requests.Session) -> None: - adapter = self._adapter_factory(max_retries=REQUESTS_RETRY) - if adapter is not None: - session.mount("http://", adapter) - session.mount("https://", adapter) - - def make_session(self) -> Session: - s = requests.Session() - self._mount_adapter(s) - s._reuse_count = itertools.count() - return s - - @contextlib.contextmanager - def use_session( - self, url: str | None = None, use_pooling: bool | None = None - ) -> Session: - use_pooling = use_pooling if use_pooling is not None else self._use_pooling - if not use_pooling: - session = self.make_session() - try: - yield session - finally: - session.close() - else: - hostname = urlparse(url).hostname if url else None - pool = self._sessions_map[hostname] - session = pool.get_session() - try: - yield session - finally: - pool.return_session(session) - - def request( - self, - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - use_pooling: bool | None = None, - **kwargs: Any, - ) -> Response: - """Make a single HTTP request handled by this *SessionManager*. - - This wraps :pymeth:`use_session` so callers don’t have to manage the - context manager themselves. - """ - with self.use_session(url, use_pooling) as session: - return session.request( - method=method.upper(), - url=url, - headers=headers, - timeout=timeout_sec, - **kwargs, - ) - - def close(self): - for pool in self._sessions_map.values(): - pool.close() - - def clone(self, *, use_pooling: bool | None = None) -> SessionManager: - """Return an independent manager that reuses the adapter_factory.""" - return SessionManager( - use_pooling=self._use_pooling if use_pooling is None else use_pooling, - adapter_factory=self._adapter_factory, - ) - - -def request( - method: str, - url: str, - *, - headers: Mapping[str, str] | None = None, - timeout_sec: int | None = 3, - session_manager: SessionManager | None = None, - use_pooling: bool | None = None, - **kwargs: Any, -) -> Response: - """Convenience wrapper – *requires* an explicit ``session_manager``. - - This keeps a one-liner API equivalent to the old - ``snowflake.connector.http_client.request`` helper. - """ - if session_manager is None: - raise ValueError( - "session_manager is required - no default session manager available" - ) - - return session_manager.request( - method=method, - url=url, - headers=headers, - timeout_sec=timeout_sec, - use_pooling=use_pooling, - **kwargs, - ) From bbf50524ba4b9ef6aae32db40e5828e2c7f23510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 11:14:09 +0200 Subject: [PATCH 43/54] SNOW-2183023: slight improved tests --- test/csp_helpers.py | 61 ++++++++++++++++----------------------------- 1 file changed, 21 insertions(+), 40 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index dacfdbd31a..4843c20a30 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,5 +1,6 @@ -#!/usr/bin/env python -import datetime +from __future__ import annotations + +import contextlib import json import logging import os @@ -9,7 +10,6 @@ from unittest import mock from urllib.parse import parse_qs, urlparse -import botocore.endpoint import jwt from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials @@ -52,8 +52,9 @@ def build_response(content: bytes, status_code: int = 200) -> Response: class FakeMetadataService(ABC): """Base class for fake metadata service implementations.""" - def __init__(self): + def __init__(self) -> None: self.reset_defaults() + self._context_stack: contextlib.ExitStack | None = None @abstractmethod def reset_defaults(self): @@ -63,10 +64,9 @@ def reset_defaults(self): """ pass - @property @abstractmethod - def expected_hostname(self): - """Hostname at which this metadata service is listening. + def is_expected_hostname(self, host: str | None) -> bool: + """Checks if passed hostname is the one at which this metadata service is listening. Used to raise a ConnectTimeout for requests not targeted to this hostname. """ @@ -82,7 +82,7 @@ def __call__(self, method, url, headers, timeout): logger.debug(f"Received request: {method} {url} {str(headers)}") parsed_url = urlparse(url) - if not parsed_url.hostname == self.expected_hostname: + if not self.is_expected_hostname(parsed_url.hostname): logger.debug( f"Received request to unexpected hostname {parsed_url.hostname}" ) @@ -93,29 +93,26 @@ def __call__(self, method, url, headers, timeout): def __enter__(self): """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() - self.patchers = [] + self._context_stack = ExitStack() # requests.request is used by the direct metadata service API calls from our code. This is the main # thing being faked here. - self.patchers.append( + self._context_stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", side_effect=self ) ) # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we # simply raise a ConnectTimeout to avoid making real network calls. - self.patchers.append( + self._context_stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", side_effect=ConnectTimeout(), ) ) - for patcher in self.patchers: - patcher.__enter__() return self - def __exit__(self, *args, **kwargs): - for patcher in self.patchers: - patcher.__exit__(*args, **kwargs) + def __exit__(self, *exc): + self._context_stack.close() class NoMetadataService(FakeMetadataService): @@ -124,9 +121,8 @@ class NoMetadataService(FakeMetadataService): def reset_defaults(self): pass - @property - def expected_hostname(self): - return None # Always raise a ConnectTimeout. + def is_expected_hostname(self, host: str | None) -> bool: + return host is None # Always raise a ConnectTimeout. def handle_request(self, method, parsed_url, headers, timeout): # This should never be called because we always raise a ConnectTimeout. @@ -141,9 +137,8 @@ def reset_defaults(self): self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - @property - def expected_hostname(self): - return "169.254.169.254" + def is_expected_hostname(self, host: str | None) -> bool: + return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -176,9 +171,8 @@ def reset_defaults(self): self.identity_header = "FD80F6DA783A4881BE9FAFA365F58E7A" self.parsed_identity_endpoint = urlparse(self.identity_endpoint) - @property - def expected_hostname(self): - return self.parsed_identity_endpoint.hostname + def is_expected_hostname(self, host: str | None) -> bool: + return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -221,9 +215,8 @@ def reset_defaults(self): self.sub = "123" self.iss = "https://accounts.google.com" - @property - def expected_hostname(self): - return "169.254.169.254" + def is_expected_hostname(self, host: str | None) -> bool: + return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -291,18 +284,6 @@ def __enter__(self): ) ) - # hard-fail any botocore endpoint attempts – guarantees offline tests - def _no_http(*a, **k): - raise AssertionError("botocore attempted real HTTP call") - - self._stack.enter_context( - mock.patch.object( - botocore.endpoint.EndpointCreator, - "create_endpoint", - _no_http, - autospec=True, - ) - ) return self def __exit__(self, *exc): From 35e23cf278aca38f2738834001c354a822a2caf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 12:15:04 +0200 Subject: [PATCH 44/54] SNOW-2183023: TEMP commit metadata service initial intro --- test/csp_helpers.py | 319 ++++++++++++++++++++++++++------------------ 1 file changed, 190 insertions(+), 129 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 4843c20a30..d997cf1452 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations -import contextlib +import datetime import json import logging import os @@ -14,6 +14,12 @@ from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials +from snowflake.connector._aws_credentials import ( + _ECS_CRED_BASE_URL, + _IMDS_BASE_URL, + _IMDS_ROLE_PATH, + _IMDS_TOKEN_PATH, +) from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError from snowflake.connector.vendored.requests.models import Response @@ -21,11 +27,12 @@ def gen_dummy_id_token( - sub="test-subject", iss="test-issuer", aud="snowflakecomputing.com" + sub: str = "test-subject", + iss: str = "test-issuer", + aud: str = "snowflakecomputing.com", ) -> str: - """Generates a dummy ID token using the given subject and issuer.""" + """Generates a dummy HS256-signed JWT.""" now = int(time()) - key = "secret" payload = { "sub": sub, "iss": iss, @@ -33,76 +40,68 @@ def gen_dummy_id_token( "iat": now, "exp": now + 60 * 60, } - logger.debug(f"Generating dummy token with the following claims:\n{str(payload)}") - return jwt.encode( - payload=payload, - key=key, - algorithm="HS256", - ) + logger.debug("Generating dummy token with claims %s", payload) + return jwt.encode(payload, key="secret", algorithm="HS256") -def build_response(content: bytes, status_code: int = 200) -> Response: - """Builds a requests.Response object with the given status code and content.""" - response = Response() - response.status_code = status_code - response._content = content - return response +def build_response( + content: bytes, + status_code: int = 200, + headers: dict[str, str] | None = None, +) -> Response: + """Return a minimal Response object with canned body/headers.""" + resp = Response() + resp.status_code = status_code + resp._content = content + if headers: + resp.headers.update(headers) + return resp class FakeMetadataService(ABC): - """Base class for fake metadata service implementations.""" + """Base class for cloud-metadata fakes.""" def __init__(self) -> None: self.reset_defaults() - self._context_stack: contextlib.ExitStack | None = None + self._context_stack: ExitStack | None = None @abstractmethod - def reset_defaults(self): - """Resets any default values for test parameters. - - This is called in the constructor and when entering as a context manager. - """ - pass + def reset_defaults(self) -> None: ... @abstractmethod - def is_expected_hostname(self, host: str | None) -> bool: - """Checks if passed hostname is the one at which this metadata service is listening. - - Used to raise a ConnectTimeout for requests not targeted to this hostname. - """ - pass + def is_expected_hostname(self, host: str | None) -> bool: ... @abstractmethod - def handle_request(self, method, parsed_url, headers, timeout): - """Main business logic for handling this request. Should return a Response object.""" - pass - - def __call__(self, method, url, headers, timeout): - """Entry point for the requests mock.""" - logger.debug(f"Received request: {method} {url} {str(headers)}") - parsed_url = urlparse(url) - - if not self.is_expected_hostname(parsed_url.hostname): - logger.debug( - f"Received request to unexpected hostname {parsed_url.hostname}" - ) + def handle_request( + self, + method, + parsed_url, + headers, + timeout, + ) -> Response: ... + + def __call__(self, method, url, headers=None, timeout=None, **_kw): + """Entry-point for the requests monkey-patch.""" + headers = headers or {} + parsed = urlparse(url) + logger.debug("FakeMetadataService received %s %s %s", method, url, headers) + + if not self.is_expected_hostname(parsed.hostname): + logger.debug("Unexpected hostname %s – timeout", parsed.hostname) raise ConnectTimeout() - return self.handle_request(method, parsed_url, headers, timeout) + return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): - """Patches the relevant HTTP calls when entering as a context manager.""" + """Patch requests & urllib3 so no real traffic escapes.""" self.reset_defaults() self._context_stack = ExitStack() - # requests.request is used by the direct metadata service API calls from our code. This is the main - # thing being faked here. self._context_stack.enter_context( mock.patch( - "snowflake.connector.vendored.requests.request", side_effect=self + "snowflake.connector.vendored.requests.request", + side_effect=self, ) ) - # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we - # simply raise a ConnectTimeout to avoid making real network calls. self._context_stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", @@ -112,28 +111,26 @@ def __enter__(self): return self def __exit__(self, *exc): - self._context_stack.close() + self._context_stack.close() # type: ignore[arg-type] class NoMetadataService(FakeMetadataService): - """Emulates an environment without any metadata service.""" + """Always times out – simulates an environment without any metadata service.""" - def reset_defaults(self): + def reset_defaults(self) -> None: pass def is_expected_hostname(self, host: str | None) -> bool: - return host is None # Always raise a ConnectTimeout. + return False - def handle_request(self, method, parsed_url, headers, timeout): - # This should never be called because we always raise a ConnectTimeout. - pass + def handle_request(self, *_): + raise ConnectTimeout() class FakeAzureVmMetadataService(FakeMetadataService): - """Emulates an environment with the Azure VM metadata service.""" + """Simulates Azure VM metadata endpoint.""" - def reset_defaults(self): - # Defaults used for generating an Entra ID token. Can be overriden in individual tests. + def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" @@ -141,77 +138,62 @@ def is_expected_hostname(self, host: str | None) -> bool: return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. + qs = parse_qs(parsed_url.query) if not ( method == "GET" and parsed_url.path == "/metadata/identity/oauth2/token" and headers.get("Metadata") == "True" - and query_string["resource"] + and qs.get("resource") ): raise HTTPError() - logger.debug("Received request for Azure VM metadata service") - - resource = query_string["resource"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) - return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) + resource = qs["resource"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, resource) + return build_response(json.dumps({"access_token": self.token}).encode()) class FakeAzureFunctionMetadataService(FakeMetadataService): - """Emulates an environment with the Azure Function metadata service.""" + """Simulates Azure Functions MSI endpoint.""" - def reset_defaults(self): - # Defaults used for generating an Entra ID token. Can be overriden in individual tests. + def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - self.identity_endpoint = "http://169.254.255.2:8081/msi/token" self.identity_header = "FD80F6DA783A4881BE9FAFA365F58E7A" self.parsed_identity_endpoint = urlparse(self.identity_endpoint) + def __enter__(self): + os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint + os.environ["IDENTITY_HEADER"] = self.identity_header + return super().__enter__() + + def __exit__(self, *exc): + os.environ.pop("IDENTITY_ENDPOINT", None) + os.environ.pop("IDENTITY_HEADER", None) + return super().__exit__(*exc) + def is_expected_hostname(self, host: str | None) -> bool: return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. + qs = parse_qs(parsed_url.query) if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path and headers.get("X-IDENTITY-HEADER") == self.identity_header - and query_string["resource"] + and qs.get("resource") ): - logger.warning( - f"Received malformed request: {method} {parsed_url.path} {str(headers)} {str(query_string)}" - ) raise HTTPError() - logger.debug("Received request for Azure Functions metadata service") - - resource = query_string["resource"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) - return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) - - def __enter__(self): - # In addition to the normal patching, we need to set the environment variables that Azure Functions would set. - os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint - os.environ["IDENTITY_HEADER"] = self.identity_header - return super().__enter__() - - def __exit__(self, *args, **kwargs): - os.environ.pop("IDENTITY_ENDPOINT") - os.environ.pop("IDENTITY_HEADER") - return super().__exit__(*args, **kwargs) + resource = qs["resource"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, resource) + return build_response(json.dumps({"access_token": self.token}).encode()) class FakeGceMetadataService(FakeMetadataService): - """Emulates an environment with the GCE metadata service.""" + """Simulates GCE metadata endpoint.""" - def reset_defaults(self): - # Defaults used for generating a token. Can be overriden in individual tests. + def reset_defaults(self) -> None: self.sub = "123" self.iss = "https://accounts.google.com" @@ -219,38 +201,90 @@ def is_expected_hostname(self, host: str | None) -> bool: return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. + qs = parse_qs(parsed_url.query) if not ( method == "GET" and parsed_url.path == "/computeMetadata/v1/instance/service-accounts/default/identity" and headers.get("Metadata-Flavor") == "Google" - and query_string["audience"] + and qs.get("audience") ): raise HTTPError() - logger.debug("Received request for GCE metadata service") + audience = qs["audience"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, audience) + return build_response(self.token.encode()) - audience = query_string["audience"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) - return build_response(self.token.encode("utf-8")) +class _AwsMetadataService(FakeMetadataService): + """Low-level fake for IMDSv2 and ECS endpoints.""" -class FakeAwsEnvironment: - """Emulates the AWS environment-specific functions used in wif_util.py. + def reset_defaults(self) -> None: + self.role_name = "MyRole" + self.access_key = "AKIA_TEST" + self.secret_key = "SK_TEST" + self.session_token = "STS_TOKEN" + self.imds_token = "IMDS_TOKEN" + + def is_expected_hostname(self, host: str | None) -> bool: + return host in { + urlparse(_IMDS_BASE_URL).hostname, + urlparse(_ECS_CRED_BASE_URL).hostname, + } + + def handle_request(self, method, parsed_url, headers, timeout): + url = f"{parsed_url.scheme}://{parsed_url.hostname}{parsed_url.path}" + + if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": + return build_response( + self.imds_token.encode(), + headers={"x-aws-ec2-metadata-token-ttl-seconds": "21600"}, + ) - Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so - emulating them here would be complex and fragile. Instead, we emulate the higher-level functions - called by the connector code. - """ + if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": + return build_response(self.role_name.encode()) + + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{self.role_name}" + ): + if self.access_key is None or self.secret_key is None: + return build_response(b"", status_code=404) + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + + ecs_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + if ecs_uri and method == "GET" and url == f"{_ECS_CRED_BASE_URL}{ecs_uri}": + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + + raise ConnectTimeout() + + +class FakeAwsEnvironment: + """Context-manager fixture that fakes AWS metadata plus helper functions.""" def __init__(self): - # Defaults used for generating a token. Can be overriden in individual tests. - self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.region = "us-east-1" - self.credentials = Credentials(access_key="ak", secret_key="sk") + self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" + self.credentials: Credentials | None = Credentials( + access_key="ak", secret_key="sk" + ) + + self._metadata = _AwsMetadataService() + self._stack: ExitStack | None = None def get_region(self): return self.region @@ -261,30 +295,57 @@ def get_arn(self): def get_credentials(self): return self.credentials - def sign_request(self, request: AWSRequest): - request.headers.add_header("X-Amz-Date", datetime.time().isoformat()) - request.headers.add_header("X-Amz-Security-Token", "") - request.headers.add_header( - "Authorization", - f"AWS4-HMAC-SHA256 Credential=, SignedHeaders={';'.join(request.headers.keys())}, Signature=", + def __enter__(self): + # Keep metadata service in sync with the top-level attrs each time we enter + self._metadata.access_key = ( + self.credentials.access_key if self.credentials else None + ) + self._metadata.secret_key = ( + self.credentials.secret_key if self.credentials else None + ) + self._metadata.session_token = ( + self.credentials.token if self.credentials else None ) - def __enter__(self): self._stack = ExitStack() - # patch connector helpers self._stack.enter_context( mock.patch( - "snowflake.connector.wif_util.load_default_credentials", - side_effect=self.get_credentials, + "snowflake.connector.vendored.requests.request", + side_effect=self._metadata, ) ) self._stack.enter_context( mock.patch( - "snowflake.connector.wif_util.get_region", side_effect=self.get_region + "urllib3.connection.HTTPConnection.request", + side_effect=ConnectTimeout(), + ) + ) + self._stack.enter_context( + mock.patch( + "snowflake.connector.wif_util.get_region", + side_effect=self.get_region, + ) + ) + # critical: ensure driver’s helper uses our current credential state + self._stack.enter_context( + mock.patch( + "snowflake.connector.wif_util.load_default_credentials", + side_effect=self.get_credentials, ) ) - return self def __exit__(self, *exc): - self._stack.close() + self._stack.close() # type: ignore[arg-type] + + # Helper occasionally used in SigV4 parity tests + @staticmethod + def sign_request(request: AWSRequest): + request.headers.add_header( + "X-Amz-Date", datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + ) + request.headers.add_header("X-Amz-Security-Token", "") + request.headers.add_header( + "Authorization", + "AWS4-HMAC-SHA256 Credential=, SignedHeaders=host;x-amz-date,Signature=", + ) From 0b8c7a6b69304b389398d5f7fc162ecdec233779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 13:13:06 +0200 Subject: [PATCH 45/54] SNOW-2183023: All tests passing with many aws envs --- test/csp_helpers.py | 194 +++++++++++++++++------ test/unit/conftest.py | 23 ++- test/unit/test_auth_workload_identity.py | 42 +++-- 3 files changed, 189 insertions(+), 70 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index d997cf1452..db26c65fc6 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,6 +1,5 @@ from __future__ import annotations -import datetime import json import logging import os @@ -11,7 +10,6 @@ from urllib.parse import parse_qs, urlparse import jwt -from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials from snowflake.connector._aws_credentials import ( @@ -66,10 +64,20 @@ def __init__(self) -> None: self._context_stack: ExitStack | None = None @abstractmethod - def reset_defaults(self) -> None: ... + def reset_defaults(self) -> None: + """Resets any default values for test parameters. + + This is called in the constructor and when entering as a context manager. + """ + pass @abstractmethod - def is_expected_hostname(self, host: str | None) -> bool: ... + def is_expected_hostname(self, host: str | None) -> bool: + """Returns true if the passed hostname is the one at which this metadata service is listening. + + Used to raise a ConnectTimeout for requests not targeted to this hostname. + """ + pass @abstractmethod def handle_request( @@ -78,7 +86,9 @@ def handle_request( parsed_url, headers, timeout, - ) -> Response: ... + ) -> Response: + """Main business logic for handling this request. Should return a Response object.""" + pass def __call__(self, method, url, headers=None, timeout=None, **_kw): """Entry-point for the requests monkey-patch.""" @@ -87,13 +97,15 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw): logger.debug("FakeMetadataService received %s %s %s", method, url, headers) if not self.is_expected_hostname(parsed.hostname): - logger.debug("Unexpected hostname %s – timeout", parsed.hostname) + logger.debug( + "Received request to unexpected hostname %s – timeout", parsed.hostname + ) raise ConnectTimeout() return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): - """Patch requests & urllib3 so no real traffic escapes.""" + """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() self._context_stack = ExitStack() self._context_stack.enter_context( @@ -102,6 +114,8 @@ def __enter__(self): side_effect=self, ) ) + # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we + # simply raise a ConnectTimeout to avoid making real network calls. self._context_stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", @@ -111,7 +125,7 @@ def __enter__(self): return self def __exit__(self, *exc): - self._context_stack.close() # type: ignore[arg-type] + self._context_stack.close() class NoMetadataService(FakeMetadataService): @@ -124,13 +138,17 @@ def is_expected_hostname(self, host: str | None) -> bool: return False def handle_request(self, *_): - raise ConnectTimeout() + # This should never be called because we always raise a ConnectTimeout. + raise AssertionError( + "This should never be called because we always raise a ConnectTimeout." + ) class FakeAzureVmMetadataService(FakeMetadataService): - """Simulates Azure VM metadata endpoint.""" + """Emulates an environment with the Azure VM metadata service.""" def reset_defaults(self) -> None: + # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" @@ -138,24 +156,29 @@ def is_expected_hostname(self, host: str | None) -> bool: return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == "/metadata/identity/oauth2/token" and headers.get("Metadata") == "True" - and qs.get("resource") + and query_string.get("resource") ): raise HTTPError() - resource = qs["resource"][0] - self.token = gen_dummy_id_token(self.sub, self.iss, resource) - return build_response(json.dumps({"access_token": self.token}).encode()) + logger.debug("Received request for Azure VM metadata service") + + resource = query_string["resource"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) + return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) class FakeAzureFunctionMetadataService(FakeMetadataService): - """Simulates Azure Functions MSI endpoint.""" + """Emulates an environment with the Azure Function metadata service.""" def reset_defaults(self) -> None: + # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" self.identity_endpoint = "http://169.254.255.2:8081/msi/token" @@ -163,29 +186,43 @@ def reset_defaults(self) -> None: self.parsed_identity_endpoint = urlparse(self.identity_endpoint) def __enter__(self): - os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint - os.environ["IDENTITY_HEADER"] = self.identity_header + # Inject the variables *without* touching os.environ directly + self._stack = mock.patch.dict( + os.environ, + { + "IDENTITY_ENDPOINT": self.identity_endpoint, + "IDENTITY_HEADER": self.identity_header, + }, + clear=False, + ) + self._stack.start() return super().__enter__() def __exit__(self, *exc): - os.environ.pop("IDENTITY_ENDPOINT", None) - os.environ.pop("IDENTITY_HEADER", None) + self._stack.stop() return super().__exit__(*exc) def is_expected_hostname(self, host: str | None) -> bool: return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path and headers.get("X-IDENTITY-HEADER") == self.identity_header - and qs.get("resource") + and query_string["resource"] ): + logger.warning( + f"Received malformed request: {method} {parsed_url.path} {str(headers)} {str(query_string)}" + ) raise HTTPError() - resource = qs["resource"][0] + logger.debug("Received request for Azure Functions metadata service") + + resource = query_string["resource"][0] self.token = gen_dummy_id_token(self.sub, self.iss, resource) return build_response(json.dumps({"access_token": self.token}).encode()) @@ -194,6 +231,7 @@ class FakeGceMetadataService(FakeMetadataService): """Simulates GCE metadata endpoint.""" def reset_defaults(self) -> None: + # Defaults used for generating a token. Can be overriden in individual tests. self.sub = "123" self.iss = "https://accounts.google.com" @@ -201,19 +239,23 @@ def is_expected_hostname(self, host: str | None) -> bool: return host == "169.254.169.254" def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == "/computeMetadata/v1/instance/service-accounts/default/identity" and headers.get("Metadata-Flavor") == "Google" - and qs.get("audience") + and query_string.get("audience") ): raise HTTPError() - audience = qs["audience"][0] - self.token = gen_dummy_id_token(self.sub, self.iss, audience) - return build_response(self.token.encode()) + logger.debug("Received request for GCE metadata service") + + audience = query_string["audience"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) + return build_response(self.token.encode("utf-8")) class _AwsMetadataService(FakeMetadataService): @@ -274,18 +316,25 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeAwsEnvironment: - """Context-manager fixture that fakes AWS metadata plus helper functions.""" + """ + Base context-manager for AWS runtime fakes. + Subclasses override `_prepare_runtime()` to tweak env-vars / creds. + """ def __init__(self): + # Defaults used for generating a token. Can be overriden in individual tests. self.region = "us-east-1" self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.credentials: Credentials | None = Credentials( - access_key="ak", secret_key="sk" + access_key="ak", secret_key="sk", token="tk" ) - self._metadata = _AwsMetadataService() self._stack: ExitStack | None = None + def _prepare_runtime(self): + """Sub-classes patch env / credentials here.""" + return None + def get_region(self): return self.region @@ -296,7 +345,7 @@ def get_credentials(self): return self.credentials def __enter__(self): - # Keep metadata service in sync with the top-level attrs each time we enter + # sync stub with current creds self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) @@ -308,6 +357,7 @@ def __enter__(self): ) self._stack = ExitStack() + # patch connector helpers self._stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", @@ -321,31 +371,79 @@ def __enter__(self): ) ) self._stack.enter_context( - mock.patch( - "snowflake.connector.wif_util.get_region", - side_effect=self.get_region, - ) + mock.patch("snowflake.connector.wif_util.get_region", self.get_region) ) - # critical: ensure driver’s helper uses our current credential state self._stack.enter_context( mock.patch( "snowflake.connector.wif_util.load_default_credentials", - side_effect=self.get_credentials, + self.get_credentials, ) ) + + # runtime-specific tweaks + self._prepare_runtime() return self def __exit__(self, *exc): - self._stack.close() # type: ignore[arg-type] + self._stack.close() + + +class FakeAwsEc2(FakeAwsEnvironment): + """Default – IMDSv2 only.""" + + # nothing extra needed + + +class FakeAwsEcs(FakeAwsEnvironment): + """ECS/EKS task-role – exposes creds via task metadata endpoint.""" - # Helper occasionally used in SigV4 parity tests - @staticmethod - def sign_request(request: AWSRequest): - request.headers.add_header( - "X-Amz-Date", datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + def _prepare_runtime(self): + self._stack.enter_context( + mock.patch.dict( + os.environ, + {"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "/v2/credentials/test-id"}, + clear=False, + ) ) - request.headers.add_header("X-Amz-Security-Token", "") - request.headers.add_header( - "Authorization", - "AWS4-HMAC-SHA256 Credential=, SignedHeaders=host;x-amz-date,Signature=", + + +class FakeAwsLambda(FakeAwsEnvironment): + """Lambda runtime – temporary credentials + runtime env-vars.""" + + def __init__(self): + super().__init__() + # Lambda always returns *session* credentials + self.credentials = Credentials( + access_key="ak", + secret_key="sk", + token="dummy-session-token", + ) + + def _prepare_runtime(self) -> None: + # Patch env vars via mock.patch.dict so nothing touches os.environ directly + self._stack.enter_context( + mock.patch.dict( + os.environ, + {"AWS_LAMBDA_FUNCTION_NAME": "dummy-fn"}, + clear=False, + ) + ) + + +class FakeAwsNoCreds(FakeAwsEnvironment): + """Negative path – no credentials anywhere.""" + + def _prepare_runtime(self): + self.credentials = None + self._stack.enter_context( + mock.patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "", + "AWS_SECRET_ACCESS_KEY": "", + "AWS_SESSION_TOKEN": "", + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "", + }, + clear=False, + ) ) diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 65c2fb02f6..1a4fbf5b40 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -5,7 +5,10 @@ from snowflake.connector.telemetry_oob import TelemetryService from ..csp_helpers import ( - FakeAwsEnvironment, + FakeAwsEc2, + FakeAwsEcs, + FakeAwsLambda, + FakeAwsNoCreds, FakeAzureFunctionMetadataService, FakeAzureVmMetadataService, FakeGceMetadataService, @@ -30,10 +33,20 @@ def no_metadata_service(): yield server -@pytest.fixture -def fake_aws_environment(): - """Emulates the AWS environment, returning dummy credentials.""" - with FakeAwsEnvironment() as env: +@pytest.fixture( + params=[FakeAwsEc2, FakeAwsEcs, FakeAwsLambda], + ids=["aws_ec2", "aws_ecs", "aws_lambda"], +) +def fake_aws_environment(request): + """Runtimes that *do* expose credentials.""" + with request.param() as env: + yield env + + +@pytest.fixture(params=[FakeAwsNoCreds], ids=["aws_no_creds"]) +def malformed_aws_environment(request): + """Runtime where *no* credentials are discoverable (negative-path).""" + with request.param() as env: yield env diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 1870f01bc3..5f98f70421 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -34,30 +34,36 @@ def extract_api_data(auth_class: AuthByWorkloadIdentity): def verify_aws_token(token: str, region: str): - """Performs some basic checks on a 'token' produced for AWS, to ensure it includes the expected fields.""" - decoded_token = json.loads(b64decode(token)) + """Accepts both SigV4 variants (with / without session token).""" + decoded_payload = json.loads(b64decode(token)) - parsed_url = urlparse(decoded_token["url"]) - assert parsed_url.scheme == "https" - assert parsed_url.hostname == f"sts.{region}.amazonaws.com" - query_string = parse_qs(parsed_url.query) - assert query_string.get("Action")[0] == "GetCallerIdentity" - assert query_string.get("Version")[0] == "2011-06-15" + # URL validation + sts_request_url = urlparse(decoded_payload["url"]) + assert sts_request_url.scheme == "https" + assert sts_request_url.hostname == f"sts.{region}.amazonaws.com" - assert decoded_token["method"] == "POST" + query_params = parse_qs(sts_request_url.query) + assert query_params["Action"][0] == "GetCallerIdentity" + assert query_params["Version"][0] == "2011-06-15" - headers = decoded_token["headers"] - headers_lc = {k.lower(): v for k, v in headers.items()} + # Method validation + assert decoded_payload["method"] == "POST" - expected_header_keys = { + # Header validation + headers = {k.lower(): v for k, v in decoded_payload["headers"].items()} + + mandatory_headers = { "host", "x-snowflake-audience", "x-amz-date", "authorization", } - assert set(headers_lc.keys()) == expected_header_keys - assert headers_lc["host"] == f"sts.{region}.amazonaws.com" - assert headers_lc["x-snowflake-audience"] == "snowflakecomputing.com" + optional_headers = {"x-amz-security-token"} + + assert mandatory_headers.issubset(headers) + assert set(headers).issubset(mandatory_headers | optional_headers) + assert headers["host"] == f"sts.{region}.amazonaws.com" + assert headers["x-snowflake-audience"] == "snowflakecomputing.com" # -- OIDC Tests -- @@ -109,8 +115,10 @@ def test_explicit_oidc_no_token_raises_error(): # -- AWS Tests -- -def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironment): - fake_aws_environment.credentials = None +def test_explicit_aws_no_auth_raises_error( + malformed_aws_environment: FakeAwsEnvironment, +): + malformed_aws_environment.credentials = None auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) with pytest.raises(ProgrammingError) as excinfo: From f1fd5fac2ba994689836f419508fc5967df99cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 13:24:49 +0200 Subject: [PATCH 46/54] SNOW-2183023: all tests pass and only http is mocked --- test/csp_helpers.py | 82 +++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index db26c65fc6..ba7fc7d98c 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -323,7 +323,7 @@ class FakeAwsEnvironment: def __init__(self): # Defaults used for generating a token. Can be overriden in individual tests. - self.region = "us-east-1" + self._region = "us-east-1" self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.credentials: Credentials | None = Credentials( access_key="ak", secret_key="sk", token="tk" @@ -331,33 +331,41 @@ def __init__(self): self._metadata = _AwsMetadataService() self._stack: ExitStack | None = None - def _prepare_runtime(self): - """Sub-classes patch env / credentials here.""" - return None + @property + def region(self) -> str: + return self._region - def get_region(self): - return self.region + @region.setter + def region(self, new_region: str) -> None: + """Change runtime region and, **if** the env-vars already exist, + patch them via ExitStack so they’re cleaned up on __exit__. + """ + self._region = new_region - def get_arn(self): - return self.arn + if getattr(self, "_stack", None): + for key in ("AWS_REGION", "AWS_DEFAULT_REGION"): + if key in os.environ: # patch only if present + self._stack.enter_context( + mock.patch.dict(os.environ, {key: new_region}, clear=False) + ) - def get_credentials(self): - return self.credentials + def _prepare_runtime(self): + """Sub-classes patch env / credentials here.""" + return None def __enter__(self): - # sync stub with current creds - self._metadata.access_key = ( - self.credentials.access_key if self.credentials else None - ) - self._metadata.secret_key = ( - self.credentials.secret_key if self.credentials else None - ) - self._metadata.session_token = ( - self.credentials.token if self.credentials else None - ) + """Activate the fake AWS runtime. + * Only HTTP traffic is patched – no longer stubs `get_region` + or `load_default_credentials`. + * Region / credential discovery is driven entirely via + environment variables, so the real helper functions keep + working untouched. + """ self._stack = ExitStack() - # patch connector helpers + + # Patch outgoing HTTP calls that rely on `requests` or the low-level + # urllib client, routing them to our metadata stub or timing-out. self._stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", @@ -370,17 +378,35 @@ def __enter__(self): side_effect=ConnectTimeout(), ) ) - self._stack.enter_context( - mock.patch("snowflake.connector.wif_util.get_region", self.get_region) + + # Keep the metadata stub in sync with the final credential set. + self._metadata.access_key = ( + self.credentials.access_key if self.credentials else None + ) + self._metadata.secret_key = ( + self.credentials.secret_key if self.credentials else None ) + self._metadata.session_token = ( + self.credentials.token if self.credentials else None + ) + + # Expose region & creds *only* via env vars so that the real helper + # chain can resolve them without monkey-patching. + env_for_chain = { + "AWS_REGION": self.region, + "AWS_DEFAULT_REGION": self.region, + } + if self.credentials: + env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key + env_for_chain["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key + if self.credentials.token: + env_for_chain["AWS_SESSION_TOKEN"] = self.credentials.token + self._stack.enter_context( - mock.patch( - "snowflake.connector.wif_util.load_default_credentials", - self.get_credentials, - ) + mock.patch.dict(os.environ, env_for_chain, clear=False) ) - # runtime-specific tweaks + # Runtime-specific tweaks (may change creds / env). self._prepare_runtime() return self From 2bc28256296ea2edadd00433b90eb200548e7c8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 15:28:39 +0200 Subject: [PATCH 47/54] SNOW-2183023: base fix of constants approach --- test/csp_helpers.py | 111 +++++++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 33 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index ba7fc7d98c..6c1a773748 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import json import logging import os @@ -23,6 +24,35 @@ logger = logging.getLogger(__name__) +AZURE_VM_METADATA_HOST = "169.254.169.254" +AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" + +AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" +AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" + +GCE_METADATA_HOST = "169.254.169.254" +GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" + +AWS_REGION_ENV_KEYS = ("AWS_REGION", "AWS_DEFAULT_REGION") +AWS_CONTAINER_CRED_ENV = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" +AWS_LAMBDA_FUNCTION_ENV = "AWS_LAMBDA_FUNCTION_NAME" + +HDR_IDENTITY = "X-IDENTITY-HEADER" +HDR_METADATA = "Metadata" +HDR_METADATA_FLAVOR = "Metadata-Flavor" +HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" +IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" +IMDS_REGION_PATH = "/latest/meta-data/placement/region" + +AWS_CREDENTIAL_ENV_KEYS = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_ROLE_ARN", + "AWS_EC2_METADATA_ARN", + "AWS_SESSION_ARN", +) + def gen_dummy_id_token( sub: str = "test-subject", @@ -153,7 +183,7 @@ def reset_defaults(self) -> None: self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" def is_expected_hostname(self, host: str | None) -> bool: - return host == "169.254.169.254" + return host == AZURE_VM_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -161,8 +191,8 @@ def handle_request(self, method, parsed_url, headers, timeout): # Reject malformed requests. if not ( method == "GET" - and parsed_url.path == "/metadata/identity/oauth2/token" - and headers.get("Metadata") == "True" + and parsed_url.path == AZURE_VM_TOKEN_PATH + and headers.get(HDR_METADATA) == "True" and query_string.get("resource") ): raise HTTPError() @@ -171,7 +201,7 @@ def handle_request(self, method, parsed_url, headers, timeout): resource = query_string["resource"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) - return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) + return build_response(json.dumps({"access_token": self.token}).encode()) class FakeAzureFunctionMetadataService(FakeMetadataService): @@ -181,25 +211,36 @@ def reset_defaults(self) -> None: # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - self.identity_endpoint = "http://169.254.255.2:8081/msi/token" - self.identity_header = "FD80F6DA783A4881BE9FAFA365F58E7A" + self.identity_endpoint = AZURE_FUNCTION_IDENTITY_ENDPOINT + self.identity_header = AZURE_FUNCTION_IDENTITY_HEADER self.parsed_identity_endpoint = urlparse(self.identity_endpoint) + self._stack: contextlib.ExitStack | None = None def __enter__(self): # Inject the variables *without* touching os.environ directly - self._stack = mock.patch.dict( - os.environ, - { - "IDENTITY_ENDPOINT": self.identity_endpoint, - "IDENTITY_HEADER": self.identity_header, - }, - clear=False, + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict( + os.environ, + { + "IDENTITY_ENDPOINT": self.identity_endpoint, + "IDENTITY_HEADER": self.identity_header, + }, + clear=False, + ) + ) + self._stack.enter_context( + mock.patch.dict( + os.environ, + {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS}, + clear=False, + ) ) - self._stack.start() + return super().__enter__() def __exit__(self, *exc): - self._stack.stop() + self._stack.close() return super().__exit__(*exc) def is_expected_hostname(self, host: str | None) -> bool: @@ -212,7 +253,7 @@ def handle_request(self, method, parsed_url, headers, timeout): if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path - and headers.get("X-IDENTITY-HEADER") == self.identity_header + and headers.get(HDR_IDENTITY) == self.identity_header and query_string["resource"] ): logger.warning( @@ -236,7 +277,7 @@ def reset_defaults(self) -> None: self.iss = "https://accounts.google.com" def is_expected_hostname(self, host: str | None) -> bool: - return host == "169.254.169.254" + return host == GCE_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) @@ -244,9 +285,8 @@ def handle_request(self, method, parsed_url, headers, timeout): # Reject malformed requests. if not ( method == "GET" - and parsed_url.path - == "/computeMetadata/v1/instance/service-accounts/default/identity" - and headers.get("Metadata-Flavor") == "Google" + and parsed_url.path == GCE_IDENTITY_PATH + and headers.get(HDR_METADATA_FLAVOR) == "Google" and query_string.get("audience") ): raise HTTPError() @@ -255,7 +295,7 @@ def handle_request(self, method, parsed_url, headers, timeout): audience = query_string["audience"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) - return build_response(self.token.encode("utf-8")) + return build_response(self.token.encode()) class _AwsMetadataService(FakeMetadataService): @@ -267,6 +307,7 @@ def reset_defaults(self) -> None: self.secret_key = "SK_TEST" self.session_token = "STS_TOKEN" self.imds_token = "IMDS_TOKEN" + self.region = "us-east-1" def is_expected_hostname(self, host: str | None) -> bool: return host in { @@ -280,7 +321,7 @@ def handle_request(self, method, parsed_url, headers, timeout): if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": return build_response( self.imds_token.encode(), - headers={"x-aws-ec2-metadata-token-ttl-seconds": "21600"}, + headers={HDR_IMDS_TOKEN_TTL: "21600"}, ) if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": @@ -301,7 +342,7 @@ def handle_request(self, method, parsed_url, headers, timeout): ).encode() return build_response(creds_json) - ecs_uri = os.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + ecs_uri = os.getenv(AWS_CONTAINER_CRED_ENV) if ecs_uri and method == "GET" and url == f"{_ECS_CRED_BASE_URL}{ecs_uri}": creds_json = json.dumps( { @@ -312,6 +353,12 @@ def handle_request(self, method, parsed_url, headers, timeout): ).encode() return build_response(creds_json) + if method == "GET" and url == f"{_IMDS_BASE_URL}{IMDS_REGION_PATH}": + return build_response(self.region.encode()) + + if method == "GET" and url == f"{_IMDS_BASE_URL}{IMDS_INSTANCE_IDENTITY_DOC}": + return build_response(json.dumps({"region": self.region}).encode()) + raise ConnectTimeout() @@ -341,10 +388,10 @@ def region(self, new_region: str) -> None: patch them via ExitStack so they’re cleaned up on __exit__. """ self._region = new_region - + self._metadata.region = new_region if getattr(self, "_stack", None): - for key in ("AWS_REGION", "AWS_DEFAULT_REGION"): - if key in os.environ: # patch only if present + for key in AWS_REGION_ENV_KEYS: + if key in os.environ: self._stack.enter_context( mock.patch.dict(os.environ, {key: new_region}, clear=False) ) @@ -389,13 +436,11 @@ def __enter__(self): self._metadata.session_token = ( self.credentials.token if self.credentials else None ) + self._metadata.region = self.region if self.region else None # Expose region & creds *only* via env vars so that the real helper # chain can resolve them without monkey-patching. - env_for_chain = { - "AWS_REGION": self.region, - "AWS_DEFAULT_REGION": self.region, - } + env_for_chain = {key: self.region for key in AWS_REGION_ENV_KEYS} if self.credentials: env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key env_for_chain["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key @@ -427,7 +472,7 @@ def _prepare_runtime(self): self._stack.enter_context( mock.patch.dict( os.environ, - {"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "/v2/credentials/test-id"}, + {AWS_CONTAINER_CRED_ENV: "/v2/credentials/test-id"}, clear=False, ) ) @@ -450,7 +495,7 @@ def _prepare_runtime(self) -> None: self._stack.enter_context( mock.patch.dict( os.environ, - {"AWS_LAMBDA_FUNCTION_NAME": "dummy-fn"}, + {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, clear=False, ) ) @@ -468,7 +513,7 @@ def _prepare_runtime(self): "AWS_ACCESS_KEY_ID": "", "AWS_SECRET_ACCESS_KEY": "", "AWS_SESSION_TOKEN": "", - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "", + AWS_CONTAINER_CRED_ENV: "", }, clear=False, ) From ea742d6d6c01cefa888c049eeb9703406291767d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 15:47:51 +0200 Subject: [PATCH 48/54] SNOW-2183023: base fix of constants approach --- test/csp_helpers.py | 147 +++++++++++++++++++------------------------- 1 file changed, 64 insertions(+), 83 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 6c1a773748..4c68507df4 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -24,15 +24,6 @@ logger = logging.getLogger(__name__) -AZURE_VM_METADATA_HOST = "169.254.169.254" -AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" - -AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" -AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" - -GCE_METADATA_HOST = "169.254.169.254" -GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" - AWS_REGION_ENV_KEYS = ("AWS_REGION", "AWS_DEFAULT_REGION") AWS_CONTAINER_CRED_ENV = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" AWS_LAMBDA_FUNCTION_ENV = "AWS_LAMBDA_FUNCTION_NAME" @@ -40,9 +31,6 @@ HDR_IDENTITY = "X-IDENTITY-HEADER" HDR_METADATA = "Metadata" HDR_METADATA_FLAVOR = "Metadata-Flavor" -HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" -IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" -IMDS_REGION_PATH = "/latest/meta-data/placement/region" AWS_CREDENTIAL_ENV_KEYS = ( "AWS_ACCESS_KEY_ID", @@ -93,21 +81,20 @@ def __init__(self) -> None: self.reset_defaults() self._context_stack: ExitStack | None = None - @abstractmethod - def reset_defaults(self) -> None: - """Resets any default values for test parameters. + @staticmethod + def _clean_env_vars_for_scope() -> dict[str, str]: + """Return a mapping that blanks all AWS-specific env-vars. - This is called in the constructor and when entering as a context manager. + Used by Azure / GCP fakes so tests stay hermetic even when + executed inside a real AWS runner. """ - pass + return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS} @abstractmethod - def is_expected_hostname(self, host: str | None) -> bool: - """Returns true if the passed hostname is the one at which this metadata service is listening. + def reset_defaults(self) -> None: ... - Used to raise a ConnectTimeout for requests not targeted to this hostname. - """ - pass + @abstractmethod + def is_expected_hostname(self, host: str | None) -> bool: ... @abstractmethod def handle_request( @@ -116,12 +103,9 @@ def handle_request( parsed_url, headers, timeout, - ) -> Response: - """Main business logic for handling this request. Should return a Response object.""" - pass + ) -> Response: ... def __call__(self, method, url, headers=None, timeout=None, **_kw): - """Entry-point for the requests monkey-patch.""" headers = headers or {} parsed = urlparse(url) logger.debug("FakeMetadataService received %s %s %s", method, url, headers) @@ -135,7 +119,6 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw): return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): - """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() self._context_stack = ExitStack() self._context_stack.enter_context( @@ -144,8 +127,6 @@ def __enter__(self): side_effect=self, ) ) - # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we - # simply raise a ConnectTimeout to avoid making real network calls. self._context_stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", @@ -168,7 +149,6 @@ def is_expected_hostname(self, host: str | None) -> bool: return False def handle_request(self, *_): - # This should never be called because we always raise a ConnectTimeout. raise AssertionError( "This should never be called because we always raise a ConnectTimeout." ) @@ -177,28 +157,38 @@ def handle_request(self, *_): class FakeAzureVmMetadataService(FakeMetadataService): """Emulates an environment with the Azure VM metadata service.""" + VM_HOST = "169.254.169.254" + TOKEN_PATH = "/metadata/identity/oauth2/token" + def reset_defaults(self) -> None: - # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + def is_expected_hostname(self, host: str | None) -> bool: - return host == AZURE_VM_METADATA_HOST + return host == self.__class__.VM_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. if not ( method == "GET" - and parsed_url.path == AZURE_VM_TOKEN_PATH + and parsed_url.path == self.__class__.TOKEN_PATH and headers.get(HDR_METADATA) == "True" and query_string.get("resource") ): raise HTTPError() logger.debug("Received request for Azure VM metadata service") - resource = query_string["resource"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) return build_response(json.dumps({"access_token": self.token}).encode()) @@ -207,17 +197,18 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeAzureFunctionMetadataService(FakeMetadataService): """Emulates an environment with the Azure Function metadata service.""" + IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" + IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" + def reset_defaults(self) -> None: - # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - self.identity_endpoint = AZURE_FUNCTION_IDENTITY_ENDPOINT - self.identity_header = AZURE_FUNCTION_IDENTITY_HEADER + self.identity_endpoint = self.__class__.IDENTITY_ENDPOINT + self.identity_header = self.__class__.IDENTITY_HEADER self.parsed_identity_endpoint = urlparse(self.identity_endpoint) self._stack: contextlib.ExitStack | None = None def __enter__(self): - # Inject the variables *without* touching os.environ directly self._stack = contextlib.ExitStack() self._stack.enter_context( mock.patch.dict( @@ -230,13 +221,8 @@ def __enter__(self): ) ) self._stack.enter_context( - mock.patch.dict( - os.environ, - {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS}, - clear=False, - ) + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) ) - return super().__enter__() def __exit__(self, *exc): @@ -248,8 +234,6 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path @@ -257,12 +241,12 @@ def handle_request(self, method, parsed_url, headers, timeout): and query_string["resource"] ): logger.warning( - f"Received malformed request: {method} {parsed_url.path} {str(headers)} {str(query_string)}" + f"Received malformed request: {method} {parsed_url.path} " + f"{str(headers)} {str(query_string)}" ) raise HTTPError() logger.debug("Received request for Azure Functions metadata service") - resource = query_string["resource"][0] self.token = gen_dummy_id_token(self.sub, self.iss, resource) return build_response(json.dumps({"access_token": self.token}).encode()) @@ -271,28 +255,38 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeGceMetadataService(FakeMetadataService): """Simulates GCE metadata endpoint.""" + METADATA_HOST = "169.254.169.254" + IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" + def reset_defaults(self) -> None: - # Defaults used for generating a token. Can be overriden in individual tests. self.sub = "123" self.iss = "https://accounts.google.com" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + def is_expected_hostname(self, host: str | None) -> bool: - return host == GCE_METADATA_HOST + return host == self.__class__.METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) - - # Reject malformed requests. if not ( method == "GET" - and parsed_url.path == GCE_IDENTITY_PATH + and parsed_url.path == self.__class__.IDENTITY_PATH and headers.get(HDR_METADATA_FLAVOR) == "Google" and query_string.get("audience") ): raise HTTPError() logger.debug("Received request for GCE metadata service") - audience = query_string["audience"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) return build_response(self.token.encode()) @@ -301,6 +295,10 @@ def handle_request(self, method, parsed_url, headers, timeout): class _AwsMetadataService(FakeMetadataService): """Low-level fake for IMDSv2 and ECS endpoints.""" + HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" + IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" + IMDS_REGION_PATH = "/latest/meta-data/placement/region" + def reset_defaults(self) -> None: self.role_name = "MyRole" self.access_key = "AKIA_TEST" @@ -321,7 +319,7 @@ def handle_request(self, method, parsed_url, headers, timeout): if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": return build_response( self.imds_token.encode(), - headers={HDR_IMDS_TOKEN_TTL: "21600"}, + headers={self.__class__.HDR_IMDS_TOKEN_TTL: "21600"}, ) if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": @@ -353,10 +351,16 @@ def handle_request(self, method, parsed_url, headers, timeout): ).encode() return build_response(creds_json) - if method == "GET" and url == f"{_IMDS_BASE_URL}{IMDS_REGION_PATH}": + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_REGION_PATH}" + ): return build_response(self.region.encode()) - if method == "GET" and url == f"{_IMDS_BASE_URL}{IMDS_INSTANCE_IDENTITY_DOC}": + if ( + method == "GET" + and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_INSTANCE_IDENTITY_DOC}" + ): return build_response(json.dumps({"region": self.region}).encode()) raise ConnectTimeout() @@ -369,7 +373,6 @@ class FakeAwsEnvironment: """ def __init__(self): - # Defaults used for generating a token. Can be overriden in individual tests. self._region = "us-east-1" self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.credentials: Credentials | None = Credentials( @@ -384,9 +387,6 @@ def region(self) -> str: @region.setter def region(self, new_region: str) -> None: - """Change runtime region and, **if** the env-vars already exist, - patch them via ExitStack so they’re cleaned up on __exit__. - """ self._region = new_region self._metadata.region = new_region if getattr(self, "_stack", None): @@ -397,22 +397,10 @@ def region(self, new_region: str) -> None: ) def _prepare_runtime(self): - """Sub-classes patch env / credentials here.""" return None def __enter__(self): - """Activate the fake AWS runtime. - - * Only HTTP traffic is patched – no longer stubs `get_region` - or `load_default_credentials`. - * Region / credential discovery is driven entirely via - environment variables, so the real helper functions keep - working untouched. - """ self._stack = ExitStack() - - # Patch outgoing HTTP calls that rely on `requests` or the low-level - # urllib client, routing them to our metadata stub or timing-out. self._stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", @@ -425,8 +413,6 @@ def __enter__(self): side_effect=ConnectTimeout(), ) ) - - # Keep the metadata stub in sync with the final credential set. self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) @@ -436,10 +422,8 @@ def __enter__(self): self._metadata.session_token = ( self.credentials.token if self.credentials else None ) - self._metadata.region = self.region if self.region else None + self._metadata.region = self.region - # Expose region & creds *only* via env vars so that the real helper - # chain can resolve them without monkey-patching. env_for_chain = {key: self.region for key in AWS_REGION_ENV_KEYS} if self.credentials: env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key @@ -451,7 +435,6 @@ def __enter__(self): mock.patch.dict(os.environ, env_for_chain, clear=False) ) - # Runtime-specific tweaks (may change creds / env). self._prepare_runtime() return self @@ -483,7 +466,6 @@ class FakeAwsLambda(FakeAwsEnvironment): def __init__(self): super().__init__() - # Lambda always returns *session* credentials self.credentials = Credentials( access_key="ak", secret_key="sk", @@ -491,7 +473,6 @@ def __init__(self): ) def _prepare_runtime(self) -> None: - # Patch env vars via mock.patch.dict so nothing touches os.environ directly self._stack.enter_context( mock.patch.dict( os.environ, From 0bee4cd41f9077ab85aeeb27c1576d454aef5076 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 15:57:22 +0200 Subject: [PATCH 49/54] SNOW-2183023: this breaks code a lot - wrong order of envs cleanup compare with 2 commits ago --- test/csp_helpers.py | 80 ++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 4c68507df4..200b4eb999 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -41,6 +41,18 @@ "AWS_SESSION_ARN", ) +AZURE_ENV_KEYS = ("IDENTITY_ENDPOINT", "IDENTITY_HEADER") +GCP_ENV_KEYS = ( + "GOOGLE_APPLICATION_CREDENTIALS", + "GOOGLE_CLOUD_PROJECT", + "GCLOUD_PROJECT", + "GCP_PROJECT", +) +CLOUD_ENV_KEYS = ( + AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS + AZURE_ENV_KEYS + GCP_ENV_KEYS +) +# --------------------------------------------------------------------------- + def gen_dummy_id_token( sub: str = "test-subject", @@ -83,12 +95,12 @@ def __init__(self) -> None: @staticmethod def _clean_env_vars_for_scope() -> dict[str, str]: - """Return a mapping that blanks all AWS-specific env-vars. + """Return a mapping that blanks all known cloud-specific env-vars. - Used by Azure / GCP fakes so tests stay hermetic even when - executed inside a real AWS runner. + Ensures every fake starts from a pristine state, regardless of which + provider the CI runner itself resides on. """ - return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS} + return {k: "" for k in CLOUD_ENV_KEYS} @abstractmethod def reset_defaults(self) -> None: ... @@ -121,6 +133,12 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw): def __enter__(self): self.reset_defaults() self._context_stack = ExitStack() + + # Blanket scrub of all cloud-specific env vars + self._context_stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + self._context_stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", @@ -157,38 +175,28 @@ def handle_request(self, *_): class FakeAzureVmMetadataService(FakeMetadataService): """Emulates an environment with the Azure VM metadata service.""" - VM_HOST = "169.254.169.254" - TOKEN_PATH = "/metadata/identity/oauth2/token" + AZURE_VM_METADATA_HOST = "169.254.169.254" + AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - def __enter__(self): - self._stack = contextlib.ExitStack() - self._stack.enter_context( - mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) - ) - return super().__enter__() - - def __exit__(self, *exc): - self._stack.close() - return super().__exit__(*exc) - def is_expected_hostname(self, host: str | None) -> bool: - return host == self.__class__.VM_HOST + return host == self.__class__.AZURE_VM_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) if not ( method == "GET" - and parsed_url.path == self.__class__.TOKEN_PATH + and parsed_url.path == self.__class__.AZURE_VM_TOKEN_PATH and headers.get(HDR_METADATA) == "True" and query_string.get("resource") ): raise HTTPError() logger.debug("Received request for Azure VM metadata service") + resource = query_string["resource"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) return build_response(json.dumps({"access_token": self.token}).encode()) @@ -197,14 +205,14 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeAzureFunctionMetadataService(FakeMetadataService): """Emulates an environment with the Azure Function metadata service.""" - IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" - IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" + AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" + AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - self.identity_endpoint = self.__class__.IDENTITY_ENDPOINT - self.identity_header = self.__class__.IDENTITY_HEADER + self.identity_endpoint = self.__class__.AZURE_FUNCTION_IDENTITY_ENDPOINT + self.identity_header = self.__class__.AZURE_FUNCTION_IDENTITY_HEADER self.parsed_identity_endpoint = urlparse(self.identity_endpoint) self._stack: contextlib.ExitStack | None = None @@ -241,12 +249,12 @@ def handle_request(self, method, parsed_url, headers, timeout): and query_string["resource"] ): logger.warning( - f"Received malformed request: {method} {parsed_url.path} " - f"{str(headers)} {str(query_string)}" + f"Received malformed request: {method} {parsed_url.path} {headers} {query_string}" ) raise HTTPError() logger.debug("Received request for Azure Functions metadata service") + resource = query_string["resource"][0] self.token = gen_dummy_id_token(self.sub, self.iss, resource) return build_response(json.dumps({"access_token": self.token}).encode()) @@ -255,38 +263,28 @@ def handle_request(self, method, parsed_url, headers, timeout): class FakeGceMetadataService(FakeMetadataService): """Simulates GCE metadata endpoint.""" - METADATA_HOST = "169.254.169.254" - IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" + GCE_METADATA_HOST = "169.254.169.254" + GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" def reset_defaults(self) -> None: self.sub = "123" self.iss = "https://accounts.google.com" - def __enter__(self): - self._stack = contextlib.ExitStack() - self._stack.enter_context( - mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) - ) - return super().__enter__() - - def __exit__(self, *exc): - self._stack.close() - return super().__exit__(*exc) - def is_expected_hostname(self, host: str | None) -> bool: - return host == self.__class__.METADATA_HOST + return host == self.__class__.GCE_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) if not ( method == "GET" - and parsed_url.path == self.__class__.IDENTITY_PATH + and parsed_url.path == self.__class__.GCE_IDENTITY_PATH and headers.get(HDR_METADATA_FLAVOR) == "Google" and query_string.get("audience") ): raise HTTPError() logger.debug("Received request for GCE metadata service") + audience = query_string["audience"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) return build_response(self.token.encode()) @@ -389,6 +387,7 @@ def region(self) -> str: def region(self, new_region: str) -> None: self._region = new_region self._metadata.region = new_region + if getattr(self, "_stack", None): for key in AWS_REGION_ENV_KEYS: if key in os.environ: @@ -413,6 +412,7 @@ def __enter__(self): side_effect=ConnectTimeout(), ) ) + self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) From e3a00ab6960f79b2f92045f0f0923db421cbc4c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 16:13:35 +0200 Subject: [PATCH 50/54] SNOW-2183023: breaks code less - wrong order of envs cleanup compare with 2 commits ago --- test/csp_helpers.py | 232 ++++++++++++++++---------------------------- 1 file changed, 83 insertions(+), 149 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 200b4eb999..aa45066b30 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import json import logging import os @@ -24,6 +23,15 @@ logger = logging.getLogger(__name__) +AZURE_VM_METADATA_HOST = "169.254.169.254" +AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" + +AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" +AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" + +GCE_METADATA_HOST = "169.254.169.254" +GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" + AWS_REGION_ENV_KEYS = ("AWS_REGION", "AWS_DEFAULT_REGION") AWS_CONTAINER_CRED_ENV = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" AWS_LAMBDA_FUNCTION_ENV = "AWS_LAMBDA_FUNCTION_NAME" @@ -41,6 +49,7 @@ "AWS_SESSION_ARN", ) +# ------------ additional bundles to wipe up-front --------------------------- AZURE_ENV_KEYS = ("IDENTITY_ENDPOINT", "IDENTITY_HEADER") GCP_ENV_KEYS = ( "GOOGLE_APPLICATION_CREDENTIALS", @@ -51,7 +60,7 @@ CLOUD_ENV_KEYS = ( AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS + AZURE_ENV_KEYS + GCP_ENV_KEYS ) -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- def gen_dummy_id_token( @@ -59,16 +68,8 @@ def gen_dummy_id_token( iss: str = "test-issuer", aud: str = "snowflakecomputing.com", ) -> str: - """Generates a dummy HS256-signed JWT.""" now = int(time()) - payload = { - "sub": sub, - "iss": iss, - "aud": aud, - "iat": now, - "exp": now + 60 * 60, - } - logger.debug("Generating dummy token with claims %s", payload) + payload = {"sub": sub, "iss": iss, "aud": aud, "iat": now, "exp": now + 3600} return jwt.encode(payload, key="secret", algorithm="HS256") @@ -77,7 +78,6 @@ def build_response( status_code: int = 200, headers: dict[str, str] | None = None, ) -> Response: - """Return a minimal Response object with canned body/headers.""" resp = Response() resp.status_code = status_code resp._content = content @@ -87,21 +87,20 @@ def build_response( class FakeMetadataService(ABC): - """Base class for cloud-metadata fakes.""" + """Base class for all cloud-metadata fakes.""" def __init__(self) -> None: self.reset_defaults() self._context_stack: ExitStack | None = None + # ------------------------------------------------------------------ utils @staticmethod def _clean_env_vars_for_scope() -> dict[str, str]: - """Return a mapping that blanks all known cloud-specific env-vars. - - Ensures every fake starts from a pristine state, regardless of which - provider the CI runner itself resides on. - """ + """Blank all major cloud-specific env-vars for a hermetic test.""" return {k: "" for k in CLOUD_ENV_KEYS} + # ------------------------------------------------------------------------ + @abstractmethod def reset_defaults(self) -> None: ... @@ -109,36 +108,24 @@ def reset_defaults(self) -> None: ... def is_expected_hostname(self, host: str | None) -> bool: ... @abstractmethod - def handle_request( - self, - method, - parsed_url, - headers, - timeout, - ) -> Response: ... + def handle_request(self, method, parsed_url, headers, timeout) -> Response: ... + # -------------------------------------------------------- context helpers def __call__(self, method, url, headers=None, timeout=None, **_kw): headers = headers or {} parsed = urlparse(url) - logger.debug("FakeMetadataService received %s %s %s", method, url, headers) - if not self.is_expected_hostname(parsed.hostname): - logger.debug( - "Received request to unexpected hostname %s – timeout", parsed.hostname - ) raise ConnectTimeout() - return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): self.reset_defaults() self._context_stack = ExitStack() - - # Blanket scrub of all cloud-specific env vars + # first – wipe every cloud env-var self._context_stack.enter_context( mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) ) - + # route HTTP calls through this fake self._context_stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", @@ -156,69 +143,55 @@ def __enter__(self): def __exit__(self, *exc): self._context_stack.close() + # ------------------------------------------------------------------------ -class NoMetadataService(FakeMetadataService): - """Always times out – simulates an environment without any metadata service.""" - def reset_defaults(self) -> None: - pass +class NoMetadataService(FakeMetadataService): + def reset_defaults(self) -> None: ... def is_expected_hostname(self, host: str | None) -> bool: return False def handle_request(self, *_): - raise AssertionError( - "This should never be called because we always raise a ConnectTimeout." - ) + raise AssertionError +# --------------------------- Azure fakes ----------------------------------- class FakeAzureVmMetadataService(FakeMetadataService): - """Emulates an environment with the Azure VM metadata service.""" - - AZURE_VM_METADATA_HOST = "169.254.169.254" - AZURE_VM_TOKEN_PATH = "/metadata/identity/oauth2/token" - def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" def is_expected_hostname(self, host: str | None) -> bool: - return host == self.__class__.AZURE_VM_METADATA_HOST + return host == AZURE_VM_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) + qs = parse_qs(parsed_url.query) if not ( method == "GET" - and parsed_url.path == self.__class__.AZURE_VM_TOKEN_PATH + and parsed_url.path == AZURE_VM_TOKEN_PATH and headers.get(HDR_METADATA) == "True" - and query_string.get("resource") + and qs.get("resource") ): raise HTTPError() - - logger.debug("Received request for Azure VM metadata service") - - resource = query_string["resource"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) + resource = qs["resource"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, resource) return build_response(json.dumps({"access_token": self.token}).encode()) class FakeAzureFunctionMetadataService(FakeMetadataService): - """Emulates an environment with the Azure Function metadata service.""" - - AZURE_FUNCTION_IDENTITY_ENDPOINT = "http://169.254.255.2:8081/msi/token" - AZURE_FUNCTION_IDENTITY_HEADER = "FD80F6DA783A4881BE9FAFA365F58E7A" - def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" - self.identity_endpoint = self.__class__.AZURE_FUNCTION_IDENTITY_ENDPOINT - self.identity_header = self.__class__.AZURE_FUNCTION_IDENTITY_HEADER + self.identity_endpoint = AZURE_FUNCTION_IDENTITY_ENDPOINT + self.identity_header = AZURE_FUNCTION_IDENTITY_HEADER self.parsed_identity_endpoint = urlparse(self.identity_endpoint) - self._stack: contextlib.ExitStack | None = None def __enter__(self): - self._stack = contextlib.ExitStack() - self._stack.enter_context( + # run the scrub + HTTP stubs first + super().__enter__() + # now add the two vars the Function runtime exposes + self._context_stack.enter_context( mock.patch.dict( os.environ, { @@ -228,71 +201,56 @@ def __enter__(self): clear=False, ) ) - self._stack.enter_context( - mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) - ) - return super().__enter__() - - def __exit__(self, *exc): - self._stack.close() - return super().__exit__(*exc) + return self # important! def is_expected_hostname(self, host: str | None) -> bool: return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) + qs = parse_qs(parsed_url.query) if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path and headers.get(HDR_IDENTITY) == self.identity_header - and query_string["resource"] + and qs.get("resource") ): - logger.warning( - f"Received malformed request: {method} {parsed_url.path} {headers} {query_string}" - ) raise HTTPError() - - logger.debug("Received request for Azure Functions metadata service") - - resource = query_string["resource"][0] + resource = qs["resource"][0] self.token = gen_dummy_id_token(self.sub, self.iss, resource) return build_response(json.dumps({"access_token": self.token}).encode()) -class FakeGceMetadataService(FakeMetadataService): - """Simulates GCE metadata endpoint.""" +# ----------------------------------------------------------------------------- - GCE_METADATA_HOST = "169.254.169.254" - GCE_IDENTITY_PATH = "/computeMetadata/v1/instance/service-accounts/default/identity" +# --------------------------- GCP fake -------------------------------------- +class FakeGceMetadataService(FakeMetadataService): def reset_defaults(self) -> None: self.sub = "123" self.iss = "https://accounts.google.com" def is_expected_hostname(self, host: str | None) -> bool: - return host == self.__class__.GCE_METADATA_HOST + return host == GCE_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): - query_string = parse_qs(parsed_url.query) + qs = parse_qs(parsed_url.query) if not ( method == "GET" - and parsed_url.path == self.__class__.GCE_IDENTITY_PATH + and parsed_url.path == GCE_IDENTITY_PATH and headers.get(HDR_METADATA_FLAVOR) == "Google" - and query_string.get("audience") + and qs.get("audience") ): raise HTTPError() + audience = qs["audience"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, audience) + return build_response(self.token.encode()) - logger.debug("Received request for GCE metadata service") - audience = query_string["audience"][0] - self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) - return build_response(self.token.encode()) +# ----------------------------------------------------------------------------- +# --------------------------- AWS fake -------------------------------------- class _AwsMetadataService(FakeMetadataService): - """Low-level fake for IMDSv2 and ECS endpoints.""" - HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" IMDS_REGION_PATH = "/latest/meta-data/placement/region" @@ -313,62 +271,54 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): url = f"{parsed_url.scheme}://{parsed_url.hostname}{parsed_url.path}" - if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": return build_response( self.imds_token.encode(), headers={self.__class__.HDR_IMDS_TOKEN_TTL: "21600"}, ) - if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": return build_response(self.role_name.encode()) - if ( method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{self.role_name}" ): if self.access_key is None or self.secret_key is None: return build_response(b"", status_code=404) - creds_json = json.dumps( - { - "AccessKeyId": self.access_key, - "SecretAccessKey": self.secret_key, - "Token": self.session_token, - } - ).encode() - return build_response(creds_json) - + return build_response( + json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + ) ecs_uri = os.getenv(AWS_CONTAINER_CRED_ENV) if ecs_uri and method == "GET" and url == f"{_ECS_CRED_BASE_URL}{ecs_uri}": - creds_json = json.dumps( - { - "AccessKeyId": self.access_key, - "SecretAccessKey": self.secret_key, - "Token": self.session_token, - } - ).encode() - return build_response(creds_json) - + return build_response( + json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + ) if ( method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_REGION_PATH}" ): return build_response(self.region.encode()) - if ( method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_INSTANCE_IDENTITY_DOC}" ): return build_response(json.dumps({"region": self.region}).encode()) - raise ConnectTimeout() class FakeAwsEnvironment: - """ - Base context-manager for AWS runtime fakes. - Subclasses override `_prepare_runtime()` to tweak env-vars / creds. - """ + """Context-manager that wires up the AWS fake + env-vars.""" def __init__(self): self._region = "us-east-1" @@ -379,6 +329,7 @@ def __init__(self): self._metadata = _AwsMetadataService() self._stack: ExitStack | None = None + # ------------- region helper ------------------------------------------- @property def region(self) -> str: return self._region @@ -387,7 +338,6 @@ def region(self) -> str: def region(self, new_region: str) -> None: self._region = new_region self._metadata.region = new_region - if getattr(self, "_stack", None): for key in AWS_REGION_ENV_KEYS: if key in os.environ: @@ -395,9 +345,11 @@ def region(self, new_region: str) -> None: mock.patch.dict(os.environ, {key: new_region}, clear=False) ) - def _prepare_runtime(self): - return None + # ----------------------------------------------------------------------- + + def _prepare_runtime(self): ... + # ----------------------- context plumbing ------------------------------ def __enter__(self): self._stack = ExitStack() self._stack.enter_context( @@ -412,7 +364,6 @@ def __enter__(self): side_effect=ConnectTimeout(), ) ) - self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) @@ -423,18 +374,15 @@ def __enter__(self): self.credentials.token if self.credentials else None ) self._metadata.region = self.region - - env_for_chain = {key: self.region for key in AWS_REGION_ENV_KEYS} + env_for_chain = {k: self.region for k in AWS_REGION_ENV_KEYS} if self.credentials: env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key env_for_chain["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key if self.credentials.token: env_for_chain["AWS_SESSION_TOKEN"] = self.credentials.token - self._stack.enter_context( mock.patch.dict(os.environ, env_for_chain, clear=False) ) - self._prepare_runtime() return self @@ -443,14 +391,10 @@ def __exit__(self, *exc): class FakeAwsEc2(FakeAwsEnvironment): - """Default – IMDSv2 only.""" - - # nothing extra needed + pass class FakeAwsEcs(FakeAwsEnvironment): - """ECS/EKS task-role – exposes creds via task metadata endpoint.""" - def _prepare_runtime(self): self._stack.enter_context( mock.patch.dict( @@ -462,29 +406,19 @@ def _prepare_runtime(self): class FakeAwsLambda(FakeAwsEnvironment): - """Lambda runtime – temporary credentials + runtime env-vars.""" - def __init__(self): super().__init__() - self.credentials = Credentials( - access_key="ak", - secret_key="sk", - token="dummy-session-token", - ) + self.credentials = Credentials("ak", "sk", "dummy-session-token") - def _prepare_runtime(self) -> None: + def _prepare_runtime(self): self._stack.enter_context( mock.patch.dict( - os.environ, - {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, - clear=False, + os.environ, {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, clear=False ) ) class FakeAwsNoCreds(FakeAwsEnvironment): - """Negative path – no credentials anywhere.""" - def _prepare_runtime(self): self.credentials = None self._stack.enter_context( From fbbc22326a8f55e71479acb91422d31a38a53d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 17:34:21 +0200 Subject: [PATCH 51/54] SNOW-2183023: fixed http traffix --- DESCRIPTION.md | 3 + src/snowflake/connector/_aws_credentials.py | 7 +- test/csp_helpers.py | 279 ++++++++++++++------ test/unit/conftest.py | 11 + test/unit/test_auth_workload_identity.py | 25 ++ 5 files changed, 235 insertions(+), 90 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index ba30b3b78b..097454c2bc 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,6 +7,9 @@ https://docs.snowflake.com/ Source code is also available at: https://github.com/snowflakedb/snowflake-connector-python # Release Notes +- v3.17(TBD) + - Removed boto and botocore dependencies. + - v3.16.0(July 04,2025) - Bumped numpy dependency from <2.1.0 to <=2.2.4. - Added Windows support for Python 3.13. diff --git a/src/snowflake/connector/_aws_credentials.py b/src/snowflake/connector/_aws_credentials.py index 05cc26a154..88b2032a3f 100644 --- a/src/snowflake/connector/_aws_credentials.py +++ b/src/snowflake/connector/_aws_credentials.py @@ -63,7 +63,8 @@ def get_container_credentials(*, timeout: float) -> SfAWSCredentials | None: def _get_imds_v2_token(timeout: float) -> str | None: try: - response = requests.put( + response = requests.request( + "PUT", f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}", headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, timeout=timeout, @@ -123,8 +124,8 @@ def get_region(timeout: float = 1.0) -> str | None: token = _get_imds_v2_token(timeout) headers = {"X-aws-ec2-metadata-token": token} if token else {} try: - response = requests.get( - f"{_IMDS_BASE_URL}{_IMDS_AZ_PATH}", headers=headers, timeout=timeout + response = requests.request( + "GET", f"{_IMDS_BASE_URL}{_IMDS_AZ_PATH}", headers=headers, timeout=timeout ) if response.ok: az = response.text.strip() diff --git a/test/csp_helpers.py b/test/csp_helpers.py index aa45066b30..ab69b2803d 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import json import logging import os @@ -18,7 +19,7 @@ _IMDS_ROLE_PATH, _IMDS_TOKEN_PATH, ) -from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError +from snowflake.connector.vendored.requests.exceptions import ConnectTimeout from snowflake.connector.vendored.requests.models import Response logger = logging.getLogger(__name__) @@ -49,27 +50,22 @@ "AWS_SESSION_ARN", ) -# ------------ additional bundles to wipe up-front --------------------------- -AZURE_ENV_KEYS = ("IDENTITY_ENDPOINT", "IDENTITY_HEADER") -GCP_ENV_KEYS = ( - "GOOGLE_APPLICATION_CREDENTIALS", - "GOOGLE_CLOUD_PROJECT", - "GCLOUD_PROJECT", - "GCP_PROJECT", -) -CLOUD_ENV_KEYS = ( - AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS + AZURE_ENV_KEYS + GCP_ENV_KEYS -) -# ---------------------------------------------------------------------------- - def gen_dummy_id_token( sub: str = "test-subject", iss: str = "test-issuer", aud: str = "snowflakecomputing.com", ) -> str: + """Generates a dummy HS256-signed JWT.""" now = int(time()) - payload = {"sub": sub, "iss": iss, "aud": aud, "iat": now, "exp": now + 3600} + payload = { + "sub": sub, + "iss": iss, + "aud": aud, + "iat": now, + "exp": now + 60 * 60, + } + logger.debug("Generating dummy token with claims %s", payload) return jwt.encode(payload, key="secret", algorithm="HS256") @@ -78,6 +74,7 @@ def build_response( status_code: int = 200, headers: dict[str, str] | None = None, ) -> Response: + """Return a minimal Response object with canned body/headers.""" resp = Response() resp.status_code = status_code resp._content = content @@ -87,19 +84,20 @@ def build_response( class FakeMetadataService(ABC): - """Base class for all cloud-metadata fakes.""" + """Base class for cloud-metadata fakes.""" def __init__(self) -> None: self.reset_defaults() self._context_stack: ExitStack | None = None - # ------------------------------------------------------------------ utils @staticmethod def _clean_env_vars_for_scope() -> dict[str, str]: - """Blank all major cloud-specific env-vars for a hermetic test.""" - return {k: "" for k in CLOUD_ENV_KEYS} + """Return a mapping that blanks all AWS-specific env-vars. - # ------------------------------------------------------------------------ + Used by Azure / GCP fakes so tests stay hermetic even when + executed inside a real AWS runner. + """ + return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS} @abstractmethod def reset_defaults(self) -> None: ... @@ -108,27 +106,39 @@ def reset_defaults(self) -> None: ... def is_expected_hostname(self, host: str | None) -> bool: ... @abstractmethod - def handle_request(self, method, parsed_url, headers, timeout) -> Response: ... + def handle_request( + self, + method, + parsed_url, + headers, + timeout, + ) -> Response: ... - # -------------------------------------------------------- context helpers def __call__(self, method, url, headers=None, timeout=None, **_kw): headers = headers or {} parsed = urlparse(url) + logger.debug("FakeMetadataService received %s %s %s", method, url, headers) + if not self.is_expected_hostname(parsed.hostname): + logger.debug( + "Received request to unexpected hostname %s – timeout", parsed.hostname + ) raise ConnectTimeout() + return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): self.reset_defaults() self._context_stack = ExitStack() - # first – wipe every cloud env-var self._context_stack.enter_context( - mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + mock.patch( + "snowflake.connector.vendored.requests.request", + side_effect=self, + ) ) - # route HTTP calls through this fake self._context_stack.enter_context( mock.patch( - "snowflake.connector.vendored.requests.request", + "snowflake.connector.vendored.requests.sessions.Session.request", side_effect=self, ) ) @@ -143,43 +153,64 @@ def __enter__(self): def __exit__(self, *exc): self._context_stack.close() - # ------------------------------------------------------------------------ - class NoMetadataService(FakeMetadataService): - def reset_defaults(self) -> None: ... + """Always times out – simulates an environment without any metadata service.""" + + def reset_defaults(self) -> None: + pass def is_expected_hostname(self, host: str | None) -> bool: return False def handle_request(self, *_): - raise AssertionError + raise AssertionError( + "This should never be called because we always raise a ConnectTimeout." + ) -# --------------------------- Azure fakes ----------------------------------- class FakeAzureVmMetadataService(FakeMetadataService): + """Emulates an environment with the Azure VM metadata service.""" + def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + def is_expected_hostname(self, host: str | None) -> bool: return host == AZURE_VM_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + if not ( method == "GET" and parsed_url.path == AZURE_VM_TOKEN_PATH - and headers.get(HDR_METADATA) == "True" - and qs.get("resource") + and headers.get(HDR_METADATA, "").lower() == "true" # <-- patched + and query_string.get("resource") ): - raise HTTPError() - resource = qs["resource"][0] - self.token = gen_dummy_id_token(self.sub, self.iss, resource) + raise ConnectTimeout() + + logger.debug("Received request for Azure VM metadata service") + + resource = query_string["resource"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) return build_response(json.dumps({"access_token": self.token}).encode()) class FakeAzureFunctionMetadataService(FakeMetadataService): + """Emulates an environment with the Azure Function metadata service.""" + def reset_defaults(self) -> None: self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" @@ -188,10 +219,8 @@ def reset_defaults(self) -> None: self.parsed_identity_endpoint = urlparse(self.identity_endpoint) def __enter__(self): - # run the scrub + HTTP stubs first - super().__enter__() - # now add the two vars the Function runtime exposes - self._context_stack.enter_context( + self._stack = contextlib.ExitStack() + self._stack.enter_context( mock.patch.dict( os.environ, { @@ -201,59 +230,87 @@ def __enter__(self): clear=False, ) ) - return self # important! + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) def is_expected_hostname(self, host: str | None) -> bool: return host == self.parsed_identity_endpoint.hostname def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path and headers.get(HDR_IDENTITY) == self.identity_header - and qs.get("resource") + and query_string["resource"] ): - raise HTTPError() - resource = qs["resource"][0] - self.token = gen_dummy_id_token(self.sub, self.iss, resource) - return build_response(json.dumps({"access_token": self.token}).encode()) + logger.warning( + f"Received malformed request: {method} {parsed_url.path} " + f"{str(headers)} {str(query_string)}" + ) + raise ConnectTimeout() + logger.debug("Received request for Azure Functions metadata service") -# ----------------------------------------------------------------------------- + resource = query_string["resource"][0] + self.token = gen_dummy_id_token(self.sub, self.iss, resource) + return build_response(json.dumps({"access_token": self.token}).encode()) -# --------------------------- GCP fake -------------------------------------- class FakeGceMetadataService(FakeMetadataService): + """Simulates GCE metadata endpoint.""" + def reset_defaults(self) -> None: self.sub = "123" self.iss = "https://accounts.google.com" + def __enter__(self): + self._stack = contextlib.ExitStack() + self._stack.enter_context( + mock.patch.dict(os.environ, self._clean_env_vars_for_scope(), clear=False) + ) + return super().__enter__() + + def __exit__(self, *exc): + self._stack.close() + return super().__exit__(*exc) + def is_expected_hostname(self, host: str | None) -> bool: return host == GCE_METADATA_HOST def handle_request(self, method, parsed_url, headers, timeout): - qs = parse_qs(parsed_url.query) + query_string = parse_qs(parsed_url.query) + if not ( method == "GET" and parsed_url.path == GCE_IDENTITY_PATH and headers.get(HDR_METADATA_FLAVOR) == "Google" - and qs.get("audience") + and query_string.get("audience") ): - raise HTTPError() - audience = qs["audience"][0] - self.token = gen_dummy_id_token(self.sub, self.iss, audience) - return build_response(self.token.encode()) + raise ConnectTimeout() + logger.debug("Received request for GCE metadata service") -# ----------------------------------------------------------------------------- + audience = query_string["audience"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) + return build_response(self.token.encode()) -# --------------------------- AWS fake -------------------------------------- class _AwsMetadataService(FakeMetadataService): + """Low-level fake for IMDSv2 and ECS endpoints.""" + HDR_IMDS_TOKEN_TTL = "x-aws-ec2-metadata-token-ttl-seconds" IMDS_INSTANCE_IDENTITY_DOC = "/latest/dynamic/instance-identity/document" IMDS_REGION_PATH = "/latest/meta-data/placement/region" + IMDS_AZ_PATH = "/latest/meta-data/placement/availability-zone" def reset_defaults(self) -> None: self.role_name = "MyRole" @@ -271,54 +328,68 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): url = f"{parsed_url.scheme}://{parsed_url.hostname}{parsed_url.path}" + if method == "PUT" and url == f"{_IMDS_BASE_URL}{_IMDS_TOKEN_PATH}": return build_response( self.imds_token.encode(), headers={self.__class__.HDR_IMDS_TOKEN_TTL: "21600"}, ) + if method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}": return build_response(self.role_name.encode()) + if ( method == "GET" and url == f"{_IMDS_BASE_URL}{_IMDS_ROLE_PATH}{self.role_name}" ): if self.access_key is None or self.secret_key is None: return build_response(b"", status_code=404) - return build_response( - json.dumps( - { - "AccessKeyId": self.access_key, - "SecretAccessKey": self.secret_key, - "Token": self.session_token, - } - ).encode() - ) + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + ecs_uri = os.getenv(AWS_CONTAINER_CRED_ENV) if ecs_uri and method == "GET" and url == f"{_ECS_CRED_BASE_URL}{ecs_uri}": - return build_response( - json.dumps( - { - "AccessKeyId": self.access_key, - "SecretAccessKey": self.secret_key, - "Token": self.session_token, - } - ).encode() - ) + creds_json = json.dumps( + { + "AccessKeyId": self.access_key, + "SecretAccessKey": self.secret_key, + "Token": self.session_token, + } + ).encode() + return build_response(creds_json) + if ( method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_REGION_PATH}" ): return build_response(self.region.encode()) + + # New: availability-zone path (region extracted by stripping last char) + if ( + method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}" + ): # <-- new + return build_response(f"{self.region}a".encode()) # <-- new + if ( method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_INSTANCE_IDENTITY_DOC}" ): return build_response(json.dumps({"region": self.region}).encode()) + raise ConnectTimeout() class FakeAwsEnvironment: - """Context-manager that wires up the AWS fake + env-vars.""" + """ + Base context-manager for AWS runtime fakes. + Subclasses override `_prepare_runtime()` to tweak env-vars / creds. + """ def __init__(self): self._region = "us-east-1" @@ -329,7 +400,6 @@ def __init__(self): self._metadata = _AwsMetadataService() self._stack: ExitStack | None = None - # ------------- region helper ------------------------------------------- @property def region(self) -> str: return self._region @@ -338,6 +408,7 @@ def region(self) -> str: def region(self, new_region: str) -> None: self._region = new_region self._metadata.region = new_region + if getattr(self, "_stack", None): for key in AWS_REGION_ENV_KEYS: if key in os.environ: @@ -345,25 +416,31 @@ def region(self, new_region: str) -> None: mock.patch.dict(os.environ, {key: new_region}, clear=False) ) - # ----------------------------------------------------------------------- - - def _prepare_runtime(self): ... + def _prepare_runtime(self): + return None - # ----------------------- context plumbing ------------------------------ def __enter__(self): self._stack = ExitStack() + self._stack.enter_context( mock.patch( "snowflake.connector.vendored.requests.request", side_effect=self._metadata, ) ) + self._stack.enter_context( + mock.patch( + "snowflake.connector.vendored.requests.sessions.Session.request", + side_effect=self._metadata, + ) + ) self._stack.enter_context( mock.patch( "urllib3.connection.HTTPConnection.request", side_effect=ConnectTimeout(), ) ) + self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) @@ -374,15 +451,18 @@ def __enter__(self): self.credentials.token if self.credentials else None ) self._metadata.region = self.region - env_for_chain = {k: self.region for k in AWS_REGION_ENV_KEYS} + + env_for_chain = {key: self.region for key in AWS_REGION_ENV_KEYS} if self.credentials: env_for_chain["AWS_ACCESS_KEY_ID"] = self.credentials.access_key env_for_chain["AWS_SECRET_ACCESS_KEY"] = self.credentials.secret_key if self.credentials.token: env_for_chain["AWS_SESSION_TOKEN"] = self.credentials.token + self._stack.enter_context( mock.patch.dict(os.environ, env_for_chain, clear=False) ) + self._prepare_runtime() return self @@ -391,10 +471,12 @@ def __exit__(self, *exc): class FakeAwsEc2(FakeAwsEnvironment): - pass + """Default – IMDSv2 only.""" class FakeAwsEcs(FakeAwsEnvironment): + """ECS/EKS task-role – exposes creds via task metadata endpoint.""" + def _prepare_runtime(self): self._stack.enter_context( mock.patch.dict( @@ -406,20 +488,43 @@ def _prepare_runtime(self): class FakeAwsLambda(FakeAwsEnvironment): + """Lambda runtime – temporary credentials + runtime env-vars.""" + def __init__(self): super().__init__() - self.credentials = Credentials("ak", "sk", "dummy-session-token") + self.credentials = Credentials( + access_key="ak", + secret_key="sk", + token="dummy-session-token", + ) - def _prepare_runtime(self): + def _prepare_runtime(self) -> None: self._stack.enter_context( mock.patch.dict( - os.environ, {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, clear=False + os.environ, + {AWS_LAMBDA_FUNCTION_ENV: "dummy-fn"}, + clear=False, ) ) +class _AwsMetadataTimeout(_AwsMetadataService): + """IMDS/ECS stub that never answers – simulates a totally unreachable endpoint.""" + + def handle_request(self, *args, **kwargs): + raise ConnectTimeout() + + class FakeAwsNoCreds(FakeAwsEnvironment): + """Negative path – no credentials anywhere *and* IMDS/ECS completely unreachable.""" + + def __init__(self): + super().__init__() + # Use the timeout-only IMDS stub + self._metadata = _AwsMetadataTimeout() + def _prepare_runtime(self): + # Strip every env-var that could satisfy the AWS credential chain self.credentials = None self._stack.enter_context( mock.patch.dict( diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 1a4fbf5b40..89cd796497 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -43,6 +43,17 @@ def fake_aws_environment(request): yield env +@pytest.fixture +def imds_only_aws_environment(fake_aws_environment, monkeypatch): + """ + Same fake runtime, but with AWS_REGION / AWS_DEFAULT_REGION removed + so the code *must* query IMDS to discover the region. + """ + for key in ("AWS_REGION", "AWS_DEFAULT_REGION"): + monkeypatch.delenv(key, raising=False) + yield fake_aws_environment + + @pytest.fixture(params=[FakeAwsNoCreds], ids=["aws_no_creds"]) def malformed_aws_environment(request): """Runtime where *no* credentials are discoverable (negative-path).""" diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 5f98f70421..a9590289e7 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -467,3 +467,28 @@ def test_autodetect_no_provider_raises_error(no_metadata_service): assert "No workload identity credential was found for 'auto-detect" in str( excinfo.value ) + + +def test_explicit_aws_region_falls_back_to_imds(imds_only_aws_environment): + """ + When region env-vars are absent, the connector must discover the region via + the runtime metadata service (IMDS / task-metadata / lambda env). + """ + # Advertise a non-default region through the fake metadata service + imds_only_aws_environment.region = "us-west-2" + + auth = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth.prepare() + + verify_aws_token(extract_api_data(auth)["TOKEN"], "us-west-2") + + +def test_autodetect_prefers_gcp_when_no_aws_env(fake_gce_metadata_service): + """ + No AWS env-vars + a responsive GCP metadata server ⇒ GCP selected. + """ + auth_class = AuthByWorkloadIdentity(provider=None) + auth_class.prepare() + + assert extract_api_data(auth_class)["PROVIDER"] == "GCP" + assert extract_api_data(auth_class)["TOKEN"] == fake_gce_metadata_service.token From fab3af9dfb03997be6fcd9b3ceb15a23c425ddb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 17:35:21 +0200 Subject: [PATCH 52/54] SNOW-2183023: fixed http traffix --- test/csp_helpers.py | 45 +++++++++++++++++++----- test/unit/test_auth_workload_identity.py | 2 +- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/test/csp_helpers.py b/test/csp_helpers.py index ab69b2803d..332ee021c3 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -100,10 +100,20 @@ def _clean_env_vars_for_scope() -> dict[str, str]: return {k: "" for k in AWS_CREDENTIAL_ENV_KEYS + AWS_REGION_ENV_KEYS} @abstractmethod - def reset_defaults(self) -> None: ... + def reset_defaults(self) -> None: + """Resets any default values for test parameters. + + This is called in the constructor and when entering as a context manager. + """ + pass @abstractmethod - def is_expected_hostname(self, host: str | None) -> bool: ... + def is_expected_hostname(self, host: str | None) -> bool: + """Returns true if the passed hostname is the one at which this metadata service is listening. + + Used to raise a ConnectTimeout for requests not targeted to this hostname. + """ + pass @abstractmethod def handle_request( @@ -112,9 +122,12 @@ def handle_request( parsed_url, headers, timeout, - ) -> Response: ... + ) -> Response: + """Main business logic for handling this request. Should return a Response object.""" + pass def __call__(self, method, url, headers=None, timeout=None, **_kw): + """Entry-point for the requests monkey-patch.""" headers = headers or {} parsed = urlparse(url) logger.debug("FakeMetadataService received %s %s %s", method, url, headers) @@ -128,6 +141,7 @@ def __call__(self, method, url, headers=None, timeout=None, **_kw): return self.handle_request(method.upper(), parsed, headers, timeout) def __enter__(self): + """Patches the relevant HTTP calls when entering as a context manager.""" self.reset_defaults() self._context_stack = ExitStack() self._context_stack.enter_context( @@ -164,6 +178,7 @@ def is_expected_hostname(self, host: str | None) -> bool: return False def handle_request(self, *_): + # This should never be called because we always raise a ConnectTimeout. raise AssertionError( "This should never be called because we always raise a ConnectTimeout." ) @@ -173,6 +188,7 @@ class FakeAzureVmMetadataService(FakeMetadataService): """Emulates an environment with the Azure VM metadata service.""" def reset_defaults(self) -> None: + # Defaults used for generating an Entra ID token. Can be overriden in individual tests. self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" @@ -193,6 +209,7 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == AZURE_VM_TOKEN_PATH @@ -205,7 +222,7 @@ def handle_request(self, method, parsed_url, headers, timeout): resource = query_string["resource"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) - return build_response(json.dumps({"access_token": self.token}).encode()) + return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) class FakeAzureFunctionMetadataService(FakeMetadataService): @@ -220,6 +237,7 @@ def reset_defaults(self) -> None: def __enter__(self): self._stack = contextlib.ExitStack() + # Inject the variables without touching os.environ directly self._stack.enter_context( mock.patch.dict( os.environ, @@ -246,6 +264,7 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == self.parsed_identity_endpoint.path @@ -269,6 +288,7 @@ class FakeGceMetadataService(FakeMetadataService): """Simulates GCE metadata endpoint.""" def reset_defaults(self) -> None: + # Defaults used for generating a token. Can be overriden in individual tests. self.sub = "123" self.iss = "https://accounts.google.com" @@ -289,6 +309,7 @@ def is_expected_hostname(self, host: str | None) -> bool: def handle_request(self, method, parsed_url, headers, timeout): query_string = parse_qs(parsed_url.query) + # Reject malformed requests. if not ( method == "GET" and parsed_url.path == GCE_IDENTITY_PATH @@ -370,11 +391,8 @@ def handle_request(self, method, parsed_url, headers, timeout): ): return build_response(self.region.encode()) - # New: availability-zone path (region extracted by stripping last char) - if ( - method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}" - ): # <-- new - return build_response(f"{self.region}a".encode()) # <-- new + if method == "GET" and url == f"{_IMDS_BASE_URL}{self.__class__.IMDS_AZ_PATH}": + return build_response(f"{self.region}a".encode()) if ( method == "GET" @@ -392,6 +410,7 @@ class FakeAwsEnvironment: """ def __init__(self): + # Defaults used for generating a token. Can be overriden in individual tests. self._region = "us-east-1" self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" self.credentials: Credentials | None = Credentials( @@ -406,6 +425,9 @@ def region(self) -> str: @region.setter def region(self, new_region: str) -> None: + """Change runtime region and, if the env-vars already exist, + patch them via ExitStack so they’re cleaned up on __exit__. + """ self._region = new_region self._metadata.region = new_region @@ -417,6 +439,7 @@ def region(self, new_region: str) -> None: ) def _prepare_runtime(self): + """Sub-classes patch env / credentials here.""" return None def __enter__(self): @@ -441,6 +464,7 @@ def __enter__(self): ) ) + # Keep the metadata stub in sync with the final credential set. self._metadata.access_key = ( self.credentials.access_key if self.credentials else None ) @@ -463,6 +487,7 @@ def __enter__(self): mock.patch.dict(os.environ, env_for_chain, clear=False) ) + # Runtime-specific tweaks (may change creds / env). self._prepare_runtime() return self @@ -492,6 +517,7 @@ class FakeAwsLambda(FakeAwsEnvironment): def __init__(self): super().__init__() + # Lambda always returns *session* credentials self.credentials = Credentials( access_key="ak", secret_key="sk", @@ -499,6 +525,7 @@ def __init__(self): ) def _prepare_runtime(self) -> None: + # Patch env vars via mock.patch.dict so nothing touches os.environ directly self._stack.enter_context( mock.patch.dict( os.environ, diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index a9590289e7..791a254f3f 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -485,7 +485,7 @@ def test_explicit_aws_region_falls_back_to_imds(imds_only_aws_environment): def test_autodetect_prefers_gcp_when_no_aws_env(fake_gce_metadata_service): """ - No AWS env-vars + a responsive GCP metadata server ⇒ GCP selected. + No AWS env-vars + a responsive GCP metadata server -> GCP selected. """ auth_class = AuthByWorkloadIdentity(provider=None) auth_class.prepare() From f9b891820a224d031db5a362c623154c0fe6c4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 22:36:28 +0200 Subject: [PATCH 53/54] SNOW-2183023: comments cleanup --- src/snowflake/connector/_aws_sign_v4.py | 6 +++--- src/snowflake/connector/wif_util.py | 16 ++++++++-------- test/unit/test_boto_compatibility.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/snowflake/connector/_aws_sign_v4.py b/src/snowflake/connector/_aws_sign_v4.py index f4efc4323e..fa5dc6d472 100644 --- a/src/snowflake/connector/_aws_sign_v4.py +++ b/src/snowflake/connector/_aws_sign_v4.py @@ -33,11 +33,11 @@ def sign_get_caller_identity( session_token: str | None = None, ) -> dict[str, str]: """ - Return the SigV4 headers needed for a presigned **POST** to AWS STS + Return the SigV4 headers needed for a presigned POST to AWS STS `GetCallerIdentity`. - Parameters - ---------- + Parameters: + url The full STS endpoint with query parameters (e.g. ``https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15``) diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 5ba26c6d4f..9a624d27d6 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -3,9 +3,9 @@ """Workload‑identity attestation helpers. This module builds the attestation token that the Snowflake Python connector -sends when Authenticating with *Workload Identity Federation* (WIF). -It supports AWS, Azure, GCP and generic OIDC environments **without** pulling -in heavy SDKs such as *botocore* – we only need a small presigned STS request +sends when Authenticating with Workload Identity Federation (WIF). +It supports AWS, Azure, GCP and generic OIDC environments without pulling +in heavy SDKs such as botocore – we only need a small presigned STS request for AWS and a couple of metadata‑server calls for Azure / GCP. """ @@ -66,7 +66,7 @@ def from_string(provider: str) -> AttestationProvider: @dataclass class WorkloadIdentityAttestation: provider: AttestationProvider - credential: str # **base64** JSON blob – provider‑specific + credential: str # base64 JSON blob – provider‑specific user_identifier_components: dict[str, Any] @@ -136,7 +136,7 @@ def _partition_from_region(region: str) -> AWSPartition: def _sts_host_from_region(region: str) -> str | None: """ - Construct the STS endpoint hostname for *region* according to the + Construct the STS endpoint hostname for region according to the regionalised-STS rules published by AWS.:contentReference[oaicite:2]{index=2} References: @@ -173,8 +173,8 @@ def _try_get_arn_from_env_vars() -> str | None: def try_compose_aws_user_identifier(region: str | None = None) -> dict[str, str]: """Return an identifier for the running AWS workload. - Always includes the AWS *region*; adds an *arn* key only if one is already - discoverable via common environment variables. Returns **{}** only if + Always includes the AWS region; adds an *arn* key only if one is already + discoverable via common environment variables. Returns {} only if the region cannot be determined.""" region = region or get_region() if not region: @@ -189,7 +189,7 @@ def try_compose_aws_user_identifier(region: str | None = None) -> dict[str, str] def create_aws_attestation() -> WorkloadIdentityAttestation | None: - """Return AWS attestation or *None* if we're not on AWS / creds missing.""" + """Return AWS attestation or None if we're not on AWS / creds missing.""" creds = load_default_credentials() if not creds: diff --git a/test/unit/test_boto_compatibility.py b/test/unit/test_boto_compatibility.py index e46177fc67..e50d15f615 100644 --- a/test/unit/test_boto_compatibility.py +++ b/test/unit/test_boto_compatibility.py @@ -119,7 +119,7 @@ def test_sts_host_from_region_matches_botocore( ): sf_host = _sts_host_from_region(region) - # Force botocore into **regional** mode so that it doesn’t fall back to the + # Force botocore into regional mode so that it doesn’t fall back to the # legacy global host (sts.amazonaws.com) for the particular regions (like us-east-1). # Both approaches work correctly. monkeypatch.setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") From 935efd165af669adeb09498f6f8a9db818b3ef76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sun, 13 Jul 2025 23:02:17 +0200 Subject: [PATCH 54/54] SNOW-2183023: http session unmanaged requests --- src/snowflake/connector/_aws_sign_v4.py | 28 ++++++++++++------------- test/unit/test_boto_compatibility.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/snowflake/connector/_aws_sign_v4.py b/src/snowflake/connector/_aws_sign_v4.py index fa5dc6d472..4210aeba9a 100644 --- a/src/snowflake/connector/_aws_sign_v4.py +++ b/src/snowflake/connector/_aws_sign_v4.py @@ -1,26 +1,26 @@ from __future__ import annotations -import datetime as _dt -import hashlib as _hashlib -import hmac as _hmac -import urllib.parse as _urlparse +import datetime +import hashlib +import hmac +import urllib.parse as urlparse _ALGORITHM: str = "AWS4-HMAC-SHA256" -_EMPTY_PAYLOAD_SHA256: str = _hashlib.sha256(b"").hexdigest() +_EMPTY_PAYLOAD_SHA256: str = hashlib.sha256(b"").hexdigest() _SAFE_CHARS: str = "-_.~" def _sign(key: bytes, msg: str) -> bytes: """Return an HMAC-SHA256 of *msg* keyed with *key*.""" - return _hmac.new(key, msg.encode(), _hashlib.sha256).digest() + return hmac.new(key, msg.encode(), hashlib.sha256).digest() def _canonical_query_string(query: str) -> str: """Return the query string in canonical (sorted & URL-escaped) form.""" - pairs = _urlparse.parse_qsl(query, keep_blank_values=True) + pairs = urlparse.parse_qsl(query, keep_blank_values=True) pairs.sort() return "&".join( - f"{_urlparse.quote(k, _SAFE_CHARS)}={_urlparse.quote(v, _SAFE_CHARS)}" + f"{urlparse.quote(k, _SAFE_CHARS)}={urlparse.quote(v, _SAFE_CHARS)}" for k, v in pairs ) @@ -50,12 +50,12 @@ def sign_get_caller_identity( session_token (Optional) session token for temporary credentials. """ - timestamp = _dt.datetime.utcnow() + timestamp = datetime.datetime.utcnow() amz_date = timestamp.strftime("%Y%m%dT%H%M%SZ") short_date = timestamp.strftime("%Y%m%d") service = "sts" - parsed = _urlparse.urlparse(url) + parsed = urlparse.urlparse(url) headers: dict[str, str] = { "host": parsed.netloc.lower(), @@ -70,14 +70,14 @@ def sign_get_caller_identity( canonical_request = "\n".join( ( "POST", - _urlparse.quote(parsed.path or "/", safe="/"), + urlparse.quote(parsed.path or "/", safe="/"), _canonical_query_string(parsed.query), "".join(f"{k}:{headers[k]}\n" for k in sorted(headers)), signed_headers, _EMPTY_PAYLOAD_SHA256, ) ) - canonical_request_hash = _hashlib.sha256(canonical_request.encode()).hexdigest() + canonical_request_hash = hashlib.sha256(canonical_request.encode()).hexdigest() # String to sign credential_scope = f"{short_date}/{region}/{service}/aws4_request" @@ -90,8 +90,8 @@ def sign_get_caller_identity( key_region = _sign(key_date, region) key_service = _sign(key_region, service) key_signing = _sign(key_service, "aws4_request") - signature = _hmac.new( - key_signing, string_to_sign.encode(), _hashlib.sha256 + signature = hmac.new( + key_signing, string_to_sign.encode(), hashlib.sha256 ).hexdigest() # Final Authorization header diff --git a/test/unit/test_boto_compatibility.py b/test/unit/test_boto_compatibility.py index e50d15f615..fe3a8f8951 100644 --- a/test/unit/test_boto_compatibility.py +++ b/test/unit/test_boto_compatibility.py @@ -28,7 +28,7 @@ def freeze_utcnow(monkeypatch: pytest.MonkeyPatch): class _FrozenDateTime(datetime.datetime): @classmethod - def utcnow(cls): # type: ignore[override] + def utcnow(cls): return fixed monkeypatch.setattr(datetime, "datetime", _FrozenDateTime) @@ -159,7 +159,7 @@ def test_region_env_var_default(monkeypatch: pytest.MonkeyPatch) -> None: def test_region_env_var_legacy(monkeypatch: pytest.MonkeyPatch) -> None: """ - AWS_REGION is *ignored* by botocore currently, but should be introduced in the future: https://docs.aws.amazon.com/sdkref/latest/guide/feature-region.html + AWS_REGION is ignored by botocore currently, but should be introduced in the future: https://docs.aws.amazon.com/sdkref/latest/guide/feature-region.html Therefore for now we set it as env_var for the driver and pass via explicit parameter to botocore. """ desired_region = "ca-central-1"