diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43aabd2e9..5038df20b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,6 +81,7 @@ repos: (?x)^src/snowflake/connector/( constants | compat + | cursor | dbapi | description | errorcode diff --git a/src/snowflake/connector/bind_upload_agent.py b/src/snowflake/connector/bind_upload_agent.py index d01751cad..efbc0dfea 100644 --- a/src/snowflake/connector/bind_upload_agent.py +++ b/src/snowflake/connector/bind_upload_agent.py @@ -14,7 +14,7 @@ from .errors import BindUploadError, Error if TYPE_CHECKING: # pragma: no cover - from .cursor import SnowflakeCursor + from .cursor import SnowflakeCursorBase logger = getLogger(__name__) @@ -23,7 +23,7 @@ class BindUploadAgent: def __init__( self, - cursor: SnowflakeCursor, + cursor: SnowflakeCursorBase, rows: list[bytes], stream_buffer_size: int = 1024 * 1024 * 10, ) -> None: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index c4efe25f8..79031b9bc 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -109,7 +109,13 @@ ER_NO_USER, ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE, ) -from .errors import DatabaseError, Error, OperationalError, ProgrammingError +from .errors import ( + DatabaseError, + Error, + ErrorHandler, + OperationalError, + ProgrammingError, +) from .log_configuration import EasyLoggingConfigPython from .network import ( DEFAULT_AUTHENTICATOR, @@ -500,6 +506,12 @@ class SnowflakeConnection: OCSP_ENV_LOCK = Lock() + # Tell mypy these fields exist. + # TODO: Replace the dynamic setattr() with static methods + _interpolate_empty_sequences: bool + _reraise_error_in_file_transfer_work_function: bool + _reuse_results: bool + def __init__( self, connection_name: str | None = None, @@ -780,12 +792,12 @@ def application(self) -> str: return self._application @property - def errorhandler(self) -> Callable: # TODO: callable args + def errorhandler(self) -> ErrorHandler: return self._errorhandler @errorhandler.setter - # Note: Callable doesn't implement operator| - def errorhandler(self, value: Callable | None) -> None: + def errorhandler(self, value: ErrorHandler | None) -> None: + # TODO: Why is value `ErrorHandler | None` if it always errors on None? if value is None: raise ProgrammingError("None errorhandler is specified") self._errorhandler = value @@ -1096,9 +1108,9 @@ def execute_string( sql_text: str, remove_comments: bool = False, return_cursors: bool = True, - cursor_class: SnowflakeCursor = SnowflakeCursor, + cursor_class: SnowflakeCursorBase = SnowflakeCursor, **kwargs, - ) -> Iterable[SnowflakeCursor]: + ) -> Iterable[SnowflakeCursorBase]: """Executes a SQL text including multiple statements. This is a non-standard convenience method.""" stream = StringIO(sql_text) stream_generator = self.execute_stream( @@ -1111,9 +1123,9 @@ def execute_stream( self, stream: StringIO, remove_comments: bool = False, - cursor_class: SnowflakeCursor = SnowflakeCursor, + cursor_class: SnowflakeCursorBase = SnowflakeCursor, **kwargs, - ) -> Generator[SnowflakeCursor]: + ) -> Generator[SnowflakeCursorBase]: """Executes a stream of SQL statements. This is a non-standard convenient method.""" split_statements_list = split_statements( stream, remove_comments=remove_comments @@ -1830,7 +1842,6 @@ def _write_params_to_byte_rows( Args: params: Binding parameters to bulk array insertion query with qmark/numeric format. - cursor: SnowflakeCursor. Returns: List of bytes string corresponding to rows @@ -1847,7 +1858,7 @@ def _write_params_to_byte_rows( def _get_snowflake_type_and_binding( self, - cursor: SnowflakeCursor | None, + cursor: SnowflakeCursorBase | None, v: tuple[str, Any] | Any, ) -> TypeAndBinding: if isinstance(v, tuple): @@ -1888,7 +1899,7 @@ def _get_snowflake_type_and_binding( def _process_params_qmarks( self, params: Sequence | None, - cursor: SnowflakeCursor | None = None, + cursor: SnowflakeCursorBase | None = None, ) -> dict[str, dict[str, str]] | None: if not params: return None @@ -1922,14 +1933,14 @@ def _process_params_qmarks( def _process_params_pyformat( self, params: Any | Sequence[Any] | dict[Any, Any] | None, - cursor: SnowflakeCursor | None = None, + cursor: SnowflakeCursorBase | None = None, ) -> tuple[Any] | dict[str, Any] | None: """Process parameters for client-side parameter binding. Args: params: Either a sequence, or a dictionary of parameters, if anything else is given then it will be put into a list and processed that way. - cursor: The SnowflakeCursor used to report errors if necessary. + cursor: The SnowflakeCursorBase used to report errors if necessary. """ if params is None: if self._interpolate_empty_sequences: @@ -1961,7 +1972,7 @@ def _process_params_pyformat( ) def _process_params_dict( - self, params: dict[Any, Any], cursor: SnowflakeCursor | None = None + self, params: dict[Any, Any], cursor: SnowflakeCursorBase | None = None ) -> dict: try: res = {k: self._process_single_param(v) for k, v in params.items()} diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index c13ab242c..1f197477c 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -69,6 +69,7 @@ from .errors import ( DatabaseError, Error, + ErrorHandler, IntegrityError, InterfaceError, NotSupportedError, @@ -223,7 +224,7 @@ def from_column(cls, col: dict[str, Any]) -> ResultMetadataV2: ) fields = col.get("fields") - processed_fields: Optional[List[ResultMetadataV2]] = None + processed_fields: list[ResultMetadataV2] | None = None if fields is not None: if col_type in {"VECTOR", "ARRAY", "OBJECT", "MAP"}: processed_fields = [ @@ -372,12 +373,9 @@ def __init__( Args: connection: The connection that created this cursor. """ - self._connection: SnowflakeConnection = connection + self._connection: SnowflakeConnection | None = connection - self._errorhandler: Callable[ - [SnowflakeConnection, SnowflakeCursor, type[Error], dict[str, str]], - None, - ] = Error.default_errorhandler + self._errorhandler = Error.default_errorhandler self.messages: list[ tuple[type[Error] | type[Exception], dict[str, str | bool]] ] = [] @@ -413,11 +411,11 @@ def __init__( self._lock_canceling = Lock() - self._first_chunk_time = None + self._first_chunk_time: int | None = None self._log_max_query_length = connection.log_max_query_length self._inner_cursor: SnowflakeCursorBase | None = None - self._prefetch_hook = None + self._prefetch_hook: Callable[[], None] | None = None self._rownumber: int | None = None self.reset() @@ -436,14 +434,14 @@ def _use_dict_result(self) -> bool: pass @property - def description(self) -> list[ResultMetadata]: + def description(self) -> list[ResultMetadata] | None: if self._description is None: return None return [meta._to_result_metadata_v1() for meta in self._description] @property - def _description_internal(self) -> list[ResultMetadataV2]: + def _description_internal(self) -> list[ResultMetadataV2] | None: """Return the new format of result metadata for a query. This method is for internal use only. @@ -456,7 +454,11 @@ def rowcount(self) -> int | None: @property def rownumber(self) -> int | None: - return self._rownumber if self._rownumber >= 0 else None + return ( + self._rownumber + if self._rownumber is not None and self._rownumber >= 0 + else None + ) @property def sfqid(self) -> str | None: @@ -519,16 +521,17 @@ def arraysize(self, value) -> None: self._arraysize = int(value) @property - def connection(self) -> SnowflakeConnection: + def connection(self) -> SnowflakeConnection | None: return self._connection @property - def errorhandler(self) -> Callable: + def errorhandler(self) -> ErrorHandler: return self._errorhandler @errorhandler.setter - def errorhandler(self, value: Callable | None) -> None: + def errorhandler(self, value: ErrorHandler | None) -> None: logger.debug("setting errorhandler: %s", value) + # TODO: why is value `ErrorHandler | None` if it always errors on None? if value is None: raise ProgrammingError("Invalid errorhandler is specified") self._errorhandler = value @@ -559,6 +562,7 @@ def callproc(self, procname: str, args=tuple()): Returns: The input parameters. """ + assert self._connection is not None marker_format = "%s" if self._connection.is_pyformat else "?" command = ( f"CALL {procname}({', '.join([marker_format for _ in range(len(args))])})" @@ -590,7 +594,7 @@ def _execute_helper( query: str, timeout: int = 0, statement_params: dict[str, str] | None = None, - binding_params: tuple | dict[str, dict[str, str]] = None, + binding_params: tuple | dict[str, dict[str, str]] | None = None, binding_stage: str | None = None, is_internal: bool = False, describe_only: bool = False, @@ -631,6 +635,7 @@ def _execute_helper( "JSON" ) + assert self._connection is not None self._sequence_counter = self._connection._next_sequence_counter() # If requestId is contained in statement parameters, use it to set request id. Verify here it is a valid uuid4 @@ -776,8 +781,10 @@ def _preprocess_pyformat_query( ) -> str: # pyformat/format paramstyle # client side binding + assert self._connection is not None processed_params = self._connection._process_params_pyformat(params, self) # SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement + assert self.connection is not None if params is not None and len(params) == 0: self._log_telemetry_job_data( TelemetryField.EMPTY_SEQ_INTERPOLATION, @@ -798,7 +805,7 @@ def _preprocess_pyformat_query( and processed_params is not None ) or ( not self.connection._interpolate_empty_sequences - and len(processed_params) > 0 + and len(processed_params) > 0 # type: ignore[arg-type] # TODO: handle when processed_params is None ): query = command % processed_params else: @@ -815,14 +822,14 @@ def execute( _exec_async: bool = False, _no_retry: bool = False, _do_reset: bool = True, - _put_callback: SnowflakeProgressPercentage = None, - _put_azure_callback: SnowflakeProgressPercentage = None, + _put_callback: SnowflakeProgressPercentage | None = None, + _put_azure_callback: SnowflakeProgressPercentage | None = None, _put_callback_output_stream: IO[str] = sys.stdout, - _get_callback: SnowflakeProgressPercentage = None, - _get_azure_callback: SnowflakeProgressPercentage = None, + _get_callback: SnowflakeProgressPercentage | None = None, + _get_azure_callback: SnowflakeProgressPercentage | None = None, _get_callback_output_stream: IO[str] = sys.stdout, _show_progress_bar: bool = True, - _statement_params: dict[str, str] | None = None, + _statement_params: dict[str, str | int] | None = None, _is_internal: bool = False, _describe_only: bool = False, _no_results: Literal[False] = False, @@ -832,6 +839,7 @@ def execute( _skip_upload_on_content_match: bool = False, file_stream: IO[bytes] | None = None, num_statements: int | None = None, + _force_qmark_paramstyle: bool = False, _dataframe_ast: str | None = None, ) -> Self | None: ... @@ -845,14 +853,14 @@ def execute( _exec_async: bool = False, _no_retry: bool = False, _do_reset: bool = True, - _put_callback: SnowflakeProgressPercentage = None, - _put_azure_callback: SnowflakeProgressPercentage = None, + _put_callback: SnowflakeProgressPercentage | None = None, + _put_azure_callback: SnowflakeProgressPercentage | None = None, _put_callback_output_stream: IO[str] = sys.stdout, - _get_callback: SnowflakeProgressPercentage = None, - _get_azure_callback: SnowflakeProgressPercentage = None, + _get_callback: SnowflakeProgressPercentage | None = None, + _get_azure_callback: SnowflakeProgressPercentage | None = None, _get_callback_output_stream: IO[str] = sys.stdout, _show_progress_bar: bool = True, - _statement_params: dict[str, str] | None = None, + _statement_params: dict[str, str | int] | None = None, _is_internal: bool = False, _describe_only: bool = False, _no_results: Literal[True] = True, @@ -862,6 +870,7 @@ def execute( _skip_upload_on_content_match: bool = False, file_stream: IO[bytes] | None = None, num_statements: int | None = None, + _force_qmark_paramstyle: bool = False, _dataframe_ast: str | None = None, ) -> dict[str, Any] | None: ... @@ -874,14 +883,14 @@ def execute( _exec_async: bool = False, _no_retry: bool = False, _do_reset: bool = True, - _put_callback: SnowflakeProgressPercentage = None, - _put_azure_callback: SnowflakeProgressPercentage = None, + _put_callback: SnowflakeProgressPercentage | None = None, + _put_azure_callback: SnowflakeProgressPercentage | None = None, _put_callback_output_stream: IO[str] = sys.stdout, - _get_callback: SnowflakeProgressPercentage = None, - _get_azure_callback: SnowflakeProgressPercentage = None, + _get_callback: SnowflakeProgressPercentage | None = None, + _get_azure_callback: SnowflakeProgressPercentage | None = None, _get_callback_output_stream: IO[str] = sys.stdout, _show_progress_bar: bool = True, - _statement_params: dict[str, str] | None = None, + _statement_params: dict[str, str | int] | None = None, _is_internal: bool = False, _describe_only: bool = False, _no_results: bool = False, @@ -976,6 +985,7 @@ def execute( "dataframe_ast": _dataframe_ast, } + assert self._connection is not None if self._connection.is_pyformat and not _force_qmark_paramstyle: query = self._preprocess_pyformat_query(command, params) else: @@ -994,7 +1004,7 @@ def execute( ) kwargs["binding_params"] = self._connection._process_params_qmarks( - params, self + params, self # type: ignore[arg-type] # TODO: handle when params is dict and errorhandler_wrapper doesn't throw an exception ) m = DESC_TABLE_RE.match(query) @@ -1040,6 +1050,8 @@ def execute( # session parameters param = m.group(1).upper() value = m.group(2) + assert self._connection is not None + assert self._connection.converter is not None self._connection.converter.set_parameter(param, value) if "resultIds" in data: @@ -1076,7 +1088,8 @@ def execute( self._total_rowcount = len(data["rowset"]) if "rowset" in data else -1 if _exec_async: - self.connection._async_sfqids[self._sfqid] = None + assert self.connection is not None + self.connection._async_sfqids[self._sfqid] = None # type: ignore[index] # TODO: Handle when self._sfqid is None if _no_results: self._total_rowcount = ( ret["data"]["total"] @@ -1130,7 +1143,7 @@ def execute_async(self, *args: Any, **kwargs: Any) -> dict[str, Any]: kwargs["_exec_async"] = True return self.execute(*args, **kwargs) - def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]: + def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata] | None: """Obtain the schema of the result without executing the query. This function takes the same arguments as execute, please refer to that function @@ -1146,7 +1159,9 @@ def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]: return None return [meta._to_result_metadata_v1() for meta in self._description] - def _describe_internal(self, *args: Any, **kwargs: Any) -> list[ResultMetadataV2]: + def _describe_internal( + self, *args: Any, **kwargs: Any + ) -> list[ResultMetadataV2] | None: """Obtain the schema of the result without executing the query. This function takes the same arguments as execute, please refer to that function @@ -1162,6 +1177,7 @@ def _describe_internal(self, *args: Any, **kwargs: Any) -> list[ResultMetadataV2 return self._description def _format_query_for_log(self, query: str) -> str: + assert self._connection is not None return self._connection._format_query_for_log(query) def _is_dml(self, data: dict[Any, Any]) -> bool: @@ -1178,7 +1194,7 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None: if self._total_rowcount == -1 and not is_dml and data.get("total") is not None: self._total_rowcount = data["total"] - self._description: list[ResultMetadataV2] = [ + self._description = [ ResultMetadataV2.from_column(col) for col in data["rowtype"] ] @@ -1191,6 +1207,7 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None: "Number of results in first chunk: %s", result_chunks[0].rowcount ) + assert self._connection is not None self._result_set = ResultSet( self, result_chunks, @@ -1238,6 +1255,7 @@ def check_can_use_arrow_resultset(self) -> None: global CAN_USE_ARROW_RESULT_FORMAT if not CAN_USE_ARROW_RESULT_FORMAT: + assert self._connection is not None if self._connection.application == "SnowSQL": msg = "Currently SnowSQL doesn't support the result set in Apache Arrow format." errno = ER_NO_PYARROW_SNOWSQL @@ -1273,9 +1291,11 @@ def check_can_use_pandas(self) -> None: }, ) - def query_result(self, qid: str) -> SnowflakeCursor: + def query_result(self, qid: str) -> Self: """Query the result of a previously executed query.""" url = f"/queries/{qid}/result" + assert self._connection is not None + assert self._connection.rest is not None ret = self._connection.rest.request(url=url, method="get") self._sfqid = ( ret["data"]["queryId"] @@ -1319,6 +1339,7 @@ def fetch_arrow_batches(self) -> Iterator[Table]: self._log_telemetry_job_data( TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE ) + assert self._result_set is not None return self._result_set._fetch_arrow_batches() @overload @@ -1341,6 +1362,7 @@ def fetch_arrow_all(self, force_return_table: bool = False) -> Table | None: if self._query_result_format != "arrow": raise NotSupportedError self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE) + assert self._result_set is not None return self._result_set._fetch_arrow_all(force_return_table=force_return_table) def fetch_pandas_batches(self, **kwargs: Any) -> Iterator[DataFrame]: @@ -1353,6 +1375,7 @@ def fetch_pandas_batches(self, **kwargs: Any) -> Iterator[DataFrame]: self._log_telemetry_job_data( TelemetryField.PANDAS_FETCH_BATCHES, TelemetryData.TRUE ) + assert self._result_set is not None return self._result_set._fetch_pandas_batches(**kwargs) def fetch_pandas_all(self, **kwargs: Any) -> DataFrame: @@ -1375,10 +1398,13 @@ def fetch_pandas_all(self, **kwargs: Any) -> DataFrame: self._log_telemetry_job_data( TelemetryField.PANDAS_FETCH_ALL, TelemetryData.TRUE ) + assert self._result_set is not None return self._result_set._fetch_pandas_all(**kwargs) def abort_query(self, qid: str) -> bool: url = f"/queries/{qid}/abort-request" + assert self._connection is not None + assert self._connection.rest is not None ret = self._connection.rest.request(url=url, method="post") return ret.get("success") @@ -1387,10 +1413,10 @@ def executemany( command: str, seqparams: Sequence[Any] | dict[str, Any], **kwargs: Any, - ) -> SnowflakeCursor: + ) -> Self: """Executes a command/query with the given set of parameters sequentially.""" logger.debug("executing many SQLs/commands") - command = command.strip(" \t\n\r") if command else None + command = command.strip(" \t\n\r") if command else None # type: ignore[assignment] # TODO: None will break subsequent code if not seqparams: logger.warning( @@ -1401,6 +1427,8 @@ def executemany( if self.INSERT_SQL_RE.match(command) and ( "num_statements" not in kwargs or kwargs.get("num_statements") == 1 ): + assert self._connection is not None + assert self.connection is not None if self._connection.is_pyformat: # TODO(SNOW-940692) - utilize multi-statement instead of rewriting the query and # accumulate results to mock the result from a single insert statement as formatted below @@ -1418,7 +1446,7 @@ def executemany( }, ) - fmt = m.group(1) + fmt = m.group(1) # type: ignore[union-attr] # TODO: Handle when m is None and errorhandler_wrapper doesn't throw an exception values = [] for param in seqparams: logger.debug(f"parameter: {param}") @@ -1431,7 +1459,7 @@ def executemany( else: logger.debug("bulk insert") # sanity check - row_size = len(seqparams[0]) + row_size = len(seqparams[0]) # type: ignore[index] # TODO: handle when seqparams is dict for row in seqparams: if len(row) != row_size: error_value = { @@ -1445,16 +1473,16 @@ def executemany( return self bind_size = len(seqparams) * row_size bind_stage = None - if ( - bind_size - >= self.connection._session_parameters[ - "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD" - ] - > 0 - ): + binding_threshold = self.connection._session_parameters[ + "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD" + ] + assert isinstance(binding_threshold, int) + if bind_size >= binding_threshold > 0: # bind stage optimization try: - rows = self.connection._write_params_to_byte_rows(seqparams) + rows = self.connection._write_params_to_byte_rows( + seqparams # type: ignore[arg-type] # TODO: handle when seqparams is dict + ) bind_uploader = BindUploadAgent(self, rows) bind_uploader.upload() bind_stage = bind_uploader.stage_path @@ -1472,7 +1500,8 @@ def executemany( return self self.reset() - if "num_statements" not in kwargs: + num_statements = kwargs.get("num_statements") + if num_statements is None: # fall back to old driver behavior when the user does not provide the parameter to enable # multi-statement optimizations for executemany for param in seqparams: @@ -1480,6 +1509,7 @@ def executemany( else: if re.search(";/s*$", command) is None: command = command + "; " + assert self._connection is not None if self._connection.is_pyformat and not kwargs.get( "_force_qmark_paramstyle", False ): @@ -1493,16 +1523,14 @@ def executemany( query = command * len(seqparams) params = [param for parameters in seqparams for param in parameters] - kwargs["num_statements"]: int = kwargs.get("num_statements") * len( - seqparams - ) + kwargs["num_statements"] = num_statements * len(seqparams) self.execute(query, params, _do_reset=False, **kwargs) return self @abc.abstractmethod - def fetchone(self) -> FetchRow: + def fetchone(self) -> FetchRow | None: pass def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: @@ -1519,6 +1547,13 @@ def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: self._result_state = ResultState.VALID try: + # Keep throwing the same TypeError as before, except tell mypy + # this branch doesn't continue + # TODO: Handle this explicitly, instead of simply catching TypeErrors + if self._result is None: + next(None) # type: ignore + raise Exception("unreachable: the previous line should have thrown") + _next = next(self._result, None) if isinstance(_next, Exception): Error.errorhandler_wrapper_from_ready_exception( @@ -1527,6 +1562,7 @@ def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: _next, ) if _next is not None: + assert self._rownumber is not None self._rownumber += 1 return _next except TypeError as err: @@ -1571,7 +1607,7 @@ def fetchall(self) -> list[FetchRow]: ret.append(row) return ret - def nextset(self) -> SnowflakeCursor | None: + def nextset(self) -> Self | None: """ Fetches the next set of results if the previously executed query was multi-statement so that subsequent calls to any of the fetch*() methods will return rows from the next query's set of results. Returns None if no more @@ -1625,10 +1661,11 @@ def reset(self, closing: bool = False) -> None: self._result = None self._inner_cursor = None self._prefetch_hook = None + assert self.connection is not None if not self.connection._reuse_results: self._result_set = None - def __iter__(self) -> Iterator[dict] | Iterator[tuple]: + def __iter__(self) -> Iterator[FetchRow]: """Iteration over the result set.""" while True: _next = self.fetchone() @@ -1640,7 +1677,11 @@ def __cancel_query(self, query) -> None: if self._sequence_counter >= 0 and not self.is_closed(): logger.debug("canceled. %s, request_id: %s", query, self._request_id) with self._lock_canceling: - self._connection._cancel_query(query, self._request_id) + assert self._connection is not None + self._connection._cancel_query( + query, + self._request_id, # type: ignore[arg-type] # TODO: What should happen if self._request_id is None? + ) def _log_telemetry_job_data( self, telemetry_field: TelemetryField, value: Any @@ -1648,7 +1689,7 @@ def _log_telemetry_job_data( """Builds an instance of TelemetryData with the given field and logs it.""" ts = get_time_millis() try: - self._connection._log_telemetry( + self._connection._log_telemetry( # type: ignore[union-attr] # TODO: replace try..except with explicit check TelemetryData.from_telemetry_data_dict( from_dict={ TelemetryField.KEY_TYPE.value: telemetry_field.value, @@ -1688,6 +1729,7 @@ def wait_until_ready() -> None: no_data_counter = 0 retry_pattern_pos = 0 while True: + assert self.connection is not None status, status_resp = self.connection._get_query_status(sfqid) self.connection._cache_query_status(sfqid, status) if not self.connection.is_still_running(status): @@ -1713,6 +1755,7 @@ def wait_until_ready() -> None: error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable", error_cls=DatabaseError, ) + assert self._inner_cursor is not None self._inner_cursor.execute(f"select * from table(result_scan('{sfqid}'))") self._result = self._inner_cursor._result self._query_result_format = self._inner_cursor._query_result_format @@ -1728,6 +1771,8 @@ def wait_until_ready() -> None: self._inner_cursor.fetchall() ): url = f"/queries/{sfqid}/result" + assert self._connection is not None + assert self._connection.rest is not None ret = self._connection.rest.request(url=url, method="get") if "data" in ret and "resultIds" in ret["data"]: self._init_multi_statement_results(ret["data"]) @@ -1745,6 +1790,7 @@ def _is_successful_multi_stmt(rows: list[Any]) -> bool: else: return False + assert self.connection is not None self.connection.get_query_status_throw_if_error( sfqid ) # Trigger an exception if query failed @@ -1788,6 +1834,7 @@ def _download( self.reset() # Interpret the file operation. + assert self.connection is not None ret = self.connection._file_operation_parser.parse_file_operation( stage_location=stage_location, local_file_name=None, @@ -1825,6 +1872,7 @@ def _upload( self.reset() # Interpret the file operation. + assert self.connection is not None ret = self.connection._file_operation_parser.parse_file_operation( stage_location=stage_location, local_file_name=local_file_name, @@ -1856,6 +1904,7 @@ def _download_stream( IO[bytes]: A stream to read from. """ # Interpret the file operation. + assert self.connection is not None ret = self.connection._file_operation_parser.parse_file_operation( stage_location=stage_location, local_file_name=None, @@ -1889,6 +1938,7 @@ def _upload_stream( self.reset() # Interpret the file operation. + assert self.connection is not None ret = self.connection._file_operation_parser.parse_file_operation( stage_location=stage_location, local_file_name=None, @@ -1917,6 +1967,7 @@ def _create_file_transfer_agent( ) -> SnowflakeFileTransferAgent: from .file_transfer_agent import SnowflakeFileTransferAgent + assert self._connection is not None return SnowflakeFileTransferAgent( self, command, diff --git a/src/snowflake/connector/errors.py b/src/snowflake/connector/errors.py index 447f8b9f0..96610e5b4 100644 --- a/src/snowflake/connector/errors.py +++ b/src/snowflake/connector/errors.py @@ -6,7 +6,7 @@ import re import traceback from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol from .errorcode import ER_HTTP_GENERAL_ERROR from .secret_detector import SecretDetector @@ -15,7 +15,7 @@ if TYPE_CHECKING: # pragma: no cover from .connection import SnowflakeConnection - from .cursor import SnowflakeCursor + from .cursor import SnowflakeCursorBase logger = getLogger(__name__) connector_base_path = os.path.join("snowflake", "connector") @@ -36,7 +36,7 @@ def __init__( query: str | None = None, done_format_msg: bool | None = None, connection: SnowflakeConnection | None = None, - cursor: SnowflakeCursor | None = None, + cursor: SnowflakeCursorBase | None = None, errtype: TelemetryField = TelemetryField.SQL_EXCEPTION, send_telemetry: bool = True, ) -> None: @@ -172,7 +172,7 @@ def send_exception_telemetry( def exception_telemetry( self, msg: str, - cursor: SnowflakeCursor | None, + cursor: SnowflakeCursorBase | None, connection: SnowflakeConnection | None, ) -> None: """Main method to generate and send telemetry data for exceptions.""" @@ -197,7 +197,7 @@ def exception_telemetry( @staticmethod def default_errorhandler( connection: SnowflakeConnection, - cursor: SnowflakeCursor, + cursor: SnowflakeCursorBase, error_class: type[Error], error_value: dict[str, str], ) -> None: @@ -231,7 +231,7 @@ def default_errorhandler( def errorhandler_wrapper_from_cause( connection: SnowflakeConnection, cause: Error | Exception, - cursor: SnowflakeCursor | None = None, + cursor: SnowflakeCursorBase | None = None, ) -> None: """Wrapper for errorhandler_wrapper, it is called with a cause instead of a dictionary. @@ -263,7 +263,7 @@ def errorhandler_wrapper_from_cause( @staticmethod def errorhandler_wrapper( connection: SnowflakeConnection | None, - cursor: SnowflakeCursor | None, + cursor: SnowflakeCursorBase | None, error_class: type[Error] | type[Exception], error_value: dict[str, Any], ) -> None: @@ -298,7 +298,7 @@ def errorhandler_wrapper( @staticmethod def errorhandler_wrapper_from_ready_exception( connection: SnowflakeConnection | None, - cursor: SnowflakeCursor | None, + cursor: SnowflakeCursorBase | None, error_exc: Error | Exception, ) -> None: """Like errorhandler_wrapper, but it takes a ready to go Exception.""" @@ -324,7 +324,7 @@ def errorhandler_wrapper_from_ready_exception( @staticmethod def hand_to_other_handler( connection: SnowflakeConnection | None, - cursor: SnowflakeCursor | None, + cursor: SnowflakeCursorBase | None, error_class: type[Error] | type[Exception], error_value: dict[str, str | bool], ) -> bool: @@ -363,6 +363,19 @@ def errorhandler_make_exception( return error_class(error_value) +# Defining as Protocol instead of alias to Callable because mypy +# doesn't seem to like Callable as a type alias in files with +# circular imports. +class ErrorHandler(Protocol): + def __call__( + self, + connection: SnowflakeConnection, + cursor: SnowflakeCursorBase, + error_class: type[Error], + error_value: dict[str, str], + ) -> None: ... + + class _Warning(Exception): """Exception for important warnings.""" diff --git a/src/snowflake/connector/file_transfer_agent.py b/src/snowflake/connector/file_transfer_agent.py index 2f22078b2..a44eb98b8 100644 --- a/src/snowflake/connector/file_transfer_agent.py +++ b/src/snowflake/connector/file_transfer_agent.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: # pragma: no cover from .connection import SnowflakeConnection - from .cursor import SnowflakeCursor + from .cursor import SnowflakeCursorBase from .file_compression_type import CompressionType VALID_STORAGE = [LOCAL_FS, S3_FS, AZURE_FS, GCS_FS] @@ -338,7 +338,7 @@ class SnowflakeFileTransferAgent: def __init__( self, - cursor: SnowflakeCursor, + cursor: SnowflakeCursorBase, command: str, ret: dict[str, Any], put_callback: type[SnowflakeProgressPercentage] | None = None,