2626from botocore .waiter import Waiter
2727from jinja2 import StrictUndefined
2828
29- from airflow .exceptions import TaskDeferred
29+ from airflow .exceptions import AirflowProviderDeprecationWarning , TaskDeferred
3030from airflow .models import DAG , DagRun , TaskInstance
3131from airflow .providers .amazon .aws .operators .emr import EmrCreateJobFlowOperator
3232from airflow .providers .amazon .aws .triggers .emr import EmrCreateJobFlowTrigger
@@ -216,34 +216,26 @@ def test_execute_returns_job_id(self, mocked_hook_client):
216216 mocked_hook_client .run_job_flow .return_value = RUN_JOB_FLOW_SUCCESS_RETURN
217217 assert self .operator .execute (self .mock_context ) == JOB_FLOW_ID
218218
219- @pytest .mark .parametrize (
220- "wait_policy" ,
221- [
222- pytest .param (WaitPolicy .WAIT_FOR_COMPLETION , id = "with wait for completion" ),
223- pytest .param (WaitPolicy .WAIT_FOR_STEPS_COMPLETION , id = "with wait for steps completion policy" ),
224- ],
225- )
226219 @mock .patch ("botocore.waiter.get_service_module_name" , return_value = "emr" )
227220 @mock .patch .object (Waiter , "wait" )
228- def test_execute_with_wait_policy (self , mock_waiter , _ , mocked_hook_client , wait_policy : WaitPolicy ):
221+ def test_execute_with_wait_for_completion (self , mock_waiter , _ , mocked_hook_client ):
229222 mocked_hook_client .run_job_flow .return_value = RUN_JOB_FLOW_SUCCESS_RETURN
230223
231- # Mock out the emr_client creator
232- self .operator .wait_policy = wait_policy
224+ self .operator .wait_for_completion = True
233225
234226 assert self .operator .execute (self .mock_context ) == JOB_FLOW_ID
235227 mock_waiter .assert_called_once_with (mock .ANY , ClusterId = JOB_FLOW_ID , WaiterConfig = mock .ANY )
236- assert_expected_waiter_type (mock_waiter , WAITER_POLICY_NAME_MAPPING [wait_policy ])
228+ assert_expected_waiter_type (mock_waiter , WAITER_POLICY_NAME_MAPPING [WaitPolicy . WAIT_FOR_COMPLETION ])
237229
238230 def test_create_job_flow_deferrable (self , mocked_hook_client ):
239231 """
240232 Test to make sure that the operator raises a TaskDeferred exception
241- if run in deferrable mode and wait_policy is set.
233+ if run in deferrable mode and wait_for_completion is set.
242234 """
243235 mocked_hook_client .run_job_flow .return_value = RUN_JOB_FLOW_SUCCESS_RETURN
244236
245237 self .operator .deferrable = True
246- self .operator .wait_policy = WaitPolicy . WAIT_FOR_COMPLETION
238+ self .operator .wait_for_completion = True
247239 with pytest .raises (TaskDeferred ) as exc :
248240 self .operator .execute (self .mock_context )
249241
@@ -254,14 +246,22 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):
254246 def test_create_job_flow_deferrable_no_wait (self , mocked_hook_client ):
255247 """
256248 Test to make sure that the operator does NOT raise a TaskDeferred exception
257- if run in deferrable mode but wait_policy is not set.
249+ if run in deferrable mode but wait_for_completion is not set.
258250 """
259251 mocked_hook_client .run_job_flow .return_value = RUN_JOB_FLOW_SUCCESS_RETURN
260252
261253 self .operator .deferrable = True
262- # wait_policy is None by default
254+ # wait_for_completion is None by default
263255 result = self .operator .execute (self .mock_context )
264256 assert result == JOB_FLOW_ID
265257
266258 def test_template_fields (self ):
267259 validate_template_fields (self .operator )
260+
261+ def test_wait_policy_deprecation_warning (self ):
262+ """Test that using wait_policy raises a deprecation warning."""
263+ with pytest .warns (AirflowProviderDeprecationWarning , match = "`wait_policy` parameter is deprecated" ):
264+ EmrCreateJobFlowOperator (
265+ task_id = TASK_ID ,
266+ wait_policy = WaitPolicy .WAIT_FOR_COMPLETION ,
267+ )
0 commit comments