Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ repos:
(?x)^src/snowflake/connector/(
constants
| compat
| cursor
| dbapi
| description
| errorcode
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/bind_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .errors import BindUploadError, Error

if TYPE_CHECKING: # pragma: no cover
from .cursor import SnowflakeCursor
from .cursor import SnowflakeCursorBase

logger = getLogger(__name__)

Expand All @@ -23,7 +23,7 @@ class BindUploadAgent:

def __init__(
self,
cursor: SnowflakeCursor,
cursor: SnowflakeCursorBase,
rows: list[bytes],
stream_buffer_size: int = 1024 * 1024 * 10,
) -> None:
Expand Down
39 changes: 25 additions & 14 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@
ER_NO_USER,
ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE,
)
from .errors import DatabaseError, Error, OperationalError, ProgrammingError
from .errors import (
DatabaseError,
Error,
ErrorHandler,
OperationalError,
ProgrammingError,
)
from .log_configuration import EasyLoggingConfigPython
from .network import (
DEFAULT_AUTHENTICATOR,
Expand Down Expand Up @@ -500,6 +506,12 @@ class SnowflakeConnection:

OCSP_ENV_LOCK = Lock()

# Tell mypy these fields exist.
# TODO: Replace the dynamic setattr() with static methods
_interpolate_empty_sequences: bool
_reraise_error_in_file_transfer_work_function: bool
_reuse_results: bool

def __init__(
self,
connection_name: str | None = None,
Expand Down Expand Up @@ -780,12 +792,12 @@ def application(self) -> str:
return self._application

@property
def errorhandler(self) -> Callable: # TODO: callable args
def errorhandler(self) -> ErrorHandler:
return self._errorhandler

@errorhandler.setter
# Note: Callable doesn't implement operator|
def errorhandler(self, value: Callable | None) -> None:
def errorhandler(self, value: ErrorHandler | None) -> None:
# TODO: Why is value `ErrorHandler | None` if it always errors on None?
if value is None:
raise ProgrammingError("None errorhandler is specified")
self._errorhandler = value
Expand Down Expand Up @@ -1096,9 +1108,9 @@ def execute_string(
sql_text: str,
remove_comments: bool = False,
return_cursors: bool = True,
cursor_class: SnowflakeCursor = SnowflakeCursor,
cursor_class: SnowflakeCursorBase = SnowflakeCursor,
**kwargs,
) -> Iterable[SnowflakeCursor]:
) -> Iterable[SnowflakeCursorBase]:
"""Executes a SQL text including multiple statements. This is a non-standard convenience method."""
stream = StringIO(sql_text)
stream_generator = self.execute_stream(
Expand All @@ -1111,9 +1123,9 @@ def execute_stream(
self,
stream: StringIO,
remove_comments: bool = False,
cursor_class: SnowflakeCursor = SnowflakeCursor,
cursor_class: SnowflakeCursorBase = SnowflakeCursor,
**kwargs,
) -> Generator[SnowflakeCursor]:
) -> Generator[SnowflakeCursorBase]:
"""Executes a stream of SQL statements. This is a non-standard convenient method."""
split_statements_list = split_statements(
stream, remove_comments=remove_comments
Expand Down Expand Up @@ -1830,7 +1842,6 @@ def _write_params_to_byte_rows(

Args:
params: Binding parameters to bulk array insertion query with qmark/numeric format.
cursor: SnowflakeCursor.

Returns:
List of bytes string corresponding to rows
Expand All @@ -1847,7 +1858,7 @@ def _write_params_to_byte_rows(

def _get_snowflake_type_and_binding(
self,
cursor: SnowflakeCursor | None,
cursor: SnowflakeCursorBase | None,
v: tuple[str, Any] | Any,
) -> TypeAndBinding:
if isinstance(v, tuple):
Expand Down Expand Up @@ -1888,7 +1899,7 @@ def _get_snowflake_type_and_binding(
def _process_params_qmarks(
self,
params: Sequence | None,
cursor: SnowflakeCursor | None = None,
cursor: SnowflakeCursorBase | None = None,
) -> dict[str, dict[str, str]] | None:
if not params:
return None
Expand Down Expand Up @@ -1922,14 +1933,14 @@ def _process_params_qmarks(
def _process_params_pyformat(
self,
params: Any | Sequence[Any] | dict[Any, Any] | None,
cursor: SnowflakeCursor | None = None,
cursor: SnowflakeCursorBase | None = None,
) -> tuple[Any] | dict[str, Any] | None:
"""Process parameters for client-side parameter binding.

Args:
params: Either a sequence, or a dictionary of parameters, if anything else
is given then it will be put into a list and processed that way.
cursor: The SnowflakeCursor used to report errors if necessary.
cursor: The SnowflakeCursorBase used to report errors if necessary.
"""
if params is None:
if self._interpolate_empty_sequences:
Expand Down Expand Up @@ -1961,7 +1972,7 @@ def _process_params_pyformat(
)

def _process_params_dict(
self, params: dict[Any, Any], cursor: SnowflakeCursor | None = None
self, params: dict[Any, Any], cursor: SnowflakeCursorBase | None = None
) -> dict:
try:
res = {k: self._process_single_param(v) for k, v in params.items()}
Expand Down
Loading
Loading