Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions schemachange/session/SnowflakeSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,26 @@ def __init__(
self.change_history_table = change_history_table
self.autocommit = autocommit
self.logger = logger

self.session_parameters = {"QUERY_TAG": f"schemachange {schemachange_version}"}
self.session_parameters = {}
snowflake_kwargs = {
"connection_name": connection_name,
"connections_file_path": connections_file_path,
"application": application
}
snowflake_kwargs = {k: v for k, v in snowflake_kwargs.items() if v is not None}
temp_con = snowflake.connector.connect(**snowflake_kwargs)
if hasattr(temp_con, '_session_parameters'):
self.session_parameters.update(temp_con._session_parameters)
temp_con.close()

query_tag_value = f"schemachange {schemachange_version}"
if query_tag:
self.session_parameters["QUERY_TAG"] += f";{query_tag}"
query_tag_value += f";{query_tag}"

if "QUERY_TAG" in self.session_parameters:
self.session_parameters["QUERY_TAG"] += f";{query_tag_value}"
else:
self.session_parameters["QUERY_TAG"] = query_tag_value

connect_kwargs = {
"account": account, # TODO: Remove when connections.toml is enforced
Expand Down
53 changes: 53 additions & 0 deletions tests/session/test_session_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from unittest import mock

import pytest
import structlog

from schemachange.config.ChangeHistoryTable import ChangeHistoryTable
from schemachange.session.SnowflakeSession import SnowflakeSession


@pytest.fixture
def mock_snowflake_connect():
with mock.patch("snowflake.connector.connect") as mock_connect:
mock_con = mock.MagicMock()
mock_con._session_parameters = {
"QUOTED_IDENTIFIERS_IGNORE_CASE": False,
"QUERY_TAG": "existing_tag"
}
mock_connect.return_value = mock_con
yield mock_connect


def test_session_parameters_from_toml(mock_snowflake_connect):
"""Test that session parameters from connections.toml are respected and merged with QUERY_TAG"""
change_history_table = ChangeHistoryTable()
logger = structlog.testing.CapturingLogger()

with mock.patch("schemachange.session.SnowflakeSession.get_snowflake_identifier_string"):
session = SnowflakeSession(
user="user",
account="account",
role="role",
warehouse="warehouse",
schemachange_version="3.6.1.dev",
application="schemachange",
change_history_table=change_history_table,
logger=logger,
connections_file_path="connections.toml",
connection_name="test_connection",
query_tag="custom_tag"
)

first_call_kwargs = mock_snowflake_connect.call_args_list[0][1]
assert first_call_kwargs["connections_file_path"] == "connections.toml"
assert first_call_kwargs["connection_name"] == "test_connection"
assert "session_parameters" not in first_call_kwargs

second_call_kwargs = mock_snowflake_connect.call_args_list[1][1]
assert second_call_kwargs["session_parameters"] == {
"QUOTED_IDENTIFIERS_IGNORE_CASE": False,
"QUERY_TAG": "existing_tag;schemachange 3.6.1.dev;custom_tag"
}