Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 32 additions & 36 deletions airbyte_cdk/connector_builder/connector_builder_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#


from dataclasses import asdict, dataclass, field
from typing import Any, ClassVar, Dict, List, Mapping
from dataclasses import asdict
from typing import Any, Dict, List, Mapping, Optional

from airbyte_cdk.connector_builder.test_reader import TestReader
from airbyte_cdk.models import (
Expand All @@ -15,45 +15,32 @@
Type,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
ConcurrentDeclarativeSource,
TestLimits,
)
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
ModelToComponentFactory,
)
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
DEFAULT_MAXIMUM_RECORDS = 100
DEFAULT_MAXIMUM_STREAMS = 100

MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice"
MAX_SLICES_KEY = "max_slices"
MAX_RECORDS_KEY = "max_records"
MAX_STREAMS_KEY = "max_streams"


@dataclass
class TestLimits:
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name

max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS)
max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE)
max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES)
max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS)


def get_limits(config: Mapping[str, Any]) -> TestLimits:
command_config = config.get("__test_read_config", {})
max_pages_per_slice = (
command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE
return TestLimits(
max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS),
max_pages_per_slice=command_config.get(
MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE
),
max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES),
max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS),
)
max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES
max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS
max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS
return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams)


def should_migrate_manifest(config: Mapping[str, Any]) -> bool:
Expand All @@ -75,21 +62,30 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool:
return config.get("__should_normalize", False)


def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource:
def create_source(
config: Mapping[str, Any],
limits: TestLimits,
catalog: Optional[ConfiguredAirbyteCatalog],
state: Optional[List[AirbyteStateMessage]],
) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]:
manifest = config["__injected_declarative_manifest"]
return ManifestDeclarativeSource(

# We enforce a concurrency level of 1 so that the stream is processed on a single thread
# to retain ordering for the grouping of the builder message responses.
if "concurrency_level" in manifest:
manifest["concurrency_level"]["default_concurrency"] = 1
else:
manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1}

return ConcurrentDeclarativeSource(
catalog=catalog,
config=config,
emit_connector_builder_messages=True,
state=state,
source_config=manifest,
emit_connector_builder_messages=True,
migrate_manifest=should_migrate_manifest(config),
normalize_manifest=should_normalize_manifest(config),
component_factory=ModelToComponentFactory(
emit_connector_builder_messages=True,
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
limit_slices_fetched=limits.max_slices,
disable_retries=True,
disable_cache=True,
),
limits=limits,
)


Expand Down
6 changes: 3 additions & 3 deletions airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def handle_connector_builder_request(
def handle_request(args: List[str]) -> str:
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
return orjson.dumps(
source = create_source(config=config, limits=limits, catalog=catalog, state=state)
return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
AirbyteMessageSerializer.dump(
handle_connector_builder_request(source, command, config, catalog, state, limits)
)
).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
).decode()


if __name__ == "__main__":
Expand Down
26 changes: 24 additions & 2 deletions airbyte_cdk/connector_builder/test_reader/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from copy import deepcopy
from json import JSONDecodeError
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Union

