Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions airbyte_cdk/sources/declarative/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@
import json
from dataclasses import InitVar, dataclass
from datetime import datetime
from typing import Any, Mapping, Optional, Union, cast
from typing import Any, Mapping, MutableMapping, Optional, Union, cast

import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.requesters.request_option import (
RequestOption,
RequestOptionType,
)

# Type alias for keys that JWT library accepts
JwtKeyTypes = Union[
Expand Down Expand Up @@ -86,6 +89,7 @@ class JwtAuthenticator(DeclarativeAuthenticator):
additional_jwt_headers: Optional[Mapping[str, Any]] = None
additional_jwt_payload: Optional[Mapping[str, Any]] = None
passphrase: Optional[Union[InterpolatedString, str]] = None
request_option: Optional[RequestOption] = None

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
Expand Down Expand Up @@ -121,6 +125,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else None
)

# When we first implemented the JWT authenticator, we assumed that the signed token was always supposed
# to be loaded into the request headers under the `Authorization` key. This is not always the case, but
# this default option allows for backwards compatibility to be retained for existing connectors
self._request_option = self.request_option or RequestOption(
inject_into=RequestOptionType.header, field_name="Authorization", parameters=parameters
)

def _get_jwt_headers(self) -> dict[str, Any]:
"""
Builds and returns the headers used when signing the JWT.
Expand Down Expand Up @@ -213,7 +224,8 @@ def _get_header_prefix(self) -> Union[str, None]:

@property
def auth_header(self) -> str:
return "Authorization"
options = self._get_request_options(RequestOptionType.header)
return next(iter(options.keys()), "")

@property
def token(self) -> str:
Expand All @@ -222,3 +234,18 @@ def token(self) -> str:
if self._get_header_prefix()
else self._get_signed_token()
)

def get_request_params(self) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter)

def get_request_body_data(self) -> Union[Mapping[str, Any], str]:
return self._get_request_options(RequestOptionType.body_data)

def get_request_body_json(self) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json)

def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]:
options: MutableMapping[str, Any] = {}
if self._request_option.inject_into == option_type:
self._request_option.inject_into_request(options, self.token, self.config)
return options
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,10 @@ definitions:
type: string
examples:
- "{{ config['passphrase'] }}"
request_option:
title: Request Option
description: A request option describing where the signed JWT token that is generated should be injected into the outbound API request.
"$ref": "#/definitions/RequestOption"
$parameters:
type: object
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,35 @@ class Algorithm(Enum):
EdDSA = "EdDSA"


class InjectInto(Enum):
request_parameter = "request_parameter"
header = "header"
body_data = "body_data"
body_json = "body_json"


class RequestOption(BaseModel):
type: Literal["RequestOption"]
inject_into: InjectInto = Field(
...,
description="Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.",
examples=["request_parameter", "header", "body_data", "body_json"],
title="Inject Into",
)
field_name: Optional[str] = Field(
None,
description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.",
examples=["segment_id"],
title="Field Name",
)
field_path: Optional[List[str]] = Field(
None,
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
examples=[["data", "viewer", "id"]],
title="Field Path",
)


class JwtHeaders(BaseModel):
class Config:
extra = Extra.forbid
Expand Down Expand Up @@ -454,6 +483,11 @@ class JwtAuthenticator(BaseModel):
examples=["{{ config['passphrase'] }}"],
title="Passphrase",
)
request_option: Optional[RequestOption] = Field(
None,
description="A request option describing where the generated JWT token should be injected into the outbound API request.",
title="Request Option",
)
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down Expand Up @@ -1294,35 +1328,6 @@ class RequestPath(BaseModel):
type: Literal["RequestPath"]


class InjectInto(Enum):
request_parameter = "request_parameter"
header = "header"
body_data = "body_data"
body_json = "body_json"


class RequestOption(BaseModel):
type: Literal["RequestOption"]
inject_into: InjectInto = Field(
...,
description="Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.",
examples=["request_parameter", "header", "body_data", "body_json"],
title="Inject Into",
)
field_name: Optional[str] = Field(
None,
description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.",
examples=["segment_id"],
title="Field Name",
)
field_path: Optional[List[str]] = Field(
None,
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
examples=[["data", "viewer", "id"]],
title="Field Path",
)


class Schemas(BaseModel):
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2683,12 +2683,16 @@ def create_json_file_schema_loader(
file_path=model.file_path or "", config=config, parameters=model.parameters or {}
)

@staticmethod
def create_jwt_authenticator(
model: JwtAuthenticatorModel, config: Config, **kwargs: Any
self, model: JwtAuthenticatorModel, config: Config, **kwargs: Any
) -> JwtAuthenticator:
jwt_headers = model.jwt_headers or JwtHeadersModel(kid=None, typ="JWT", cty=None)
jwt_payload = model.jwt_payload or JwtPayloadModel(iss=None, sub=None, aud=None)
request_option = (
self._create_component_from_model(model.request_option, config)
if model.request_option
else None
)
return JwtAuthenticator(
config=config,
parameters=model.parameters or {},
Expand All @@ -2706,6 +2710,7 @@ def create_jwt_authenticator(
additional_jwt_headers=model.additional_jwt_headers,
additional_jwt_payload=model.additional_jwt_payload,
passphrase=model.passphrase,
request_option=request_option,
)

def create_list_partition_router(
Expand Down
107 changes: 107 additions & 0 deletions unit_tests/sources/declarative/auth/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from cryptography.hazmat.primitives.asymmetric import rsa

from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
from airbyte_cdk.sources.declarative.requesters.request_option import (
RequestOption,
RequestOptionType,
)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,3 +289,106 @@ def test_get_signed_token_with_passphrase_protected_key(self):
assert decoded_payload["iss"] == "test_issuer"
assert "iat" in decoded_payload
assert "exp" in decoded_payload

@pytest.mark.parametrize(
"request_option, expected_request_key",
[
pytest.param(
RequestOption(
inject_into=RequestOptionType.request_parameter,
field_name="custom_parameter",
parameters={},
),
"custom_parameter",
id="test_get_request_headers",
),
pytest.param(
RequestOption(
inject_into=RequestOptionType.body_data, field_name="custom_body", parameters={}
),
"custom_body",
id="test_get_request_headers",
),
pytest.param(
RequestOption(
inject_into=RequestOptionType.body_json, field_name="custom_json", parameters={}
),
"custom_json",
id="test_get_request_headers",
),
],
)
def test_get_request_options(self, request_option, expected_request_key):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm="HS256",
secret_key="test_key",
token_duration=1000,
iss="test_iss",
sub="test_sub",
aud="test_aud",
additional_jwt_payload={"kid": "test_kid"},
request_option=request_option,
)

expected_request_options = {
expected_request_key: jwt.encode(
payload=authenticator._get_jwt_payload(),
key=authenticator._get_secret_key(),
algorithm=authenticator._algorithm,
headers=authenticator._get_jwt_headers(),
)
}

match request_option.inject_into:
case RequestOptionType.request_parameter:
actual_request_options = authenticator.get_request_params()
case RequestOptionType.body_data:
actual_request_options = authenticator.get_request_body_data()
case RequestOptionType.body_json:
actual_request_options = authenticator.get_request_body_json()
case _:
actual_request_options = None

assert actual_request_options == expected_request_options

@pytest.mark.parametrize(
"request_option, expected_header_key",
[
pytest.param(
RequestOption(
inject_into=RequestOptionType.header,
field_name="custom_authorization",
parameters={},
),
"custom_authorization",
id="test_get_request_headers",
),
pytest.param(None, "Authorization", id="test_with_default_authorization_header"),
],
)
def test_get_request_headers(self, request_option, expected_header_key):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm="HS256",
secret_key="test_key",
token_duration=1000,
iss="test_iss",
sub="test_sub",
aud="test_aud",
additional_jwt_payload={"kid": "test_kid"},
request_option=request_option,
)

expected_headers = {
expected_header_key: jwt.encode(
payload=authenticator._get_jwt_payload(),
key=authenticator._get_secret_key(),
algorithm=authenticator._algorithm,
headers=authenticator._get_jwt_headers(),
)
}

assert authenticator.get_auth_header() == expected_headers
Original file line number Diff line number Diff line change
Expand Up @@ -3034,6 +3034,10 @@ def test_create_custom_retriever():
aud: "test aud"
additional_jwt_payload:
test: "test custom payload"
request_option:
type: RequestOption
inject_into: body_json
field_name: authorization
""",
{
"secret_key": "secret_key",
Expand Down Expand Up @@ -3141,6 +3145,11 @@ def test_create_jwt_authenticator(config, manifest, expected):
)
assert authenticator._get_jwt_payload() == jwt_payload

if authenticator_manifest.get("request_option"):
assert authenticator._request_option.inject_into.value == authenticator_manifest.get(
"request_option", {}
).get("inject_into")


def test_use_request_options_provider_for_datetime_based_cursor():
config = {
Expand Down
Loading