diff --git a/schemachange/config/DeployConfig.py b/schemachange/config/DeployConfig.py index 036033f2..4702cca4 100644 --- a/schemachange/config/DeployConfig.py +++ b/schemachange/config/DeployConfig.py @@ -9,6 +9,7 @@ from schemachange.config.utils import ( get_snowflake_identifier_string, get_snowflake_password, + get_snowflake_private_key, ) @@ -93,4 +94,11 @@ def get_session_kwargs(self) -> dict: if snowflake_password is not None and snowflake_password: session_kwargs["password"] = snowflake_password + private_key_path, private_key_passphrase = get_snowflake_private_key() + if private_key_path: + session_kwargs["private_key_path"] = private_key_path + + if private_key_passphrase: + session_kwargs["private_key_passphrase"] = private_key_passphrase + return {k: v for k, v in session_kwargs.items() if v is not None} diff --git a/schemachange/config/utils.py b/schemachange/config/utils.py index af3f284b..e09b7ab0 100644 --- a/schemachange/config/utils.py +++ b/schemachange/config/utils.py @@ -161,3 +161,9 @@ def get_snowflake_password() -> str | None: return snowsql_pwd else: return None + + +def get_snowflake_private_key() -> tuple[str | None, str | None]: + private_key_path = os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH") + private_key_passphrase = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE") + return private_key_path, private_key_passphrase diff --git a/schemachange/session/SnowflakeSession.py b/schemachange/session/SnowflakeSession.py index 1346b761..a4805e4a 100644 --- a/schemachange/session/SnowflakeSession.py +++ b/schemachange/session/SnowflakeSession.py @@ -10,7 +10,7 @@ from schemachange.config.ChangeHistoryTable import ChangeHistoryTable from schemachange.config.utils import get_snowflake_identifier_string -from schemachange.session.Script import VersionedScript, RepeatableScript, AlwaysScript +from schemachange.session.Script import AlwaysScript, RepeatableScript, VersionedScript class SnowflakeSession: @@ -63,9 +63,8 @@ def __init__( "schema": schema, # TODO: Remove when connections.toml is enforced "role": role, # TODO: Remove when connections.toml is enforced "warehouse": warehouse, # TODO: Remove when connections.toml is enforced - "private_key_file": kwargs.get( - "private_key_path" - ), # TODO: Remove when connections.toml is enforced + "private_key_path": kwargs.get("private_key_path"), + "private_key_file_pwd": kwargs.get("private_key_file_pwd"), "token": kwargs.get( "oauth_token" ), # TODO: Remove when connections.toml is enforced