diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index a8ef81393..a79cd6383 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1264,22 +1264,12 @@ def create_concurrent_cursor_from_datetime_based_cursor( component_definition: ComponentDefinition, stream_name: str, stream_namespace: Optional[str], + stream_state: MutableMapping[str, Any], config: Config, message_repository: Optional[MessageRepository] = None, runtime_lookback_window: Optional[datetime.timedelta] = None, - stream_state_migrations: Optional[List[Any]] = None, **kwargs: Any, ) -> ConcurrentCursor: - # Per-partition incremental streams can dynamically create child cursors which will pass their current - # state via the stream_state keyword argument. Incremental syncs without parent streams use the - # incoming state and connector_state_manager that is initialized when the component factory is created - stream_state = ( - self._connector_state_manager.get_stream_state(stream_name, stream_namespace) - if "stream_state" not in kwargs - else kwargs["stream_state"] - ) - stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) - component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: raise ValueError( @@ -1498,21 +1488,11 @@ def create_concurrent_cursor_from_incrementing_count_cursor( component_definition: ComponentDefinition, stream_name: str, stream_namespace: Optional[str], + stream_state: MutableMapping[str, Any], config: Config, message_repository: Optional[MessageRepository] = None, - stream_state_migrations: Optional[List[Any]] = None, **kwargs: Any, ) -> ConcurrentCursor: - # Per-partition incremental streams can dynamically create child cursors which will pass their current - # state via the stream_state keyword argument. Incremental syncs without parent streams use the - # incoming state and connector_state_manager that is initialized when the component factory is created - stream_state = ( - self._connector_state_manager.get_stream_state(stream_name, stream_namespace) - if "stream_state" not in kwargs - else kwargs["stream_state"] - ) - stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) - component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: raise ValueError( @@ -1587,7 +1567,6 @@ def create_concurrent_cursor_from_perpartition_cursor( config: Config, stream_state: MutableMapping[str, Any], partition_router: PartitionRouter, - stream_state_migrations: Optional[List[Any]] = None, attempt_to_create_cursor_if_not_provided: bool = False, **kwargs: Any, ) -> ConcurrentPerPartitionCursor: @@ -1647,11 +1626,9 @@ def create_concurrent_cursor_from_perpartition_cursor( stream_namespace=stream_namespace, config=config, message_repository=NoopMessageRepository(), - # stream_state_migrations=stream_state_migrations, # FIXME is it expected to run migration on per partition state too? ) ) - stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) # Per-partition state doesn't make sense for GroupingPartitionRouter, so force the global state use_global_cursor = isinstance( partition_router, GroupingPartitionRouter @@ -1974,6 +1951,7 @@ def create_default_stream( self, model: DeclarativeStreamModel, config: Config, is_parent: bool = False, **kwargs: Any ) -> AbstractStream: primary_key = model.primary_key.__root__ if model.primary_key else None + self._migrate_state(model, config) partition_router = self._build_stream_slicer_from_partition_router( model.retriever, @@ -2135,6 +2113,23 @@ def create_default_stream( supports_file_transfer=hasattr(model, "file_uploader") and bool(model.file_uploader), ) + def _migrate_state(self, model: DeclarativeStreamModel, config: Config) -> None: + stream_name = model.name or "" + stream_state = self._connector_state_manager.get_stream_state( + stream_name=stream_name, namespace=None + ) + if model.state_migrations: + state_transformations = [ + self._create_component_from_model(state_migration, config, declarative_stream=model) + for state_migration in model.state_migrations + ] + else: + state_transformations = [] + stream_state = self.apply_stream_state_migrations(state_transformations, stream_state) + self._connector_state_manager.update_state_for_stream( + stream_name=stream_name, namespace=None, value=stream_state + ) + def _is_stop_condition_on_cursor(self, model: DeclarativeStreamModel) -> bool: return bool( model.incremental_sync @@ -2206,17 +2201,7 @@ def _build_concurrent_cursor( config: Config, ) -> Cursor: stream_name = model.name or "" - stream_state = self._connector_state_manager.get_stream_state( - stream_name=stream_name, namespace=None - ) - - if model.state_migrations: - state_transformations = [ - self._create_component_from_model(state_migration, config, declarative_stream=model) - for state_migration in model.state_migrations - ] - else: - state_transformations = [] + stream_state = self._connector_state_manager.get_stream_state(stream_name, None) if ( model.incremental_sync @@ -2228,10 +2213,9 @@ def _build_concurrent_cursor( model_type=DatetimeBasedCursorModel, component_definition=model.incremental_sync.__dict__, stream_name=stream_name, + stream_state=stream_state, stream_namespace=None, config=config or {}, - stream_state=stream_state, - stream_state_migrations=state_transformations, partition_router=stream_slicer, attempt_to_create_cursor_if_not_provided=True, # FIXME can we remove that now? ) @@ -2242,8 +2226,8 @@ def _build_concurrent_cursor( component_definition=model.incremental_sync.__dict__, stream_name=stream_name, stream_namespace=None, + stream_state=stream_state, config=config or {}, - stream_state_migrations=state_transformations, ) elif type(model.incremental_sync) == DatetimeBasedCursorModel: return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing @@ -2251,8 +2235,8 @@ def _build_concurrent_cursor( component_definition=model.incremental_sync.__dict__, stream_name=stream_name, stream_namespace=None, + stream_state=stream_state, config=config or {}, - stream_state_migrations=state_transformations, attempt_to_create_cursor_if_not_provided=True, ) else: diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 93c675de2..dc3334650 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -195,6 +195,7 @@ } CONFIG_START_TIME = ab_datetime_parse(input_config["start_time"]) CONFIG_END_TIME = ab_datetime_parse(input_config["end_time"]) +_NO_STATE = {} def get_factory_with_parameters( @@ -3325,21 +3326,7 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields( stream_name = "test" - connector_state_manager = ConnectorStateManager( - state=[ - AirbyteStateMessage( - type=AirbyteStateType.STREAM, - stream=AirbyteStreamState( - stream_descriptor=StreamDescriptor(name=stream_name), - stream_state=AirbyteStateBlob(stream_state), - ), - ) - ] - ) - - connector_builder_factory = ModelToComponentFactory( - emit_connector_builder_messages=True, connector_state_manager=connector_state_manager - ) + connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) cursor_component_definition = { "type": "DatetimeBasedCursor", @@ -3360,13 +3347,13 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields( component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=stream_state, config=config, ) ) assert concurrent_cursor._stream_name == stream_name assert not concurrent_cursor._stream_namespace - assert concurrent_cursor._connector_state_manager == connector_state_manager assert concurrent_cursor.cursor_field.cursor_field_key == expected_cursor_field assert concurrent_cursor._slice_range == expected_step assert concurrent_cursor._lookback_window == expected_lookback_window @@ -3481,8 +3468,8 @@ def test_create_concurrent_cursor_from_datetime_based_cursor( component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=_NO_STATE, config=config, - stream_state={}, ) else: concurrent_cursor = ( @@ -3492,131 +3479,184 @@ def test_create_concurrent_cursor_from_datetime_based_cursor( component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=_NO_STATE, config=config, - stream_state={}, ) ) assert getattr(concurrent_cursor, assertion_field) == expected_value -def test_create_concurrent_cursor_from_datetime_based_cursor_runs_state_migrations(): - class DummyStateMigration: - def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: - return True - - def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: - updated_at = stream_state["updated_at"] - return { - "states": [ - { - "partition": {"type": "type_1"}, - "cursor": {"updated_at": updated_at}, - }, - { - "partition": {"type": "type_2"}, - "cursor": {"updated_at": updated_at}, - }, - ] - } +def test_create_default_stream_with_datetime_cursor_then_runs_state_migrations(): + content = """ + type: DeclarativeStream + primary_key: "id" + name: test + schema_loader: + type: InlineSchemaLoader + schema: + $schema: "http://json-schema.org/draft-07/schema" + type: object + properties: + id: + type: string + incremental_sync: + type: "DatetimeBasedCursor" + cursor_field: "updated_at" + datetime_format: "%Y-%m-%dT%H:%M:%S.%f%z" + start_datetime: "{{ config['start_time'] }}" + end_datetime: "{{ config['end_time'] }}" + partition_field_start: "custom_start" + partition_field_end: "custom_end" + step: "P10D" + cursor_granularity: "PT0.000001S" + lookback_window: "P3D" + retriever: + type: SimpleRetriever + name: test + requester: + type: HttpRequester + name: "test" + url_base: "https://api.test.com/v3/" + http_method: "GET" + authenticator: + type: NoAuth + record_selector: + type: RecordSelector + extractor: + type: DpathExtractor + field_path: [] + state_migrations: + - type: CustomStateMigration + class_name: unit_tests.sources.declarative.parsers.testing_components.TestingStateMigration + """ - stream_name = "test" - config = { - "start_time": "2024-08-01T00:00:00.000000Z", - "end_time": "2024-09-01T00:00:00.000000Z", - } - stream_state = {"updated_at": "2025-01-01T00:00:00.000000Z"} - connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) - connector_state_manager = ConnectorStateManager() - cursor_component_definition = { - "type": "DatetimeBasedCursor", - "cursor_field": "updated_at", - "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", - "start_datetime": "{{ config['start_time'] }}", - "end_datetime": "{{ config['end_time'] }}", - "partition_field_start": "custom_start", - "partition_field_end": "custom_end", - "step": "P10D", - "cursor_granularity": "PT0.000001S", - "lookback_window": "P3D", - } - concurrent_cursor = ( - connector_builder_factory.create_concurrent_cursor_from_datetime_based_cursor( - state_manager=connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=cursor_component_definition, - stream_name=stream_name, - stream_namespace=None, - config=config, - stream_state=stream_state, - stream_state_migrations=[DummyStateMigration()], - ) + stream_state = {"updated_at": "2025-01-01T00:00:00.000000+00:00"} + connector_state_manager = ConnectorStateManager( + state=[ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="test"), + stream_state=AirbyteStateBlob(stream_state), + ), + ) + ] ) - assert concurrent_cursor.state["states"] == [ + factory = ModelToComponentFactory( + emit_connector_builder_messages=True, connector_state_manager=connector_state_manager + ) + stream = factory.create_component( + model_type=DeclarativeStreamModel, + component_definition=YamlDeclarativeSource._parse(content), + config=input_config, + ) + assert stream.cursor.state["states"] == [ {"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_1"}}, {"cursor": {"updated_at": stream_state["updated_at"]}, "partition": {"type": "type_2"}}, ] -def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations(): - class DummyStateMigration: - def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: - return True - - def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: - stream_state["lookback_window"] = 10 * 2 - return stream_state +def test_create_concurrent_cursor_from_perpartition_cursor_runs_state_migrations_on_both_child_and_parent(): + content = """ + type: DeclarativeStream + primary_key: "id" + name: test + schema_loader: + type: InlineSchemaLoader + schema: + $schema: "http://json-schema.org/draft-07/schema" + type: object + properties: + id: + type: string + incremental_sync: + type: "DatetimeBasedCursor" + cursor_field: "updated_at" + datetime_format: "%Y-%m-%dT%H:%M:%S.%f%z" + start_datetime: "{{ config['start_time'] }}" + retriever: + type: SimpleRetriever + name: test + requester: + type: HttpRequester + name: "test" + url_base: "https://api.test.com/v3/" + http_method: "GET" + authenticator: + type: NoAuth + record_selector: + type: RecordSelector + extractor: + type: DpathExtractor + field_path: [] + partition_router: + type: SubstreamPartitionRouter + parent_stream_configs: + - type: ParentStreamConfig + parent_key: id + partition_field: id + incremental_dependency: true + stream: + type: DeclarativeStream + primary_key: id + name: parent_stream + schema_loader: + type: InlineSchemaLoader + schema: + $schema: "http://json-schema.org/draft-07/schema" + type: object + properties: + id: + type: string + incremental_sync: + type: "DatetimeBasedCursor" + cursor_field: "updated_at" + datetime_format: "%Y-%m-%dT%H:%M:%S.%f%z" + start_datetime: "{{ config['start_time'] }}" + retriever: + type: SimpleRetriever + requester: + type: HttpRequester + url_base: "https://api.test.com/v3/parent" + http_method: "GET" + record_selector: + type: RecordSelector + extractor: + type: DpathExtractor + field_path: [] + state_migrations: + - type: CustomStateMigration + class_name: unit_tests.sources.declarative.parsers.testing_components.TestingStateMigrationWithParentState + """ - state = { - "states": [ - { - "partition": {"type": "typ_1"}, - "cursor": {"updated_at": "2024-08-01T00:00:00.000000Z"}, - } - ], - "state": {"updated_at": "2024-08-01T00:00:00.000000Z"}, - "lookback_window": 10, - "parent_state": {"parent_test": {"last_updated": "2024-08-01T00:00:00.000000Z"}}, - } - config = { - "start_time": "2024-08-01T00:00:00.000000Z", - "end_time": "2024-09-01T00:00:00.000000Z", + stream_state = { + "state": {"updated_at": "2025-01-01T00:00:00.000000+00:00"}, + "parent_state": {"parent_stream": {"updated_at": "2025-01-01T00:00:00.000000+00:00"}}, } - list_partition_router = ListPartitionRouter( - cursor_field="id", - values=["type_1", "type_2", "type_3"], - config=config, - parameters={}, + connector_state_manager = ConnectorStateManager( + state=[ + AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="test"), + stream_state=AirbyteStateBlob(stream_state), + ), + ) + ] ) - connector_state_manager = ConnectorStateManager() - stream_name = "test" - cursor_component_definition = { - "type": "DatetimeBasedCursor", - "cursor_field": "updated_at", - "datetime_format": "%Y-%m-%dT%H:%M:%S.%fZ", - "start_datetime": "{{ config['start_time'] }}", - "end_datetime": "{{ config['end_time'] }}", - "partition_field_start": "custom_start", - "partition_field_end": "custom_end", - "step": "P10D", - "cursor_granularity": "PT0.000001S", - "lookback_window": "P3D", - } - connector_builder_factory = ModelToComponentFactory(emit_connector_builder_messages=True) - cursor = connector_builder_factory.create_concurrent_cursor_from_perpartition_cursor( - state_manager=connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=cursor_component_definition, - stream_name=stream_name, - stream_namespace=None, - config=config, - stream_state=state, - partition_router=list_partition_router, - stream_state_migrations=[DummyStateMigration()], + factory = ModelToComponentFactory( + emit_connector_builder_messages=True, connector_state_manager=connector_state_manager ) - assert cursor.state["lookback_window"] != 10, "State migration wasn't called" - assert cursor.state["lookback_window"] == 20, ( - "State migration was called, but actual state don't match expected" + stream = factory.create_component( + model_type=DeclarativeStreamModel, + component_definition=YamlDeclarativeSource._parse(content), + config=input_config, + ) + assert stream.cursor.state["lookback_window"] == 20 + assert ( + stream.cursor._partition_router.parent_stream_configs[0].stream.cursor.state["updated_at"] + == "2024-02-01T00:00:00.000000+0000" ) @@ -3669,8 +3709,8 @@ def test_create_concurrent_cursor_uses_min_max_datetime_format_if_defined(): component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=_NO_STATE, config=config, - stream_state={}, ) ) @@ -3769,8 +3809,8 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_with_clamping( component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=_NO_STATE, config=config, - stream_state={}, ) else: @@ -3781,8 +3821,8 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_with_clamping( component_definition=cursor_component_definition, stream_name=stream_name, stream_namespace=None, + stream_state=_NO_STATE, config=config, - stream_state={}, ) ) diff --git a/unit_tests/sources/declarative/parsers/testing_components.py b/unit_tests/sources/declarative/parsers/testing_components.py index 0b9a68e6b..d37bb9307 100644 --- a/unit_tests/sources/declarative/parsers/testing_components.py +++ b/unit_tests/sources/declarative/parsers/testing_components.py @@ -3,9 +3,10 @@ # from dataclasses import dataclass, field -from typing import ClassVar, List, Optional +from typing import Any, ClassVar, List, Mapping, Optional from airbyte_cdk.sources.declarative.extractors import DpathExtractor +from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration from airbyte_cdk.sources.declarative.partition_routers import SubstreamPartitionRouter from airbyte_cdk.sources.declarative.requesters import RequestOption from airbyte_cdk.sources.declarative.requesters.error_handlers import DefaultErrorHandler @@ -49,3 +50,35 @@ class TestingCustomSubstreamPartitionRouter(SubstreamPartitionRouter): @dataclass class TestingCustomRetriever(SimpleRetriever): pass + + +class TestingStateMigration(StateMigration): + def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: + return True + + def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: + updated_at = stream_state["updated_at"] + return { + "states": [ + { + "partition": {"type": "type_1"}, + "cursor": {"updated_at": updated_at}, + }, + { + "partition": {"type": "type_2"}, + "cursor": {"updated_at": updated_at}, + }, + ] + } + + +class TestingStateMigrationWithParentState(StateMigration): + def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: + return True + + def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]: + stream_state["lookback_window"] = 20 + stream_state["parent_state"]["parent_stream"] = { + "updated_at": "2024-02-01T00:00:00.000000+00:00" + } + return stream_state