|
29 | 29 |
|
30 | 30 | if AIRFLOW_V_3_0_PLUS: |
31 | 31 | from airflow.sdk.execution_time.comms import XComResult |
| 32 | + from airflow.sdk.execution_time.xcom import XCom |
| 33 | +else: |
| 34 | + from airflow.models.xcom import XCom # type: ignore[no-redef] |
32 | 35 |
|
33 | 36 | TEST_LOCATION = "test-location" |
34 | 37 | TEST_CLUSTER_ID = "test-cluster-id" |
@@ -128,3 +131,73 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis |
128 | 131 | ) |
129 | 132 | actual_url = link.get_link(operator=ti.task, ti_key=ti.key) |
130 | 133 | assert actual_url == expected_url |
| 134 | + |
| 135 | + @pytest.mark.db_test |
| 136 | + @mock.patch.object(XCom, "get_value") |
| 137 | + def test_get_link_uses_xcom_url_and_skips_get_config( |
| 138 | + self, |
| 139 | + mock_get_value, |
| 140 | + create_task_instance_of_operator, |
| 141 | + session, |
| 142 | + ): |
| 143 | + xcom_url = "https://console.cloud.google.com/some/service?project=test-proj" |
| 144 | + mock_get_value.return_value = xcom_url |
| 145 | + |
| 146 | + link = GoogleLink() |
| 147 | + ti = create_task_instance_of_operator( |
| 148 | + MyOperator, |
| 149 | + dag_id="test_link_dag", |
| 150 | + task_id="test_link_task", |
| 151 | + location=TEST_LOCATION, |
| 152 | + cluster_id=TEST_CLUSTER_ID, |
| 153 | + project_id=TEST_PROJECT_ID, |
| 154 | + ) |
| 155 | + session.add(ti) |
| 156 | + session.commit() |
| 157 | + |
| 158 | + with mock.patch.object(GoogleLink, "get_config", autospec=True) as m_get_config: |
| 159 | + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) |
| 160 | + |
| 161 | + assert actual_url == xcom_url |
| 162 | + m_get_config.assert_not_called() |
| 163 | + |
| 164 | + @pytest.mark.db_test |
| 165 | + @mock.patch.object(XCom, "get_value") |
| 166 | + def test_get_link_falls_back_to_get_config_when_xcom_not_http( |
| 167 | + self, |
| 168 | + mock_get_value, |
| 169 | + create_task_instance_of_operator, |
| 170 | + session, |
| 171 | + ): |
| 172 | + mock_get_value.return_value = "gs://bucket/path" |
| 173 | + |
| 174 | + link = GoogleLink() |
| 175 | + ti = create_task_instance_of_operator( |
| 176 | + MyOperator, |
| 177 | + dag_id="test_link_dag", |
| 178 | + task_id="test_link_task", |
| 179 | + location=TEST_LOCATION, |
| 180 | + cluster_id=TEST_CLUSTER_ID, |
| 181 | + project_id=TEST_PROJECT_ID, |
| 182 | + ) |
| 183 | + session.add(ti) |
| 184 | + session.commit() |
| 185 | + |
| 186 | + expected_formatted = "https://console.cloud.google.com/expected/link?project=test-proj" |
| 187 | + with ( |
| 188 | + mock.patch.object( |
| 189 | + GoogleLink, |
| 190 | + "get_config", |
| 191 | + return_value={ |
| 192 | + "project_id": ti.task.project_id, |
| 193 | + "location": ti.task.location, |
| 194 | + "cluster_id": ti.task.cluster_id, |
| 195 | + }, |
| 196 | + ) as m_get_config, |
| 197 | + mock.patch.object(GoogleLink, "_format_link", return_value=expected_formatted) as m_fmt, |
| 198 | + ): |
| 199 | + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) |
| 200 | + |
| 201 | + assert actual_url == expected_formatted |
| 202 | + m_get_config.assert_called_once() |
| 203 | + m_fmt.assert_called_once() |
0 commit comments