diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py index c83d081bb..86e13a99d 100644 --- a/airbyte_cdk/sources/declarative/auth/jwt.py +++ b/airbyte_cdk/sources/declarative/auth/jwt.py @@ -6,15 +6,26 @@ import json from dataclasses import InitVar, dataclass from datetime import datetime -from typing import Any, Mapping, Optional, Union +from typing import Any, Mapping, Optional, Union, cast import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey +from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +# Type alias for keys that JWT library accepts +JwtKeyTypes = Union[ + RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey, str, bytes +] + class JwtAlgorithm(str): """ @@ -74,6 +85,7 @@ class JwtAuthenticator(DeclarativeAuthenticator): aud: Optional[Union[InterpolatedString, str]] = None additional_jwt_headers: Optional[Mapping[str, Any]] = None additional_jwt_payload: Optional[Mapping[str, Any]] = None + passphrase: Optional[Union[InterpolatedString, str]] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters) @@ -103,6 +115,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._additional_jwt_payload = InterpolatedMapping( self.additional_jwt_payload or {}, parameters=parameters ) + self._passphrase = ( + InterpolatedString.create(self.passphrase, parameters=parameters) + if self.passphrase + else None + ) def _get_jwt_headers(self) -> dict[str, Any]: """ @@ -149,11 +166,21 @@ def _get_jwt_payload(self) -> dict[str, Any]: payload["nbf"] = nbf return payload - def _get_secret_key(self) -> str: + def _get_secret_key(self) -> JwtKeyTypes: """ Returns the secret key used to sign the JWT. """ secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads) + + if self._passphrase: + passphrase_value = self._passphrase.eval(self.config, json_loads=json.loads) + if passphrase_value: + private_key = serialization.load_pem_private_key( + secret_key.encode(), + password=passphrase_value.encode(), + ) + return cast(JwtKeyTypes, private_key) + return ( base64.b64encode(secret_key.encode()).decode() if self._base64_encode_secret_key diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 50e80b601..35c3064a6 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1270,6 +1270,12 @@ definitions: title: Additional JWT Payload Properties description: Additional properties to be added to the JWT payload. additionalProperties: true + passphrase: + title: Passphrase + description: A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair. + type: string + examples: + - "{{ config['passphrase'] }}" $parameters: type: object additionalProperties: true diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py index edac9501d..9b8aca336 100644 --- a/airbyte_cdk/sources/declarative/interpolation/macros.py +++ b/airbyte_cdk/sources/declarative/interpolation/macros.py @@ -6,6 +6,7 @@ import datetime import re import typing +import uuid from typing import Optional, Union from urllib.parse import quote_plus @@ -207,6 +208,16 @@ def camel_case_to_snake_case(value: str) -> str: return re.sub(r"(? str: + """ + Generates a UUID4 + + Usage: + `"{{ generate_uuid() }}"` + """ + return str(uuid.uuid4()) + + _macros_list = [ now_utc, today_utc, @@ -220,5 +231,6 @@ def camel_case_to_snake_case(value: str) -> str: str_to_datetime, sanitize_url, camel_case_to_snake_case, + generate_uuid, ] macros = {f.__name__: f for f in _macros_list} diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index e207e18f4..9525f9d00 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -448,6 +448,12 @@ class JwtAuthenticator(BaseModel): description="Additional properties to be added to the JWT payload.", title="Additional JWT Payload Properties", ) + passphrase: Optional[str] = Field( + None, + description="A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.", + examples=["{{ config['passphrase'] }}"], + title="Passphrase", + ) parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index a79cd6383..9fd5d2d44 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -2705,6 +2705,7 @@ def create_jwt_authenticator( aud=jwt_payload.aud, additional_jwt_headers=model.additional_jwt_headers, additional_jwt_payload=model.additional_jwt_payload, + passphrase=model.passphrase, ) def create_list_partition_router( diff --git a/unit_tests/sources/declarative/auth/test_jwt.py b/unit_tests/sources/declarative/auth/test_jwt.py index 49b7ea570..4996e5388 100644 --- a/unit_tests/sources/declarative/auth/test_jwt.py +++ b/unit_tests/sources/declarative/auth/test_jwt.py @@ -8,6 +8,9 @@ import freezegun import jwt import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator @@ -185,3 +188,100 @@ def test_get_header_prefix(self, header_prefix, expected): header_prefix=header_prefix, ) assert authenticator._get_header_prefix() == expected + + def test_get_secret_key_with_passphrase(self): + """Test _get_secret_key method with encrypted private key and passphrase.""" + # Generate a test RSA private key + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + passphrase = b"test_passphrase" + encrypted_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(passphrase), + ) + + authenticator = JwtAuthenticator( + config={}, + parameters={}, + secret_key=encrypted_pem.decode(), + algorithm="RS256", + token_duration=1200, + passphrase="test_passphrase", + ) + + result_key = authenticator._get_secret_key() + + assert isinstance(result_key, rsa.RSAPrivateKey) + + original_public_key = private_key.public_key() + result_public_key = result_key.public_key() + + original_public_numbers = original_public_key.public_numbers() + result_public_numbers = result_public_key.public_numbers() + + assert original_public_numbers.n == result_public_numbers.n + assert original_public_numbers.e == result_public_numbers.e + + def test_get_secret_key_with_wrong_passphrase_raises_error(self): + """Test that _get_secret_key raises error with wrong passphrase.""" + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + passphrase = b"correct_passphrase" + encrypted_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(passphrase), + ) + + authenticator = JwtAuthenticator( + config={}, + parameters={}, + secret_key=encrypted_pem.decode(), + algorithm="RS256", + token_duration=1200, + passphrase="wrong_passphrase", + ) + + with pytest.raises(Exception): + authenticator._get_secret_key() + + def test_get_signed_token_with_passphrase_protected_key(self): + """Test that JWT signing works with passphrase-protected RSA private key.""" + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + passphrase = b"test_passphrase" + encrypted_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(passphrase), + ) + + authenticator = JwtAuthenticator( + config={}, + parameters={}, + secret_key=encrypted_pem.decode(), + algorithm="RS256", + token_duration=1000, + passphrase="test_passphrase", + typ="JWT", + iss="test_issuer", + ) + + signed_token = authenticator._get_signed_token() + + assert isinstance(signed_token, str) + assert len(signed_token.split(".")) == 3 + + public_key = private_key.public_key() + decoded_payload = jwt.decode(signed_token, public_key, algorithms=["RS256"]) + + assert decoded_payload["iss"] == "test_issuer" + assert "iat" in decoded_payload + assert "exp" in decoded_payload diff --git a/unit_tests/sources/declarative/interpolation/test_macros.py b/unit_tests/sources/declarative/interpolation/test_macros.py index 42da205db..1102cbc35 100644 --- a/unit_tests/sources/declarative/interpolation/test_macros.py +++ b/unit_tests/sources/declarative/interpolation/test_macros.py @@ -3,6 +3,7 @@ # import datetime +import uuid import pytest @@ -20,6 +21,7 @@ ("test_format_datetime", "format_datetime", True), ("test_duration", "duration", True), ("test_camel_case_to_snake_case", "camel_case_to_snake_case", True), + ("test_generate_uuid", "generate_uuid", True), ("test_not_a_macro", "thisisnotavalidmacro", False), ], ) @@ -275,3 +277,26 @@ def test_sanitize_url(test_name, input_value, expected_output): ) def test_camel_case_to_snake_case(value, expected_value): assert macros["camel_case_to_snake_case"](value) == expected_value + + +def test_generate_uuid(): + """Test uuid macro generates valid UUID4 strings.""" + uuid_fn = macros["generate_uuid"] + + # Test that uuid function returns a string + result = uuid_fn() + assert isinstance(result, str) + + # Test that the result is a valid UUID format + # This will raise ValueError if not a valid UUID + parsed_uuid = uuid.UUID(result) + + # Test that it's specifically a UUID4 (version 4) + assert parsed_uuid.version == 4 + + # Test that multiple calls return different UUIDs + result2 = uuid_fn() + assert result != result2 + + # Test that both results are valid UUIDs + uuid.UUID(result2) # Will raise ValueError if invalid