Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions airbyte_cdk/sources/declarative/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,27 @@
import json
from dataclasses import InitVar, dataclass
from datetime import datetime
from typing import Any, Mapping, Optional, Union
from typing import Any, Mapping, Optional, Union, cast

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString

# Type alias for keys that JWT library accepts
JwtKeyTypes = Union[
RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey, str, bytes
]


class JwtAlgorithm(str):
"""
Expand Down Expand Up @@ -74,6 +86,7 @@ class JwtAuthenticator(DeclarativeAuthenticator):
aud: Optional[Union[InterpolatedString, str]] = None
additional_jwt_headers: Optional[Mapping[str, Any]] = None
additional_jwt_payload: Optional[Mapping[str, Any]] = None
passphrase: Optional[Union[InterpolatedString, str]] = None

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
Expand Down Expand Up @@ -103,6 +116,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._additional_jwt_payload = InterpolatedMapping(
self.additional_jwt_payload or {}, parameters=parameters
)
self._passphrase = (
InterpolatedString.create(self.passphrase, parameters=parameters)
if self.passphrase
else None
)

def _get_jwt_headers(self) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -149,16 +167,27 @@ def _get_jwt_payload(self) -> dict[str, Any]:
payload["nbf"] = nbf
return payload

def _get_secret_key(self) -> str:
def _get_secret_key(self) -> JwtKeyTypes:
"""
Returns the secret key used to sign the JWT.
"""
secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads)
return (
base64.b64encode(secret_key.encode()).decode()
if self._base64_encode_secret_key
else secret_key
)

if self._passphrase:
# Load encrypted private key and cast to JWT-compatible type
# The JWT algorithms we support (RSA, ECDSA, EdDSA) use compatible key types
private_key = serialization.load_pem_private_key(
secret_key.encode(),
password=self._passphrase.eval(self.config, json_loads=json.loads).encode(),
backend=default_backend(),
)
return cast(JwtKeyTypes, private_key)
else:
return (
base64.b64encode(secret_key.encode()).decode()
if self._base64_encode_secret_key
else secret_key
)

def _get_signed_token(self) -> Union[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,12 @@ definitions:
title: Additional JWT Payload Properties
description: Additional properties to be added to the JWT payload.
additionalProperties: true
passphrase:
title: Passphrase
description: A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.
type: string
examples:
- "{{ config['passphrase'] }}"
$parameters:
type: object
additionalProperties: true
Expand Down
12 changes: 12 additions & 0 deletions airbyte_cdk/sources/declarative/interpolation/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import re
import typing
import uuid
from typing import Optional, Union
from urllib.parse import quote_plus

Expand Down Expand Up @@ -207,6 +208,16 @@ def camel_case_to_snake_case(value: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", value).lower()


def generate_uuid() -> str:
"""
Generates a UUID4

Usage:
`"{{ generate_uuid() }}"`
"""
return str(uuid.uuid4())


_macros_list = [
now_utc,
today_utc,
Expand All @@ -220,5 +231,6 @@ def camel_case_to_snake_case(value: str) -> str:
str_to_datetime,
sanitize_url,
camel_case_to_snake_case,
generate_uuid,
]
macros = {f.__name__: f for f in _macros_list}
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ class JwtAuthenticator(BaseModel):
description="Additional properties to be added to the JWT payload.",
title="Additional JWT Payload Properties",
)
passphrase: Optional[str] = Field(
None,
description="A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.",
examples=["{{ config['passphrase'] }}"],
title="Passphrase",
)
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,7 @@ def create_jwt_authenticator(
aud=jwt_payload.aud,
additional_jwt_headers=model.additional_jwt_headers,
additional_jwt_payload=model.additional_jwt_payload,
passphrase=model.passphrase,
)

def create_list_partition_router(
Expand Down
100 changes: 100 additions & 0 deletions unit_tests/sources/declarative/auth/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import freezegun
import jwt
import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator

Expand Down Expand Up @@ -185,3 +188,100 @@ def test_get_header_prefix(self, header_prefix, expected):
header_prefix=header_prefix,
)
assert authenticator._get_header_prefix() == expected

def test_get_secret_key_with_passphrase(self):
"""Test _get_secret_key method with encrypted private key and passphrase."""
# Generate a test RSA private key
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)

passphrase = b"test_passphrase"
encrypted_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
)

authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key=encrypted_pem.decode(),
algorithm="RS256",
token_duration=1200,
passphrase="test_passphrase",
)

result_key = authenticator._get_secret_key()

assert isinstance(result_key, rsa.RSAPrivateKey)

original_public_key = private_key.public_key()
result_public_key = result_key.public_key()

original_public_numbers = original_public_key.public_numbers()
result_public_numbers = result_public_key.public_numbers()

assert original_public_numbers.n == result_public_numbers.n
assert original_public_numbers.e == result_public_numbers.e

def test_get_secret_key_with_wrong_passphrase_raises_error(self):
"""Test that _get_secret_key raises error with wrong passphrase."""
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)

passphrase = b"correct_passphrase"
encrypted_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
)

authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key=encrypted_pem.decode(),
algorithm="RS256",
token_duration=1200,
passphrase="wrong_passphrase",
)

with pytest.raises(Exception):
authenticator._get_secret_key()

def test_get_signed_token_with_passphrase_protected_key(self):
"""Test that JWT signing works with passphrase-protected RSA private key."""
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)

passphrase = b"test_passphrase"
encrypted_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
)

authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key=encrypted_pem.decode(),
algorithm="RS256",
token_duration=1000,
passphrase="test_passphrase",
typ="JWT",
iss="test_issuer",
)

signed_token = authenticator._get_signed_token()

assert isinstance(signed_token, str)
assert len(signed_token.split(".")) == 3

public_key = private_key.public_key()
decoded_payload = jwt.decode(signed_token, public_key, algorithms=["RS256"])

assert decoded_payload["iss"] == "test_issuer"
assert "iat" in decoded_payload
assert "exp" in decoded_payload
25 changes: 25 additions & 0 deletions unit_tests/sources/declarative/interpolation/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import datetime
import uuid

import pytest

Expand All @@ -20,6 +21,7 @@
("test_format_datetime", "format_datetime", True),
("test_duration", "duration", True),
("test_camel_case_to_snake_case", "camel_case_to_snake_case", True),
("test_generate_uuid", "generate_uuid", True),
("test_not_a_macro", "thisisnotavalidmacro", False),
],
)
Expand Down Expand Up @@ -275,3 +277,26 @@ def test_sanitize_url(test_name, input_value, expected_output):
)
def test_camel_case_to_snake_case(value, expected_value):
assert macros["camel_case_to_snake_case"](value) == expected_value


def test_generate_uuid():
"""Test uuid macro generates valid UUID4 strings."""
uuid_fn = macros["generate_uuid"]

# Test that uuid function returns a string
result = uuid_fn()
assert isinstance(result, str)

# Test that the result is a valid UUID format
# This will raise ValueError if not a valid UUID
parsed_uuid = uuid.UUID(result)

# Test that it's specifically a UUID4 (version 4)
assert parsed_uuid.version == 4

# Test that multiple calls return different UUIDs
result2 = uuid_fn()
assert result != result2

# Test that both results are valid UUIDs
uuid.UUID(result2) # Will raise ValueError if invalid
Loading