diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 61c80607ca..1d2bf53100 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,6 +7,13 @@ https://docs.snowflake.com/ Source code is also available at: https://github.com/snowflakedb/snowflake-connector-python # Release Notes +- v3.17(TBD) + - Bumped numpy dependency from <2.1.0 to <=2.2.4 + - Added Windows support for Python 3.13. + - Add `bulk_upload_chunks` parameter to `write_pandas` function. Setting this parameter to True changes the behaviour of write_pandas function to first write all the data chunks to the local disk and then perform the wildcard upload of the chunks folder to the stage. In default behaviour the chunks are being saved, uploaded and deleted one by one. + - Add `headers_customizers` parameter to the `connect` function. Setting this parameter allows enriching outgoing request headers using a list of customizers. Only header enrichment is supported — modifying query parameters or overwriting the existing headers is not allowed. + - Added support for new authentication mechanism PAT with external session ID + - v3.16.1(TBD) - Added in-band OCSP exception telemetry. - Added `APPLICATION_PATH` within `CLIENT_ENVIRONMENT` to distinguish between multiple scripts using the PythonConnector in the same environment. diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index aa81318aca..ee6fafe8b2 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -20,7 +20,16 @@ from logging import getLogger from threading import Lock from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence +from typing import ( + Any, + Callable, + Generator, + Iterable, + Iterator, + MutableSequence, + NamedTuple, + Sequence, +) from uuid import UUID from cryptography.hazmat.backends import default_backend @@ -102,6 +111,11 @@ ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE, ) from .errors import DatabaseError, Error, OperationalError, ProgrammingError +from .http_interceptor import ( + HeadersCustomizer, + HeadersCustomizerInterceptor, + HttpInterceptor, +) from .log_configuration import EasyLoggingConfigPython from .network import ( DEFAULT_AUTHENTICATOR, @@ -116,6 +130,9 @@ REQUEST_ID, USR_PWD_MFA_AUTHENTICATOR, WORKLOAD_IDENTITY_AUTHENTICATOR, + HttpAdapterFactory, + InterceptingAdapter, + ProxySupportAdapter, ReauthenticationRequest, SnowflakeRestful, ) @@ -368,6 +385,10 @@ def _get_private_bytes_from_file( True, bool, ), # SNOW-XXXXX: remove the check_arrow_conversion_error_on_every_column flag + "headers_customizers": ( + None, + (type(None), MutableSequence[HeadersCustomizer]), + ), "external_session_id": ( None, str, @@ -468,6 +489,7 @@ class SnowflakeConnection: token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used. unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600. check_arrow_conversion_error_on_every_column: When true, the error check after the conversion from arrow to python types will happen for every column in the row. This is a new behaviour which fixes the bug that caused the type errors to trigger silently when occurring at any place other than last column in a row. To revert the previous (faulty) behaviour, please set this flag to false. + headers_customizer: List of headers customizers (HeadersCustomizer class). Setting this parameter allows enriching outgoing request headers using a list of customizers. Only header enrichment is supported — modifying query parameters or overwriting the existing headers is not allowed. """ OCSP_ENV_LOCK = Lock() @@ -888,7 +910,61 @@ def check_arrow_conversion_error_on_every_column(self) -> bool: def check_arrow_conversion_error_on_every_column(self, value: bool) -> bool: self._check_arrow_conversion_error_on_every_column = value + @property + def request_interceptors(self) -> MutableSequence[HttpInterceptor]: + return self._request_interceptors + + @property + def headers_customizers(self) -> MutableSequence[HeadersCustomizer]: + return self._headers_customizers + + @headers_customizers.setter + def headers_customizers(self, value: MutableSequence[HeadersCustomizer]) -> None: + self._headers_customizers = value + request_interceptors = self._create_interceptor_for_headers_customizers(value) + self._request_interceptors = ( + [ + request_interceptors, + ] + if request_interceptors + else [] + ) + + def add_headers_customizer( + self, new_customizer: HeadersCustomizer + ) -> SnowflakeConnection: + """ + Builder method to add a single headers customizer to the list of headers customizers. + """ + if new_customizer in self._headers_customizers or not new_customizer: + return self + + self._headers_customizers.append(new_customizer) + self._request_interceptors.append( + HeadersCustomizerInterceptor([new_customizer]) + ) + return self + + def clear_headers_customizers(self) -> None: + self._headers_customizers.clear() + self._request_interceptors[:] = [ + interceptor + for interceptor in self._request_interceptors + if not isinstance(interceptor, HeadersCustomizerInterceptor) + ] + + @staticmethod + def _create_interceptor_for_headers_customizers( + headers_customizers: MutableSequence[HeadersCustomizer], + ) -> HeadersCustomizerInterceptor | None: + return ( + HeadersCustomizerInterceptor(headers_customizers) + if headers_customizers + else None + ) + def connect(self, **kwargs) -> None: + ... """Establishes connection to Snowflake.""" logger.debug("connect") if len(kwargs) > 0: @@ -1178,7 +1254,6 @@ def __open_connection(self): ): raise TypeError("auth_class must be a child class of AuthByKeyPair") # TODO: add telemetry for custom auth - self.auth_class = self.auth_class elif self._authenticator == DEFAULT_AUTHENTICATOR: self.auth_class = AuthByDefault( password=self._password, @@ -1419,6 +1494,37 @@ def __config(self, **kwargs): if "host" not in kwargs: self._host = construct_hostname(kwargs.get("region"), self._account) + self._headers_customizers = kwargs.get("headers_customizers", []) + if self._headers_customizers: + header_customizer_interceptor = ( + self._create_interceptor_for_headers_customizers( + self._headers_customizers + ) + ) + request_interceptors = ( + [ + header_customizer_interceptor, + ] + if header_customizer_interceptor + else [] + ) + HttpAdapterFactory.register_for_connection( + connection=self, + adapter_cls=InterceptingAdapter, + interceptors=request_interceptors, + ) + else: + # Default adapter + HttpAdapterFactory.register_for_connection( + connection=self, adapter_cls=ProxySupportAdapter + ) + self._request_interceptors = [] + + if self._headers_customizers: + logger.info( + f"{len(self._headers_customizers)} custom headers customizers were provided. Requests will be enriched according to the defined conditions." + ) + logger.info( f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain" ) diff --git a/src/snowflake/connector/http_interceptor.py b/src/snowflake/connector/http_interceptor.py new file mode 100644 index 0000000000..1b2a76ddc8 --- /dev/null +++ b/src/snowflake/connector/http_interceptor.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import Any, Generator, Iterable, MutableSequence, NamedTuple + +METHODS = { + "GET", + "PUT", + "POST", + "HEAD", + "DELETE", +} + +Headers = dict[str, Any] + +logger = logging.getLogger(__name__) + + +class RequestDTO(NamedTuple): + url: str | bytes + method: str + headers: Headers + + +class ConflictDTO(NamedTuple): + original_key: str + value: Any + had_conflict: bool = True + + +class HeadersCustomizer(ABC): + @abstractmethod + def applies_to(self, request: RequestDTO): + raise NotImplementedError() + + @abstractmethod + def is_invoked_once(self): + """recommended to increase performance""" + raise NotImplementedError() + + @abstractmethod + def get_new_headers(self, request: RequestDTO) -> Headers: + raise NotImplementedError() + + +class HttpInterceptor(ABC): + class InterceptionHook(Enum): + BEFORE_RETRY = auto() + BEFORE_REQUEST_ISSUED = auto() + + @abstractmethod + def intercept_on( + self, hook: HttpInterceptor.InterceptionHook, request: RequestDTO + ) -> RequestDTO: + raise NotImplementedError() + + +class HeadersCustomizerInterceptor(HttpInterceptor): + def __init__( + self, + headers_customizers: Iterable[HeadersCustomizer] | None = None, + static_headers_customizers: Iterable[HeadersCustomizer] | None = None, + dynamic_headers_customizers: Iterable[HeadersCustomizer] | None = None, + ): + if headers_customizers is not None: + self._static_headers_customizers, self._dynamic_headers_customizers = ( + self.split_customizers(headers_customizers) + ) + else: + self._static_headers_customizers = static_headers_customizers or [] + self._dynamic_headers_customizers = dynamic_headers_customizers or [] + + @staticmethod + def split_customizers( + headers_customizers: Iterable[HeadersCustomizer] | None, + ) -> tuple[MutableSequence[HeadersCustomizer], MutableSequence[HeadersCustomizer]]: + static, dynamic = [], [] + if headers_customizers: + for customizer in headers_customizers: + if customizer.is_invoked_once(): + static.append(customizer) + else: + dynamic.append(customizer) + return static, dynamic + + @staticmethod + def iter_non_conflicting_headers( + original_headers: Headers, other_headers: Headers, case_sensitive: bool = False + ) -> Generator[ConflictDTO]: + """ + Yields (key, value, is_conflicting) for each key in other_headers, + telling you if it conflicts with original_headers. + """ + original_keys = ( + set(original_headers) + if case_sensitive + else {k.lower() for k in original_headers} + ) + + for key, value in other_headers.items(): + comp_key = key if case_sensitive else key.lower() + if comp_key in original_keys: + yield ConflictDTO(key, value, True) + else: + yield ConflictDTO(key, value, False) + + def intercept_on( + self, hook: HttpInterceptor.InterceptionHook, request: RequestDTO + ) -> RequestDTO: + if hook is HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED: + customizers_to_apply = ( + self._static_headers_customizers + self._dynamic_headers_customizers + ) + return self._handle_headers_customization(request, customizers_to_apply) + elif hook is HttpInterceptor.InterceptionHook.BEFORE_RETRY: + return self._handle_headers_customization( + request, self._dynamic_headers_customizers + ) + return request + + def _handle_headers_customization( + self, request: RequestDTO, headers_customizers: Iterable[HeadersCustomizer] + ) -> RequestDTO: + # copy preventing mutation in the registered customizer + result_headers = dict(request.headers) if request.headers else {} + + for header_customizer in headers_customizers: + try: + if header_customizer.applies_to(request): + additional_headers = header_customizer.get_new_headers(request) + + for key, value, is_conflicting in self.iter_non_conflicting_headers( + result_headers, additional_headers + ): + if is_conflicting: + logger.warning( + f"Overwriting header '{key}' detected. Skipping this key." + ) + else: + result_headers[key] = value + except Exception as ex: + # Custom logic failure is treated as non-fatal for the connection + logger.warning("Unable to customize headers: %s. Skipping...", ex) + + return RequestDTO( + url=request.url, + method=request.method, + headers=result_headers, + ) diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 96a55ad031..0e22fd3a1e 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -10,13 +10,23 @@ import re import time import uuid +import weakref from collections import OrderedDict from threading import Lock -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generator, NamedTuple import OpenSSL.SSL +try: + from snowflake.connector.http_interceptor import ( + Headers, + HttpInterceptor, + RequestDTO, + ) +except ImportError: + pass from snowflake.connector.secret_detector import SecretDetector +from snowflake.connector.vendored import requests from snowflake.connector.vendored.requests.models import PreparedRequest from snowflake.connector.vendored.urllib3.connectionpool import ( HTTPConnectionPool, @@ -98,7 +108,6 @@ get_time_millis, ) from .tool.probe_connection import probe_connection -from .vendored import requests from .vendored.requests import Response, Session from .vendored.requests.adapters import HTTPAdapter from .vendored.requests.auth import AuthBase @@ -251,6 +260,63 @@ def is_login_request(url: str) -> bool: return "login-request" in parse_url(url).path +class AdapterConfig(NamedTuple): + adapter_cls: type[requests.adapters.HTTPAdapter] + partial_kwargs: dict[str, Any] + + +class HttpAdapterFactory: + _global_adapter_class: type[requests.adapters.HTTPAdapter] | None = None + _global_partial_kwargs: dict[str, Any] = {} + + # Weak-key map of connection -> AdapterConfig + _per_connection_adapter_config: weakref.WeakKeyDictionary = ( + weakref.WeakKeyDictionary() + ) + + @classmethod + def register_global( + cls, adapter_cls: type[requests.adapters.HTTPAdapter], **partial_kwargs + ) -> None: + cls._global_adapter_class = adapter_cls + cls._global_partial_kwargs = partial_kwargs + + @classmethod + def register_for_connection( + cls, + connection: Any, + *, + adapter_cls: type[requests.adapters.HTTPAdapter] | None = None, + **partial_kwargs, + ) -> None: + cls._per_connection_adapter_config[connection] = AdapterConfig( + adapter_cls, + partial_kwargs, + ) + + @classmethod + def try_create_adapter( + cls, connection: Any | None = None, **runtime_kwargs + ) -> requests.adapters.HTTPAdapter | None: + if connection in cls._per_connection_adapter_config: + config = cls._per_connection_adapter_config[connection] + elif cls._global_adapter_class is not None: + config = AdapterConfig( + cls._global_adapter_class, cls._global_partial_kwargs + ) + else: + return None + + final_kwargs = {**config.partial_kwargs, **runtime_kwargs} + return config.adapter_cls(**final_kwargs) + + @classmethod + def reset(cls) -> None: + cls._global_adapter_class = None + cls._global_partial_kwargs = {} + cls._per_connection_adapter_config.clear() + + class ProxySupportAdapter(HTTPAdapter): """This Adapter creates proper headers for Proxy CONNECT messages.""" @@ -287,6 +353,33 @@ def get_connection( return conn +class InterceptingAdapter(ProxySupportAdapter): + def __init__(self, interceptors=None, **kwargs): + super().__init__(**kwargs) + self._interceptors = interceptors or [] + + def send(self, request, **kwargs): + retry = kwargs.get("retries") + is_retry = retry is not None and getattr(retry, "history", None) + + hook = ( + HttpInterceptor.InterceptionHook.BEFORE_RETRY + if is_retry + else HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED + ) + + dto = RequestDTO( + method=request.method, url=request.url, headers=request.headers + ) + for interceptor in self._interceptors: + dto = interceptor.intercept_on(hook, dto) + + request.headers.clear() + request.headers.update(dto.headers) + + return super().send(request, **kwargs) + + class RetryRequest(Exception): """Signal to retry request.""" @@ -349,7 +442,7 @@ def get_session(self) -> Session: try: session = self._idle_sessions.pop() except IndexError: - session = self._rest.make_requests_session() + session = self._rest.make_requests_session(connection=self._rest.connection) self._active_sessions.add(session) return session @@ -473,6 +566,10 @@ def mfa_token(self, value: str) -> None: def server_url(self) -> str: return f"{self._protocol}://{self._host}:{self._port}" + @property + def connection(self): + return self._connection + def close(self) -> None: if hasattr(self, "_token"): del self._token @@ -876,7 +973,7 @@ def fetch( self, method: str, full_url: str, - headers: dict[str, Any], + headers: Headers, data: dict[str, Any] | None = None, timeout: int | None = None, **kwargs, @@ -1155,6 +1252,7 @@ def _request_exec( # socket timeout is constant. You should be able to receive # the response within the time. If not, ConnectReadTimeout or # ReadTimeout is raised. + auth = ( PATWithExternalSessionAuth(token, external_session_id) if (external_session_id is not None and token is not None) @@ -1272,15 +1370,24 @@ def _request_exec( except Exception as err: raise err - def make_requests_session(self) -> Session: + def make_requests_session( + self, connection: SnowflakeConnection | None = None + ) -> Session: s = requests.Session() - s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) - s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY)) + adapter = HttpAdapterFactory.try_create_adapter( + connection=connection, max_retries=REQUESTS_RETRY + ) + if adapter is not None: + s.mount("http://", adapter) + s.mount("https://", adapter) + s._reuse_count = itertools.count() return s @contextlib.contextmanager - def _use_requests_session(self, url: str | None = None): + def _use_requests_session( + self, url: str | bytes | None = None + ) -> Generator[Session, Any, None]: """Session caching context manager. Notes: @@ -1288,7 +1395,7 @@ def _use_requests_session(self, url: str | None = None): """ # short-lived session, not added to the _sessions_map if self._connection.disable_request_pooling: - session = self.make_requests_session() + session = self.make_requests_session(connection=self._connection) try: yield session finally: diff --git a/src/snowflake/connector/storage_client.py b/src/snowflake/connector/storage_client.py index 7fc8b67dfa..bdc5421870 100644 --- a/src/snowflake/connector/storage_client.py +++ b/src/snowflake/connector/storage_client.py @@ -15,6 +15,7 @@ import OpenSSL +from . import SnowflakeConnection from .constants import ( HTTP_HEADER_CONTENT_ENCODING, REQUEST_CONNECTION_TIMEOUT, @@ -269,6 +270,13 @@ def finish_upload(self) -> None: def _has_expired_token(self, response: requests.Response) -> bool: pass + @property + def connection(self) -> SnowflakeConnection | None: + if self.meta.sfagent: + return self.meta.sfagent._cursor.connection + else: + return None + def _send_request_with_retry( self, verb: str, @@ -277,18 +285,16 @@ def _send_request_with_retry( ) -> requests.Response: rest_call = METHODS[verb] url = b"" - conn = None - if self.meta.sfagent and self.meta.sfagent._cursor.connection: - conn = self.meta.sfagent._cursor.connection while self.retry_count[retry_id] < self.max_retry: logger.debug(f"retry #{self.retry_count[retry_id]}") cur_timestamp = self.credentials.timestamp url, rest_kwargs = get_request_args() + rest_kwargs["timeout"] = (REQUEST_CONNECTION_TIMEOUT, REQUEST_READ_TIMEOUT) try: - if conn: - with conn._rest._use_requests_session(url) as session: + if self.connection: + with self.connection._rest._use_requests_session(url) as session: logger.debug(f"storage client request with session {session}") response = session.request(verb, url, **rest_kwargs) else: diff --git a/test/conftest.py b/test/conftest.py index e8a8081b20..3e017b80f5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,12 @@ from contextlib import contextmanager from logging import getLogger from pathlib import Path + +try: + from test.test_utils.cross_module_fixtures.http_fixtures import * # NOQA +except (ImportError, NameError): + pass +from test.test_utils.cross_module_fixtures.wiremock_fixtures import * # NOQA from typing import Generator import pytest diff --git a/test/data/wiremock/mappings/auth/password/successful_flow.json b/test/data/wiremock/mappings/auth/password/successful_flow.json new file mode 100644 index 0000000000..58045d6fe3 --- /dev/null +++ b/test/data/wiremock/mappings/auth/password/successful_flow.json @@ -0,0 +1,60 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/session/v1/login-request.*", + "method": "POST", + "bodyPatterns": [ + { + "equalToJson" : { + "data": { + "LOGIN_NAME": "testUser", + "PASSWORD": "testPassword" + } + }, + "ignoreExtraElements" : true + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "masterToken": "master token", + "token": "session token", + "validityInSeconds": 3600, + "masterValidityInSeconds": 14400, + "displayUserName": "TEST_USER", + "serverVersion": "8.48.0 b2024121104444034239f05", + "firstLogin": false, + "remMeToken": null, + "remMeValidityInSeconds": 0, + "healthCheckInterval": 45, + "newClientForUpgrade": "3.12.3", + "sessionId": 1172562260498, + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + } + ], + "sessionInfo": { + "databaseName": "TEST_DB", + "schemaName": "TEST_GO", + "warehouseName": "TEST_XSMALL", + "roleName": "ANALYST" + }, + "idToken": null, + "idTokenValidityInSeconds": 0, + "responseData": null, + "mfaToken": null, + "mfaTokenValidityInSeconds": 0 + }, + "code": null, + "message": null, + "success": true + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/generic/telemetry.json b/test/data/wiremock/mappings/generic/telemetry.json new file mode 100644 index 0000000000..9b734a0cf2 --- /dev/null +++ b/test/data/wiremock/mappings/generic/telemetry.json @@ -0,0 +1,18 @@ +{ + "scenarioName": "Successful telemetry flow", + "request": { + "urlPathPattern": "/telemetry/send", + "method": "POST" + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "code": null, + "data": "Log Received", + "message": null, + "success": true + } + } + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_1.json b/test/data/wiremock/mappings/queries/chunk_1.json new file mode 100644 index 0000000000..246874d3c4 --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_1.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_2.json b/test/data/wiremock/mappings/queries/chunk_2.json new file mode 100644 index 0000000000..60f2756d0e --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_2.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/select_1_successful.json b/test/data/wiremock/mappings/queries/select_1_successful.json new file mode 100644 index 0000000000..99fdcb7103 --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_1_successful.json @@ -0,0 +1,199 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM" + }, + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 16 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": true + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "America/Los_Angeles" + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + } + ], + "rowtype": [ + { + "name": "1", + "database": "", + "schema": "", + "table": "", + "nullable": false, + "length": null, + "type": "fixed", + "scale": 0, + "precision": 1, + "byteLength": null, + "collation": null + } + ], + "rowset": [ + [ + "1" + ] + ], + "total": 1, + "returned": 1, + "queryId": "01ba13b4-0104-e9fd-0000-0111029ca00e", + "databaseProvider": null, + "finalDatabaseName": null, + "finalSchemaName": null, + "finalWarehouseName": "TEST_XSMALL", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1738317395581, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1738317395574564, + "priority": 0, + "context": "CPbPTg==" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/data/wiremock/mappings/queries/select_large_request_successful.json b/test/data/wiremock/mappings/queries/select_large_request_successful.json new file mode 100644 index 0000000000..61ee3135a6 --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_large_request_successful.json @@ -0,0 +1,413 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "DY, DD MON YYYY HH24:MI:SS TZHTZM" + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_MIN_VERSION_FOR_AST", + "value": "1.29.0" + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 160 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION", + "value": "1.31.1" + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_VERSION", + "value": "" + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": false + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "UTC" + }, + { + "name": "PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER", + "value": true + }, + { + "name": "SNOWPARK_REQUEST_TIMEOUT_IN_SECONDS", + "value": 86400 + }, + { + "name": "PYTHON_SNOWPARK_USE_AST", + "value": false + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "PYTHON_CONNECTOR_USE_NANOARROW", + "value": true + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND", + "value": 10000000 + }, + { + "name": "PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED", + "value": false + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_DATAFRAME_JOIN_ALIAS_FIX_VERSION", + "value": "" + }, + { + "name": "PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION", + "value": "1.28.0" + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED_VERSION", + "value": "" + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED", + "value": false + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "PYTHON_SNOWPARK_COMPILATION_STAGE_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND", + "value": 12000000 + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_AST_MODE", + "value": 0 + } + ], + "rowtype": [ + { + "name": "C0", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C1", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C2", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C3", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C4", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C5", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C6", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C7", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C8", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C9", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + } + ], + + "rowset": [ + [ + "1" + ] + ], + "qrmk": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "chunkHeaders": { + "x-amz-server-side-encryption-customer-key": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "x-amz-server-side-encryption-customer-key-md5": "ByrEgrMhjgAEMRr1QA/nGg==" + }, + "chunks": [ + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326422 + }, + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326176 + } + ], + "total": 50000, + "returned": 50000, + "queryId": "01bd137c-0100-0001-0000-0000001005b1", + "databaseProvider": null, + "finalDatabaseName": "TESTDB", + "finalSchemaName": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "finalWarehouseName": "REGRESS", + "finalRoleName": "ACCOUNTADMIN", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1750110502822, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1748552075465658, + "priority": 0, + "context": "CAQ=" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/integ/test_http_interceptor.py b/test/integ/test_http_interceptor.py new file mode 100644 index 0000000000..1f57748b7e --- /dev/null +++ b/test/integ/test_http_interceptor.py @@ -0,0 +1,323 @@ +import os +import pathlib +from typing import Deque + +import pytest + +from snowflake.connector import SnowflakeConnection + +from ..generate_test_files import generate_k_lines_of_n_files +from ..test_utils.http_test_utils import RequestTracker +from .test_large_result_set import ingest_data # NOQA + +try: + from snowflake.connector.http_interceptor import RequestDTO + from snowflake.connector.util_text import random_string + +except (ImportError, NameError): # Keep olddrivertest from breaking + from ..randomize import random_string + + +@pytest.fixture(scope="session") +def wiremock_auth_dir(wiremock_mapping_dir) -> pathlib.Path: + return wiremock_mapping_dir / "auth" + + +@pytest.fixture(scope="session") +def wiremock_queries_dir(wiremock_mapping_dir) -> pathlib.Path: + return wiremock_mapping_dir / "queries" + + +@pytest.fixture(scope="session") +def wiremock_password_auth_dir(wiremock_auth_dir) -> pathlib.Path: + return wiremock_auth_dir / "password" + + +@pytest.fixture(scope="session") +def current_provider(): + return os.getenv("cloud_provider", "dev") + + +@pytest.mark.parametrize("execute_on_wiremock", (True, False)) +@pytest.mark.skipolddriver +def test_interceptor_detects_expected_requests_in_successful_flow_select_1( + request, + execute_on_wiremock, + wiremock_password_auth_dir, + wiremock_generic_mappings_dir, + wiremock_queries_dir, + static_collecting_customizer, + conn_cnx, + conn_cnx_wiremock, +) -> None: + """ + This kind of test ensures that the request interceptor correctly captures all expected HTTP requests + during a known use-case scenario of the driver. + + By covering both real and mock (WireMock) executions via the `execute_on_wiremock` param, + we validate two critical things: + 1. **Real execution**: Interceptor must detect and log every actual request in a live environment. + 2. **WireMock execution**: Verifies that the stubbed mappings in WireMock match the current behavior + of the driver, so any future API changes (added/removed requests) will cause the test to fail. + + This dual setup: + - Prevents stale mocks (detects missing or extra requests not aligned with the driver). + - Guarantees that the interceptor observes requests **in correct order** and no unexpected traffic occurs. + - Ensures the customizer receives all invocations and no requests go untracked. + + If new driver code introduces additional HTTP calls, this test will fail until WireMock is updated + to reflect them, prompting to verify those new requests in the flow. + """ + + def assert_expected_requests_occurred(conn: SnowflakeConnection) -> None: + requests: Deque[RequestDTO] = static_collecting_customizer.invocations + tracker = RequestTracker(requests) + + with conn as connection_context: + tracker.assert_login_issued() + cursor = connection_context.cursor().execute("select 1") + tracker.assert_sql_query_issued() + + result = cursor.fetchall() + assert len(result) == 1, "Result should contain exactly one row" + assert ( + result[0][0] == 1 + ), "Result should contain the value 1 in the first row and the first column" + + tracker.assert_telemetry_send_issued() + tracker.assert_disconnect_issued() + + if execute_on_wiremock: + local_wiremock_client = request.getfixturevalue("wiremock_client") + local_wiremock_client.import_mapping( + wiremock_password_auth_dir / "successful_flow.json" + ) + local_wiremock_client.add_mapping( + wiremock_queries_dir / "select_1_successful.json" + ) + local_wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + local_wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "telemetry.json" + ) + + connection = conn_cnx_wiremock( + headers_customizers=[static_collecting_customizer] + ) + else: + connection = conn_cnx(headers_customizers=[static_collecting_customizer]) + + assert_expected_requests_occurred(connection) + + +@pytest.mark.parametrize("execute_on_wiremock", (True, False)) +@pytest.mark.skipolddriver +def test_interceptor_detects_expected_requests_in_successful_flow_with_chunks( + request, + execute_on_wiremock, + wiremock_password_auth_dir, + wiremock_generic_mappings_dir, + wiremock_queries_dir, + static_collecting_customizer, + conn_cnx, + conn_cnx_wiremock, + db_parameters, + default_db_wiremock_parameters, +) -> None: + + def assert_expected_requests_occurred( + conn: SnowflakeConnection, expected_large_table_name: str + ) -> None: + requests: Deque[RequestDTO] = static_collecting_customizer.invocations + tracker = RequestTracker(requests) + + with conn as connection_context: + tracker.assert_login_issued() + sql = f"select * from {expected_large_table_name} order by 1" + cursor = connection_context.cursor().execute(sql) + tracker.assert_sql_query_issued() + cursor.fetchall() + tracker.assert_get_chunk_issued() + + tracker.assert_telemetry_send_issued() + tracker.assert_disconnect_issued() + + if execute_on_wiremock: + local_wiremock_client = request.getfixturevalue("wiremock_client") + local_wiremock_client.import_mapping( + wiremock_password_auth_dir / "successful_flow.json" + ) + local_wiremock_client.add_mapping( + wiremock_queries_dir / "select_large_request_successful.json", + placeholders=local_wiremock_client.http_placeholders, + ) + local_wiremock_client.add_mapping( + wiremock_queries_dir / "chunk_1.json", + placeholders=local_wiremock_client.http_placeholders, + ) + local_wiremock_client.add_mapping( + wiremock_queries_dir / "chunk_2.json", + placeholders=local_wiremock_client.http_placeholders, + ) + local_wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + local_wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "telemetry.json" + ) + + connection = conn_cnx_wiremock( + headers_customizers=[static_collecting_customizer] + ) + large_table_name = default_db_wiremock_parameters["name"] + else: + request.getfixturevalue("ingest_data") + connection = conn_cnx(headers_customizers=[static_collecting_customizer]) + large_table_name = db_parameters["name"] + + assert_expected_requests_occurred(connection, large_table_name) + + +@pytest.mark.skipolddriver +def test_interceptor_detects_expected_requests_in_successful_flow_put_get( + tmp_path: pathlib.Path, + static_collecting_customizer, + conn_cnx, + current_provider, +): + def _assert_expected_requests_occurred(conn: SnowflakeConnection) -> None: + requests: Deque[RequestDTO] = static_collecting_customizer.invocations + tracker = RequestTracker(requests) + + test_file = tmp_path / "single_part.txt" + test_file.write_text("test,data\n") + download_dir = tmp_path / "download" + download_dir.mkdir() + stage_name = random_string(5, "test_put_get_") + + with conn as cxn: + with cxn.cursor() as cursor: + tracker.assert_login_issued() + + cursor.execute(f"create temporary stage {stage_name}") + tracker.assert_sql_query_issued() + + put_sql = f"PUT file://{test_file} @{stage_name} AUTO_COMPRESS = FALSE" + cursor.execute(put_sql) + tracker.assert_sql_query_issued() + if current_provider in ("aws", "dev"): + tracker.assert_aws_get_accelerate_issued(optional=True) + + tracker.assert_file_head_issued(test_file.name) + tracker.assert_put_file_issued(filename=test_file.name) + + get_sql = f"GET @{stage_name}/{test_file.name} file://{download_dir}" + cursor.execute(get_sql) + tracker.assert_sql_query_issued() + if current_provider in ("aws", "dev"): + tracker.assert_aws_get_accelerate_issued(optional=True) + tracker.assert_file_head_issued(test_file.name) + tracker.assert_get_file_issued(test_file.name) + + tracker.assert_telemetry_send_issued() + tracker.assert_disconnect_issued() + + connection = conn_cnx(headers_customizers=[static_collecting_customizer]) + _assert_expected_requests_occurred(connection) + + +@pytest.mark.aws +@pytest.mark.azure +@pytest.mark.skipolddriver +def test_interceptor_detects_expected_requests_in_successful_multipart_put_get( + tmp_path: pathlib.Path, + static_collecting_customizer, + dynamic_collecting_customizer, + conn_cnx, + current_provider, +): + """Verifies request flow for multipart PUT and GET of a large file, with MD5 check and optional WireMock.""" + + def _assert_expected_requests_occurred_multipart( + connection: SnowflakeConnection, + ) -> None: + requests: Deque[RequestDTO] = static_collecting_customizer.invocations + tracker = RequestTracker(requests) + + big_folder = tmp_path / "big" + big_folder.mkdir() + generate_k_lines_of_n_files(3_000_000, 1, tmp_dir=str(big_folder)) + big_test_file = big_folder / "file0" + + stage_name = random_string(5, "test_multipart_put_get_") + stage_path = "bigdata" + download_dir = tmp_path / "download" + download_dir.mkdir() + big_test_file_stage_path = f"{stage_path}/{big_test_file.name}" + + with connection as cnx: + with cnx.cursor() as cur: + tracker.assert_login_issued() + + cur.execute(f"create temporary stage {stage_name}") + tracker.assert_sql_query_issued() + + clean_file_path = str(big_test_file).replace("\\", "/") + cur.execute( + f"PUT 'file://{clean_file_path}' " + f"@{stage_name}/{stage_path} AUTO_COMPRESS=FALSE" + ) + tracker.assert_sql_query_issued() + if current_provider in ("aws", "dev"): + tracker.assert_aws_get_accelerate_issued(optional=True) + + tracker.assert_file_head_issued(big_test_file.name, sequentially=False) + + if current_provider in ("aws", "dev"): + tracker.assert_post_start_for_multipart_file_issued( + sequentially=False, file_path=big_test_file_stage_path + ) + + tracker.assert_multiple_put_file_issued( + big_test_file.name, current_provider, sequentially=False + ) + tracker.assert_end_for_multipart_file_issued( + cloud_platform=current_provider + ) + + cur.execute( + f"GET @{stage_name}/{big_test_file_stage_path} file://{download_dir}" + ) + tracker.assert_sql_query_issued() + if current_provider in ("aws", "dev"): + tracker.assert_aws_get_accelerate_issued(optional=True) + tracker.assert_file_head_issued(big_test_file.name) + tracker.assert_get_file_issued(big_test_file.name) + + tracker.assert_telemetry_send_issued() + tracker.assert_disconnect_issued() + + conn = conn_cnx( + headers_customizers=[ + static_collecting_customizer, + dynamic_collecting_customizer, + ] + ) + try: + _assert_expected_requests_occurred_multipart(conn) + except AssertionError as ex: + list_of_inv = ( + str(ex) + + "\n\n" + + str( + "\n".join( + map( + lambda r: f"{r.method} {r.url}", + dynamic_collecting_customizer.invocations, + ) + ) + ) + ) + print(list_of_inv) + raise AssertionError(list_of_inv) diff --git a/test/wiremock/__init__.py b/test/test_utils/__init__.py similarity index 100% rename from test/wiremock/__init__.py rename to test/test_utils/__init__.py diff --git a/test/test_utils/cross_module_fixtures/__init__.py b/test/test_utils/cross_module_fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_utils/cross_module_fixtures/http_fixtures.py b/test/test_utils/cross_module_fixtures/http_fixtures.py new file mode 100644 index 0000000000..2e761a8007 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/http_fixtures.py @@ -0,0 +1,13 @@ +import pytest + +from ..http_test_utils import DynamicCollectingCustomizer, StaticCollectingCustomizer + + +@pytest.fixture +def static_collecting_customizer(): + return StaticCollectingCustomizer() + + +@pytest.fixture +def dynamic_collecting_customizer(): + return DynamicCollectingCustomizer() diff --git a/test/test_utils/cross_module_fixtures/wiremock_fixtures.py b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py new file mode 100644 index 0000000000..14e6c72d46 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py @@ -0,0 +1,66 @@ +import pathlib +import uuid +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, ContextManager, Generator, Union + +import pytest + +import snowflake.connector + +from ..wiremock.wiremock_utils import WiremockClient + + +@pytest.fixture(scope="session") +def wiremock_mapping_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent / "data" / "wiremock" / "mappings" + ) + + +@pytest.fixture(scope="session") +def wiremock_generic_mappings_dir(wiremock_mapping_dir) -> pathlib.Path: + return wiremock_mapping_dir / "generic" + + +@pytest.fixture(scope="session") +def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: + with WiremockClient() as client: + yield client + + +@pytest.fixture +def default_db_wiremock_parameters(wiremock_client: WiremockClient) -> dict[str, Any]: + db_params = { + "account": "testAccount", + "user": "testUser", + "password": "testPassword", + "host": wiremock_client.wiremock_host, + "port": wiremock_client.wiremock_http_port, + "protocol": "http", + "name": "python_tests_" + str(uuid.uuid4()).replace("-", "_"), + } + return db_params + + +@contextmanager +def db_wiremock( + default_db_wiremock_parameters: dict[str, Any], + **kwargs, +) -> Generator[snowflake.connector.SnowflakeConnection, None, None]: + ret = default_db_wiremock_parameters + ret.update(kwargs) + cnx = snowflake.connector.connect(**ret) + try: + yield cnx + finally: + cnx.close() + + +@pytest.fixture +def conn_cnx_wiremock( + default_db_wiremock_parameters, +) -> Callable[..., ContextManager[snowflake.connector.SnowflakeConnection]]: + return partial( + db_wiremock, default_db_wiremock_parameters=default_db_wiremock_parameters + ) diff --git a/test/test_utils/http_test_utils.py b/test/test_utils/http_test_utils.py new file mode 100644 index 0000000000..4265185bde --- /dev/null +++ b/test/test_utils/http_test_utils.py @@ -0,0 +1,468 @@ +import re +import threading +from collections import deque +from typing import ( + Any, + Deque, + Dict, + FrozenSet, + Iterable, + NamedTuple, + Optional, + Tuple, + Union, +) + +try: + from snowflake.connector.http_interceptor import ( + Headers, + HeadersCustomizer, + RequestDTO, + ) +except ImportError: + HeadersCustomizer = object + RequestDTO = Tuple + Headers = Dict[str, Any] + + +class CollectingCustomizer(HeadersCustomizer): + def __init__(self) -> None: + self.invocations: Deque[RequestDTO] = deque() + self._lock: threading.Lock = threading.Lock() + + def applies_to(self, request: RequestDTO) -> bool: + return True + + def get_new_headers(self, request: RequestDTO) -> Dict[str, str]: + with self._lock: + self.invocations.append(request) + return {"test-header": "test-value"} + + +class StaticCollectingCustomizer(CollectingCustomizer): + def is_invoked_once(self) -> bool: + return True + + +class DynamicCollectingCustomizer(CollectingCustomizer): + def is_invoked_once(self) -> bool: + return False + + +class ExpectedRequestInfo(NamedTuple): + method: str + url_regexp: Optional[str] = None + + def is_matching(self, request: object) -> bool: + if isinstance(request, RequestDTO): + return request.method.upper() == self.method.upper() and bool( + re.fullmatch(self.url_regexp, request.url) + ) + elif isinstance(request, ExpectedRequestInfo): + return ( + request.method == self.method.upper() + and request.url_regexp == self.url_regexp + ) + return False + + +class RequestTracker: + DEFAULT_REQUESTS_TO_IGNORE_IN_CHECKS: FrozenSet[ExpectedRequestInfo] = frozenset( + [ExpectedRequestInfo("GET", r".*/__admin/health")] + ) + + def __init__( + self, + requests: Deque[RequestDTO], + ignored: Optional[ + Iterable[ExpectedRequestInfo] + ] = DEFAULT_REQUESTS_TO_IGNORE_IN_CHECKS, + ) -> None: + self.requests = requests + self.ignored = ignored or () + self._last_request = None + self._last_expected_info = None + + @staticmethod + def _assert_headers_were_added( + headers: Headers, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + ) -> None: + expected_headers = dict(expected_headers) + headers_lowercased = {k.lower(): v for k, v in headers.items()} + for expected_header_name, expected_header_value in expected_headers.items(): + assert ( + headers_lowercased.get(expected_header_name.lower()) + == expected_header_value + ), f"Custom header not found: {expected_header_name}" + + def _should_ignore(self, request: RequestDTO) -> bool: + return any(ignored_info.is_matching(request) for ignored_info in self.ignored) + + def assert_request_occurred( + self, + expected: ExpectedRequestInfo, + raise_on_missing: bool = True, + ) -> Optional[RequestDTO]: + for i, request in enumerate(self.requests): + if self._should_ignore(request): + continue + + if expected.is_matching(request): + self._last_request = request + self._last_expected_info = expected + # Pop the matched request from the associated index in the deque, while searching the whole collection + del self.requests[i] + return request + if raise_on_missing: + raise AssertionError( + f"Expected request '{expected.method} {expected.url_regexp}' not found" + ) + else: + return None + + def assert_request_occurred_sequentially( + self, + expected: ExpectedRequestInfo, + raise_on_missing: bool = True, + skip_previous_request_retries: bool = True, + ) -> Optional[RequestDTO]: + while self.requests: + request = self.requests.popleft() + if self._should_ignore(request): + continue + + if expected.is_matching(request): + self._last_request = request + self._last_expected_info = expected + return request + + if ( + skip_previous_request_retries + and self._last_expected_info + and self._last_expected_info.is_matching(request) + ): + self._last_request = request # skip retry + continue + + # Rollback and exit + self.requests.appendleft(request) + if raise_on_missing: + raise AssertionError(f"Unexpected request: {request}") + return None + if raise_on_missing: + raise AssertionError( + f"Expected request '{expected.method} {expected.url_regexp}' not found" + ) + return None + + def _assert_issued_with_custom_headers( + self, + expected: ExpectedRequestInfo, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]], + sequentially: bool, + optional: bool, + ) -> Optional[RequestDTO]: + found_request: Optional[RequestDTO] = ( + self.assert_request_occurred_sequentially( + expected, raise_on_missing=not optional + ) + if sequentially + else self.assert_request_occurred(expected, raise_on_missing=not optional) + ) + if found_request: + self._assert_headers_were_added(found_request.headers, expected_headers) + return found_request + + def assert_login_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo("POST", r".*/session/v1/login-request.*"), + expected_headers, + sequentially, + optional, + ) + + def assert_sql_query_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo("POST", r".*/queries/v1/query-request.*"), + expected_headers, + sequentially, + optional, + ) + + def assert_telemetry_send_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo("POST", r".*/telemetry/send.*"), + expected_headers, + sequentially, + optional, + ) + + def assert_disconnect_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo("POST", r".*/session\?delete=true(\&request_guid=.*)?"), + expected_headers, + sequentially, + optional, + ) + + def assert_get_chunk_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "GET", + r".*(amazonaws|blob\.core\.windows|storage\.googleapis).*/results/.*main.*data.*\?.*", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_aws_get_accelerate_issued( + self, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially=True, + optional=True, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo("GET", r".*\.s3(.*)?\.amazonaws.*/\?accelerate(.*)?"), + expected_headers, + sequentially, + optional, + ) + + def assert_get_file_issued( + self, + filename: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "GET", + r".*(s3(.*)?\.amazonaws|blob\.core\.windows|storage\.googleapis).*" + + (filename or "") + + r"(.*)?", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_put_file_issued( + self, + filename: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "PUT", + r".*(s3(.*)?\.amazonaws|blob\.core\.windows|storage\.googleapis).*stages.*" + + (filename or "") + + r"(.*)?", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_put_file_part_in_multipart_issued( + self, + filename: Optional[str] = None, + cloud_platform: Union[str, None] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + expected_request = None + optional_filename_regexp = (filename + "(.*)?") if filename else "" + if cloud_platform == "azure": + expected_request = ExpectedRequestInfo( + "PUT", + r".*blob\.core\.windows.*stages.*" + + optional_filename_regexp + + r"\?comp=block&blockid=.*", + ) + elif cloud_platform in ("aws", "dev", "gcp"): + expected_request = ExpectedRequestInfo( + "PUT", + r".*(s3(.*)?\.amazonaws|storage\.googleapis).*stages.*" + + optional_filename_regexp, + ) + return self._assert_issued_with_custom_headers( + expected_request, + expected_headers, + sequentially, + optional, + ) + + def assert_file_head_issued( + self, + filename: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "HEAD", + r".*(amazonaws|blob\.core\.windows|storage\.googleapis).*" + + (filename or "") + + r"(.*)?", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_post_start_for_multipart_file_issued( + self, + file_path: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "POST", + r".*s3.*\.amazonaws.*stages.*" + (file_path or "") + r"\?uploads", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_post_end_for_multipart_on_aws_file_issued( + self, + file_path: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "POST", + r".*s3(.*)?\.amazonaws.*/stages/.*" + + (file_path or "") + + r".*uploadId=.*", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_put_end_for_multipart_on_azure_file_issued( + self, + file_path: Optional[str] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + ) -> Optional[RequestDTO]: + return self._assert_issued_with_custom_headers( + ExpectedRequestInfo( + "PUT", + r".*blob\.core\.windows.*stages.*" + + (file_path or "") + + r".*?comp=blocklist(.*)?", + ), + expected_headers, + sequentially, + optional, + ) + + def assert_end_for_multipart_file_issued( + self, + cloud_platform: str, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = False, + file_path: Optional[str] = None, + ) -> Optional[RequestDTO]: + if cloud_platform in ("aws", "dev"): + return self.assert_post_end_for_multipart_on_aws_file_issued( + file_path, expected_headers, sequentially, optional + ) + elif cloud_platform == "azure": + return self.assert_put_end_for_multipart_on_azure_file_issued( + file_path, expected_headers, sequentially, optional + ) + return None + + def assert_multiple_put_file_issued( + self, + filename: Optional[str] = None, + cloud_platform: Union[str, None] = None, + expected_headers: Union[Dict[str, Any], Tuple[Tuple[str, Any], ...]] = ( + ("test-header", "test-value"), + ), + sequentially: bool = True, + optional: bool = True, + ) -> None: + self.assert_put_file_part_in_multipart_issued( + filename, cloud_platform, expected_headers, sequentially=sequentially + ) + while self.assert_put_file_part_in_multipart_issued( + filename, + cloud_platform, + expected_headers, + sequentially=sequentially, + optional=optional, + ): + continue diff --git a/test/test_utils/wiremock/__init__.py b/test/test_utils/wiremock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/wiremock/wiremock_utils.py b/test/test_utils/wiremock/wiremock_utils.py similarity index 74% rename from test/wiremock/wiremock_utils.py rename to test/test_utils/wiremock/wiremock_utils.py index 1d036a8023..6e54a8be98 100644 --- a/test/wiremock/wiremock_utils.py +++ b/test/test_utils/wiremock/wiremock_utils.py @@ -38,7 +38,9 @@ def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None: self.wiremock_https_port = None self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] - self.wiremock_dir = pathlib.Path(__file__).parent.parent.parent / ".wiremock" + self.wiremock_dir = ( + pathlib.Path(__file__).parent.parent.parent.parent / ".wiremock" + ) assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" self.wiremock_jar_path = self.wiremock_dir / self.wiremock_filename @@ -46,6 +48,14 @@ def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None: self.wiremock_jar_path.exists() ), f"{self.wiremock_jar_path} does not exist" + @property + def http_host_with_port(self) -> str: + return f"http://{self.wiremock_host}:{self.wiremock_http_port}" + + @property + def http_placeholders(self) -> dict[str, str]: + return {"{{WIREMOCK_HTTP_HOST_WITH_PORT}}": self.http_host_with_port} + def _start_wiremock(self): self.wiremock_http_port = self._find_free_port( forbidden_ports=self.forbidden_ports, @@ -76,14 +86,22 @@ def _start_wiremock(self): self._wait_for_wiremock() def _stop_wiremock(self): - response = self._wiremock_post( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" - ) - if response.status_code != 200: - logger.info("Wiremock shutdown failed, the process will be killed") + if self.wiremock_process.poll() is not None: + logger.warning("Wiremock process already exited, skipping shutdown") + return + + try: + response = self._wiremock_post( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" + ) + if response.status_code != 200: + logger.info("Wiremock shutdown failed, the process will be killed") + self.wiremock_process.kill() + else: + logger.debug("Wiremock shutdown gracefully") + except requests.exceptions.RequestException as e: + logger.warning(f"Shutdown request failed: {e}. Killing process directly.") self.wiremock_process.kill() - else: - logger.debug("Wiremock shutdown gracefully") def _wait_for_wiremock(self): retry_count = 0 @@ -139,19 +157,36 @@ def _wiremock_post( headers = {"Accept": "application/json", "Content-Type": "application/json"} return requests.post(endpoint, data=body, headers=headers) - def import_mapping(self, mapping: Union[str, dict, pathlib.Path]): + def _replace_placeholders_in_mapping( + self, mapping_str: str, placeholders: Optional[dict[str, object]] + ) -> str: + if not placeholders: + return mapping_str + for key, value in placeholders.items(): + mapping_str = mapping_str.replace(str(key), str(value)) + return mapping_str + + def import_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + ): self._reset_wiremock() - import_mapping_endpoint = f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings/import" + import_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings/import" mapping_str = _get_mapping_str(mapping) + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) response = self._wiremock_post(import_mapping_endpoint, mapping_str) if response.status_code != requests.codes.ok: raise RuntimeError("Failed to import mapping") - def add_mapping(self, mapping: Union[str, dict, pathlib.Path]): - add_mapping_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings" - ) + def add_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + ): + add_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings" mapping_str = _get_mapping_str(mapping) + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) response = self._wiremock_post(add_mapping_endpoint, mapping_str) if response.status_code != requests.codes.created: raise RuntimeError("Failed to add mapping") diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 65c2fb02f6..1494af8468 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -1,7 +1,10 @@ from __future__ import annotations +from typing import Callable + import pytest +from snowflake.connector.http_interceptor import Headers, HeadersCustomizer, RequestDTO from snowflake.connector.telemetry_oob import TelemetryService from ..csp_helpers import ( @@ -52,3 +55,68 @@ def fake_gce_metadata_service(): """Emulates the GCE metadata service, returning a dummy token.""" with FakeGceMetadataService() as server: yield server + + +@pytest.fixture +def sample_request_factory(): + def _factory( + url="https://test.snowflakecomputing.com/api/v1", + method="GET", + headers=None, + ): + return RequestDTO( + url=url, + method=method, + headers=headers or {"User-Agent": "SnowflakeDriver/1.0"}, + ) + + return _factory + + +@pytest.fixture +def headers_customizer_factory(): + def _customizer_factory( + applies: Callable[[RequestDTO], bool] | bool = True, + invoke_once: bool = True, + headers: Callable[[RequestDTO], Headers] | Headers = None, + ): + class MockCustomizer(HeadersCustomizer): + def applies_to(self, request: RequestDTO) -> bool: + if callable(applies): + return applies(request) + return applies + + def is_invoked_once(self) -> bool: + return invoke_once + + def get_new_headers(self, request: RequestDTO) -> Headers: + if callable(headers): + return headers(request) + return headers or {} + + return MockCustomizer() + + return _customizer_factory + + +@pytest.fixture +def dynamic_customizer_factory(): + def _dynamic_factory(): + counter = {"count": 0} + + class DynamicCustomizer(HeadersCustomizer): + def applies_to(self, request: RequestDTO) -> bool: + return True + + def is_invoked_once(self) -> bool: + return False + + def get_new_headers(self, request: RequestDTO) -> Headers: + counter["count"] += 1 + return { + f"X-Dynamic-{counter['count']}": f"DynamicVal-{counter['count']}" + } + + return DynamicCustomizer() + + return _dynamic_factory diff --git a/test/unit/test_http_interceptor.py b/test/unit/test_http_interceptor.py new file mode 100644 index 0000000000..c91bdcd8d1 --- /dev/null +++ b/test/unit/test_http_interceptor.py @@ -0,0 +1,214 @@ +from unittest.mock import Mock + +import pytest + +try: + from snowflake.connector.http_interceptor import ( + HeadersCustomizer, + HeadersCustomizerInterceptor, + HttpInterceptor, + ) +except ImportError: + pass + + +@pytest.mark.skipolddriver +def test_no_interceptors_does_nothing(sample_request_factory): + request = sample_request_factory() + interceptor = HeadersCustomizerInterceptor([]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + assert result == request + + +@pytest.mark.skipolddriver +def test_non_applying_interceptor_does_nothing( + sample_request_factory, headers_customizer_factory +): + request = sample_request_factory() + customizer = headers_customizer_factory(applies=False) + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + assert result == request + + +@pytest.mark.skipolddriver +def test_non_applying_interceptor_not_called(sample_request_factory): + request = sample_request_factory() + + # Create a mock customizer with applies_to returning False + customizer = Mock(spec=HeadersCustomizer) + customizer.applies_to.return_value = False + customizer.is_invoked_once.return_value = True + customizer.get_new_headers.return_value = {} + + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + + # Assert result is unchanged + assert result == request + + # Check call counts + customizer.applies_to.assert_called_once_with(request) + customizer.get_new_headers.assert_not_called() + + +@pytest.mark.skipolddriver +def test_dynamic_customizer_adds_different_headers( + sample_request_factory, dynamic_customizer_factory +): + request = sample_request_factory() + interceptor = HeadersCustomizerInterceptor([dynamic_customizer_factory()]) + + result1 = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_RETRY, request + ) + result2 = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_RETRY, request + ) + + assert result1.headers["X-Dynamic-1"] != result2.headers["X-Dynamic-2"] + + +@pytest.mark.skipolddriver +def test_invoke_once_skips_on_retry(sample_request_factory, headers_customizer_factory): + request = sample_request_factory() + customizer = headers_customizer_factory(applies=True, invoke_once=True) + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_RETRY, request + ) + assert result == request + + +@pytest.mark.skipolddriver +def test_invoke_always_runs_on_retry( + sample_request_factory, headers_customizer_factory +): + request = sample_request_factory() + customizer = headers_customizer_factory( + applies=True, invoke_once=False, headers={"X-Retry": "RetryVal"} + ) + interceptor = HeadersCustomizerInterceptor([customizer]) + result1 = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_RETRY, request + ) + result2 = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_RETRY, request + ) + assert result1.headers["X-Retry"] == "RetryVal" + assert result2.headers["X-Retry"] == "RetryVal" + + +@pytest.mark.skipolddriver +def test_prevents_header_overwrite(sample_request_factory, headers_customizer_factory): + request = sample_request_factory(headers={"User-Agent": "SnowflakeDriver/1.0"}) + customizer = headers_customizer_factory( + applies=True, invoke_once=True, headers={"User-Agent": "MaliciousAgent"} + ) + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + assert result.headers["User-Agent"] == "SnowflakeDriver/1.0" + assert result.headers["User-Agent"] != "MaliciousAgent" + + +@pytest.mark.skipolddriver +def test_partial_header_overwrite_ignores_only_conflicting_keys( + sample_request_factory, headers_customizer_factory +): + request = sample_request_factory(headers={"User-Agent": "SnowflakeDriver/1.0"}) + + customizer = headers_customizer_factory( + applies=True, + invoke_once=True, + headers={ + "User-Agent": "MaliciousAgent", # should be blocked + "X-New-Header": "NewValue", # should be added + }, + ) + + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + + # Original value preserved + assert result.headers["User-Agent"] == "SnowflakeDriver/1.0" + assert result.headers["User-Agent"] != "MaliciousAgent" + + # Non-conflicting key added + assert result.headers["X-New-Header"] == "NewValue" + + +@pytest.mark.skipolddriver +def test_multiple_customizers_add_headers( + sample_request_factory, headers_customizer_factory +): + request = sample_request_factory() + customizer1 = headers_customizer_factory( + applies=True, headers={"X-Custom1": "Val1"} + ) + customizer2 = headers_customizer_factory( + applies=True, headers={"X-Custom2": "Val2"} + ) + interceptor = HeadersCustomizerInterceptor([customizer1, customizer2]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + assert result.headers["X-Custom1"] == "Val1" + assert result.headers["X-Custom2"] == "Val2" + + +@pytest.mark.skipolddriver +def test_multi_value_headers(sample_request_factory, headers_customizer_factory): + request = sample_request_factory() + customizer = headers_customizer_factory( + applies=True, headers={"X-Multi": ["ValA", "ValB"]} + ) + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + values = result.headers["X-Multi"] + if isinstance(values, list): + assert "ValA" in values + assert "ValB" in values + else: + assert "ValA" in values or "ValB" in values + + +@pytest.mark.parametrize( + "url,should_apply", + [ + ("https://test.snowflakecomputing.com/api", True), + ("https://example.com/api", False), + ], +) +@pytest.mark.skipolddriver +def test_customizer_applies_only_to_specific_domain( + sample_request_factory, headers_customizer_factory, url, should_apply +): + request = sample_request_factory(url=url) + + def snowflake_only(req): + return "snowflakecomputing.com" in req.url + + customizer = headers_customizer_factory( + applies=snowflake_only, headers={"X-Domain-Specific": "True"} + ) + interceptor = HeadersCustomizerInterceptor([customizer]) + result = interceptor.intercept_on( + HttpInterceptor.InterceptionHook.BEFORE_REQUEST_ISSUED, request + ) + + if should_apply: + assert result.headers["X-Domain-Specific"] == "True" + else: + assert "X-Domain-Specific" not in result.headers diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py index cae2465453..1ea09d067c 100644 --- a/test/unit/test_oauth_token.py +++ b/test/unit/test_oauth_token.py @@ -5,7 +5,6 @@ import logging import pathlib from threading import Thread -from typing import Any, Generator, Union from unittest import mock from unittest.mock import Mock, patch @@ -16,17 +15,11 @@ from snowflake.connector.auth import AuthByOauthCredentials from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType -from ..wiremock.wiremock_utils import WiremockClient +from ..test_utils.wiremock.wiremock_utils import WiremockClient logger = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client - - @pytest.fixture(scope="session") def wiremock_oauth_authorization_code_dir() -> pathlib.Path: return ( @@ -53,17 +46,6 @@ def wiremock_oauth_client_creds_dir() -> pathlib.Path: ) -@pytest.fixture(scope="session") -def wiremock_generic_mappings_dir() -> pathlib.Path: - return ( - pathlib.Path(__file__).parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - @pytest.fixture(scope="session") def wiremock_oauth_refresh_token_dir() -> pathlib.Path: return ( diff --git a/test/unit/test_programmatic_access_token.py b/test/unit/test_programmatic_access_token.py index 7d6ecb175e..fdf5bc0c9d 100644 --- a/test/unit/test_programmatic_access_token.py +++ b/test/unit/test_programmatic_access_token.py @@ -1,5 +1,4 @@ import pathlib -from typing import Any, Generator, Union import pytest @@ -9,13 +8,7 @@ except ImportError: pass -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client +from ..test_utils.wiremock.wiremock_utils import WiremockClient @pytest.mark.skipolddriver diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py index b471f39df7..19625c42c0 100644 --- a/test/unit/test_wiremock_client.py +++ b/test/unit/test_wiremock_client.py @@ -1,7 +1,3 @@ -from typing import Any, Generator - -import pytest - # old driver support try: from snowflake.connector.vendored import requests @@ -9,16 +5,6 @@ import requests -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.mark.skipolddriver -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[WiremockClient, Any, None]: - with WiremockClient() as client: - yield client - - def test_wiremock(wiremock_client): connection_reset_by_peer_mapping = { "mappings": [