Skip to content

Commit 165a763

Browse files
laksh-krishna-sharmadabla
authored andcommitted
Refactor: deprecate wait_policy in EmrCreateJobFlowOperator in favor of wait_for_completion (apache#56158)
* refactor: deprecate wait_policy in EmrCreateJobFlowOperator in favor of wait_for_completion * added unit test for refactor * resolved copilot comments * resolved copilot comments * fixed failing test * fixed: refactor of wait_policy * ensured backward compatibility * removed "self.wait_policy = wait_policy"
1 parent c355832 commit 165a763

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -654,11 +654,10 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
654654
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
655655
:param verify: Whether or not to verify SSL certificates. See:
656656
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
657-
:param wait_for_completion: Deprecated - use `wait_policy` instead.
658-
Whether to finish task immediately after creation (False) or wait for jobflow
657+
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
659658
completion (True)
660659
(default: None)
661-
:param wait_policy: Whether to finish the task immediately after creation (None) or:
660+
:param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or:
662661
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
663662
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
664663
(default: None)
@@ -698,19 +697,29 @@ def __init__(
698697
super().__init__(**kwargs)
699698
self.emr_conn_id = emr_conn_id
700699
self.job_flow_overrides = job_flow_overrides or {}
701-
self.wait_policy = wait_policy
700+
self.wait_for_completion = wait_for_completion
702701
self.waiter_max_attempts = waiter_max_attempts or 60
703702
self.waiter_delay = waiter_delay or 60
704703
self.deferrable = deferrable
705704

706-
if wait_for_completion is not None:
705+
if wait_policy is not None:
707706
warnings.warn(
708-
"`wait_for_completion` parameter is deprecated, please use `wait_policy` instead.",
707+
"`wait_policy` parameter is deprecated and will be removed in a future release; "
708+
"please use `wait_for_completion` (bool) instead.",
709709
AirflowProviderDeprecationWarning,
710710
stacklevel=2,
711711
)
712-
# preserve previous behaviour
713-
self.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION if wait_for_completion else None
712+
713+
if wait_for_completion is not None:
714+
raise ValueError(
715+
"Cannot specify both `wait_for_completion` and deprecated `wait_policy`. "
716+
"Please use `wait_for_completion` (bool)."
717+
)
718+
719+
self.wait_for_completion = wait_policy in (
720+
WaitPolicy.WAIT_FOR_COMPLETION,
721+
WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
722+
)
714723

715724
@property
716725
def _hook_parameters(self):
@@ -748,8 +757,8 @@ def execute(self, context: Context) -> str | None:
748757
job_flow_id=self._job_flow_id,
749758
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
750759
)
751-
if self.wait_policy:
752-
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
760+
if self.wait_for_completion:
761+
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
753762

754763
if self.deferrable:
755764
self.defer(

providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from botocore.waiter import Waiter
2727
from jinja2 import StrictUndefined
2828

29-
from airflow.exceptions import TaskDeferred
29+
from airflow.exceptions import AirflowProviderDeprecationWarning, TaskDeferred
3030
from airflow.models import DAG, DagRun, TaskInstance
3131
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
3232
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
@@ -216,34 +216,26 @@ def test_execute_returns_job_id(self, mocked_hook_client):
216216
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
217217
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
218218

219-
@pytest.mark.parametrize(
220-
"wait_policy",
221-
[
222-
pytest.param(WaitPolicy.WAIT_FOR_COMPLETION, id="with wait for completion"),
223-
pytest.param(WaitPolicy.WAIT_FOR_STEPS_COMPLETION, id="with wait for steps completion policy"),
224-
],
225-
)
226219
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
227220
@mock.patch.object(Waiter, "wait")
228-
def test_execute_with_wait_policy(self, mock_waiter, _, mocked_hook_client, wait_policy: WaitPolicy):
221+
def test_execute_with_wait_for_completion(self, mock_waiter, _, mocked_hook_client):
229222
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
230223

231-
# Mock out the emr_client creator
232-
self.operator.wait_policy = wait_policy
224+
self.operator.wait_for_completion = True
233225

234226
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
235227
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
236-
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[wait_policy])
228+
assert_expected_waiter_type(mock_waiter, WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION])
237229

238230
def test_create_job_flow_deferrable(self, mocked_hook_client):
239231
"""
240232
Test to make sure that the operator raises a TaskDeferred exception
241-
if run in deferrable mode and wait_policy is set.
233+
if run in deferrable mode and wait_for_completion is set.
242234
"""
243235
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
244236

245237
self.operator.deferrable = True
246-
self.operator.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
238+
self.operator.wait_for_completion = True
247239
with pytest.raises(TaskDeferred) as exc:
248240
self.operator.execute(self.mock_context)
249241

@@ -254,14 +246,22 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):
254246
def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
255247
"""
256248
Test to make sure that the operator does NOT raise a TaskDeferred exception
257-
if run in deferrable mode but wait_policy is not set.
249+
if run in deferrable mode but wait_for_completion is not set.
258250
"""
259251
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
260252

261253
self.operator.deferrable = True
262-
# wait_policy is None by default
254+
# wait_for_completion is None by default
263255
result = self.operator.execute(self.mock_context)
264256
assert result == JOB_FLOW_ID
265257

266258
def test_template_fields(self):
267259
validate_template_fields(self.operator)
260+
261+
def test_wait_policy_deprecation_warning(self):
262+
"""Test that using wait_policy raises a deprecation warning."""
263+
with pytest.warns(AirflowProviderDeprecationWarning, match="`wait_policy` parameter is deprecated"):
264+
EmrCreateJobFlowOperator(
265+
task_id=TASK_ID,
266+
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
267+
)

0 commit comments

Comments
 (0)