Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from typing import Any

from airflow.providers.apache.kafka.triggers.await_message import AwaitMessageTrigger
from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.common.compat.sdk import BaseSensorOperator

VALID_COMMIT_CADENCE = {"never", "end_of_batch", "end_of_operator"}


class AwaitMessageSensor(BaseOperator):
class AwaitMessageSensor(BaseSensorOperator):
"""
An Airflow sensor that defers until a specific message is published to Kafka.

Expand Down Expand Up @@ -53,6 +53,10 @@ class AwaitMessageSensor(BaseOperator):
:param poll_interval: How long the kafka consumer should sleep after reaching the end of the Kafka log,
defaults to 5
:param xcom_push_key: the name of a key to push the returned message to, defaults to None
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param timeout: Time elapsed before the task times out and fails (in seconds)
:param poke_interval: This parameter is inherited but not used in this deferrable implementation
:param mode: This parameter is inherited but not used in this deferrable implementation


"""
Expand Down Expand Up @@ -111,7 +115,7 @@ def execute_complete(self, context, event=None):
return event


class AwaitMessageTriggerFunctionSensor(BaseOperator):
class AwaitMessageTriggerFunctionSensor(BaseSensorOperator):
"""
Defer until a specific message is published to Kafka, trigger a registered function, then resume waiting.

Expand All @@ -137,6 +141,10 @@ class AwaitMessageTriggerFunctionSensor(BaseOperator):
cluster, defaults to 1
:param poll_interval: How long the kafka consumer should sleep after reaching the end of the Kafka log,
defaults to 5
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param timeout: Time elapsed before the task times out and fails (in seconds)
:param poke_interval: This parameter is inherited but not used in this deferrable implementation
:param mode: This parameter is inherited but not used in this deferrable implementation


"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,53 @@ def test_await_message_trigger_event_execute_complete(self):
# task should immediately come out of deferred
with pytest.raises(TaskDeferred):
sensor.execute_complete(context={})

def test_await_message_with_timeout_parameter(self):
"""Test that AwaitMessageSensor accepts timeout parameter."""
sensor = AwaitMessageSensor(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
apply_function=_return_true,
timeout=600, # This should now work without errors
)

assert sensor.timeout == 600

def test_await_message_with_soft_fail_parameter(self):
"""Test that AwaitMessageSensor accepts soft_fail parameter."""
sensor = AwaitMessageSensor(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
apply_function=_return_true,
soft_fail=True, # This should now work without errors
)

assert sensor.soft_fail is True

def test_await_message_trigger_function_with_timeout_parameter(self):
"""Test that AwaitMessageTriggerFunctionSensor accepts timeout parameter."""
sensor = AwaitMessageTriggerFunctionSensor(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
apply_function=_return_true,
event_triggered_function=_return_true,
timeout=600,
)

assert sensor.timeout == 600

def test_await_message_trigger_function_with_soft_fail_parameter(self):
"""Test that AwaitMessageTriggerFunctionSensor accepts soft_fail parameter."""
sensor = AwaitMessageTriggerFunctionSensor(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
apply_function=_return_true,
event_triggered_function=_return_true,
soft_fail=True,
)

assert sensor.soft_fail is True