from airbyte_cdk.connector_builder.models import (
AuxiliaryRequest,
Expand All @@ -17,6 +17,8 @@
from airbyte_cdk.models import (
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateBlob,
AirbyteStateMessage,
OrchestratorType,
TraceType,
)
Expand Down Expand Up @@ -466,7 +468,7 @@ def handle_current_slice(
return StreamReadSlices(
pages=current_slice_pages,
slice_descriptor=current_slice_descriptor,
state=[latest_state_message] if latest_state_message else [],
state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [],
auxiliary_requests=auxiliary_requests if auxiliary_requests else [],
)

Expand Down Expand Up @@ -718,3 +720,23 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore
Determines the type of the auxiliary request based on the stream and HTTP properties.
"""
return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None))


def convert_state_blob_to_mapping(
state_message: Union[AirbyteStateMessage, Dict[str, Any]],
) -> Dict[str, Any]:
"""
The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not
a dictionary, but rather a list of kwargs fields. This in turn causes it to not be
properly turned into a dictionary when translating this back into response output
by the connector_builder_handler using asdict()
"""

if isinstance(state_message, AirbyteStateMessage) and state_message.stream:
state_value = state_message.stream.stream_state
if isinstance(state_value, AirbyteStateBlob):
state_value_mapping = {k: v for k, v in state_value.__dict__.items()}
state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response
return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict
else:
return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_message_groups(
latest_state_message: Optional[Dict[str, Any]] = None
slice_auxiliary_requests: List[AuxiliaryRequest] = []

while records_count < limit and (message := next(messages, None)):
while message := next(messages, None):
json_message = airbyte_message_to_json(message)

if is_page_http_request_for_different_stream(json_message, stream_name):
Expand Down
36 changes: 15 additions & 21 deletions airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
import os
from typing import Dict, Iterable, List, Optional, Set

from airbyte_cdk.exception_handler import generate_failed_streams_error_message
Expand Down Expand Up @@ -95,11 +96,14 @@ def on_partition(self, partition: Partition) -> None:
"""
stream_name = partition.stream_name()
self._streams_to_running_partitions[stream_name].add(partition)
cursor = self._stream_name_to_instance[stream_name].cursor
if self._slice_logger.should_log_slice_message(self._logger):
self._message_repository.emit_message(
self._slice_logger.create_slice_log_message(partition.to_slice())
)
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition)
self._thread_pool_manager.submit(
self._partition_reader.process_partition, partition, cursor
)

def on_partition_complete_sentinel(
self, sentinel: PartitionCompleteSentinel
Expand All @@ -112,26 +116,16 @@ def on_partition_complete_sentinel(
"""
partition = sentinel.partition

try:
if sentinel.is_successful:
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
exception, stream_descriptor=StreamDescriptor(name=partition.stream_name())
).as_sanitized_airbyte_message()
finally:
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
if partition in partitions_running:
partitions_running.remove(partition)
# If all partitions were generated and this was the last one, the stream is done
if (
partition.stream_name() not in self._streams_currently_generating_partitions
and len(partitions_running) == 0
):
yield from self._on_stream_is_done(partition.stream_name())
yield from self._message_repository.consume_queue()
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
if partition in partitions_running:
partitions_running.remove(partition)
# If all partitions were generated and this was the last one, the stream is done
if (
partition.stream_name() not in self._streams_currently_generating_partitions
and len(partitions_running) == 0
):
yield from self._on_stream_is_done(partition.stream_name())
yield from self._message_repository.consume_queue()

def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
"""
Expand Down
48 changes: 30 additions & 18 deletions airbyte_cdk/sources/concurrent_source/concurrent_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import concurrent
import logging
from queue import Queue
from typing import Iterable, Iterator, List
from typing import Iterable, Iterator, List, Optional

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
Expand All @@ -16,7 +17,7 @@
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.types import (
PartitionCompleteSentinel,
Expand All @@ -43,6 +44,7 @@ def create(
logger: logging.Logger,
slice_logger: SliceLogger,
message_repository: MessageRepository,
queue: Optional[Queue[QueueItem]] = None,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> "ConcurrentSource":
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
Expand All @@ -59,19 +61,21 @@ def create(
logger,
)
return ConcurrentSource(
threadpool,
logger,
slice_logger,
message_repository,
initial_number_of_partitions_to_generate,
timeout_seconds,
threadpool=threadpool,
logger=logger,
slice_logger=slice_logger,
queue=queue,
message_repository=message_repository,
initial_number_partitions_to_generate=initial_number_of_partitions_to_generate,
timeout_seconds=timeout_seconds,
)

def __init__(
self,
threadpool: ThreadPoolManager,
logger: logging.Logger,
slice_logger: SliceLogger = DebugSliceLogger(),
queue: Optional[Queue[QueueItem]] = None,
message_repository: MessageRepository = InMemoryMessageRepository(),
initial_number_partitions_to_generate: int = 1,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
Expand All @@ -91,33 +95,36 @@ def __init__(
self._initial_number_partitions_to_generate = initial_number_partitions_to_generate
self._timeout_seconds = timeout_seconds

# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
# information and might even need to be configurable depending on the source
self._queue = queue or Queue(maxsize=10_000)

def read(
self,
streams: List[AbstractStream],
) -> Iterator[AirbyteMessage]:
self._logger.info("Starting syncing")

# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
# partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more
# information and might even need to be configurable depending on the source
queue: Queue[QueueItem] = Queue(maxsize=10_000)
concurrent_stream_processor = ConcurrentReadProcessor(
streams,
PartitionEnqueuer(queue, self._threadpool),
PartitionEnqueuer(self._queue, self._threadpool),
self._threadpool,
self._logger,
self._slice_logger,
self._message_repository,
PartitionReader(queue),
PartitionReader(
self._queue,
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
),
)

# Enqueue initial partition generation tasks
yield from self._submit_initial_partition_generators(concurrent_stream_processor)

# Read from the queue until all partitions were generated and read
yield from self._consume_from_queue(
queue,
self._queue,
concurrent_stream_processor,
)
self._threadpool.check_for_errors_and_shutdown()
Expand All @@ -141,7 +148,10 @@ def _consume_from_queue(
airbyte_message_or_record_or_exception,
concurrent_stream_processor,
)
if concurrent_stream_processor.is_done() and queue.empty():
# In the event that a partition raises an exception, anything remaining in
# the queue will be missed because is_done() can raise an exception and exit
# out of this loop before remaining items are consumed
if queue.empty() and concurrent_stream_processor.is_done():
# all partitions were generated and processed. we're done here
break

Expand All @@ -161,5 +171,7 @@ def _handle_item(
yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item)
elif isinstance(queue_item, Record):
yield from concurrent_stream_processor.on_record(queue_item)
elif isinstance(queue_item, AirbyteMessage):
yield queue_item
else:
raise ValueError(f"Unknown queue item type: {type(queue_item)}")
Loading
Loading