|
4 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../ads_incrementality_dap_collector')))
|
5 | 5 |
|
6 | 6 | import pytest
|
7 |
| - |
| 7 | +import re |
8 | 8 | from unittest import TestCase
|
9 | 9 | from unittest.mock import patch
|
10 | 10 |
|
|
13 | 13 | mock_nimbus_experiment, mock_control_row, mock_treatment_a_row, mock_treatment_b_row,
|
14 | 14 | mock_task_id, mock_nimbus_unparseable_experiment,
|
15 | 15 | mock_tasks_to_collect, mock_dap_config, mock_experiment_config,
|
16 |
| - mock_dap_subprocess_success, mock_dap_subprocess_fail, mock_dap_subprocess_raise |
| 16 | + mock_dap_subprocess_success, mock_dap_subprocess_fail, mock_dap_subprocess_raise, |
| 17 | + mock_collected_tasks, mock_bq_config, |
| 18 | + mock_create_dataset_success, mock_create_table_success, mock_insert_rows_json_success, |
| 19 | + mock_create_dataset_fail, mock_create_table_fail, mock_insert_rows_json_fail |
| 20 | +) |
| 21 | +from ads_incrementality_dap_collector.helpers import ( |
| 22 | + get_experiment, prepare_results_rows, collect_dap_results, write_results_to_bq |
17 | 23 | )
|
18 |
| -from ads_incrementality_dap_collector.helpers import get_experiment, prepare_results_rows, collect_dap_results |
19 | 24 |
|
20 | 25 | class TestHelpers(TestCase):
|
21 | 26 | @patch("requests.get", side_effect=mock_nimbus_success)
|
@@ -69,3 +74,43 @@ def test_collect_dap_results_raise(self, mock_dap_subprocess_raise):
|
69 | 74 | with pytest.raises(Exception, match=f'Collection failed for {task_id}, 1, stderr: Uh-oh'):
|
70 | 75 | collect_dap_results(tasks_to_collect, mock_dap_config(), mock_experiment_config())
|
71 | 76 | self.assertEqual(1, mock_dap_subprocess_success.call_count)
|
| 77 | + |
| 78 | + @patch("google.cloud.bigquery.Client.create_dataset", side_effect=mock_create_dataset_success) |
| 79 | + @patch("google.cloud.bigquery.Client.create_table", side_effect=mock_create_table_success) |
| 80 | + @patch("google.cloud.bigquery.Client.insert_rows_json", side_effect=mock_insert_rows_json_success) |
| 81 | + def test_write_results_to_bq_success(self, mock_insert_rows_json_success, mock_create_table_success, mock_create_dataset_success): |
| 82 | + collected_tasks = mock_collected_tasks() |
| 83 | + write_results_to_bq(collected_tasks, mock_bq_config()) |
| 84 | + self.assertEqual(1, mock_create_dataset_success.call_count) |
| 85 | + self.assertEqual(1, mock_create_table_success.call_count) |
| 86 | + self.assertEqual(len(collected_tasks["mubArkO3So8Co1X98CBo62-lSCM4tB-NZPOUGJ83N1o"]), mock_insert_rows_json_success.call_count) |
| 87 | + |
| 88 | + @patch("google.cloud.bigquery.Client.create_dataset", side_effect=mock_create_dataset_fail) |
| 89 | + @patch("google.cloud.bigquery.Client.create_table", side_effect=mock_create_table_success) |
| 90 | + @patch("google.cloud.bigquery.Client.insert_rows_json", side_effect=mock_insert_rows_json_success) |
| 91 | + def test_write_results_to_bq_create_dataset_fail(self, mock_insert_rows_json_success, mock_create_table_success, mock_create_dataset_fail): |
| 92 | + with pytest.raises(Exception, match='BQ create dataset Uh-oh'): |
| 93 | + write_results_to_bq(mock_collected_tasks(), mock_bq_config()) |
| 94 | + self.assertEqual(1, mock_create_dataset_fail.call_count) |
| 95 | + self.assertEqual(0, mock_create_table_success.call_count) |
| 96 | + self.assertEqual(0, mock_insert_rows_json_success.call_count) |
| 97 | + |
| 98 | + @patch("google.cloud.bigquery.Client.create_dataset", side_effect=mock_create_dataset_success) |
| 99 | + @patch("google.cloud.bigquery.Client.create_table", side_effect=mock_create_table_fail) |
| 100 | + @patch("google.cloud.bigquery.Client.insert_rows_json", side_effect=mock_insert_rows_json_success) |
| 101 | + def test_write_results_to_bq_create_table_fail(self, mock_insert_rows_json_success, mock_create_table_fail, mock_create_dataset_success): |
| 102 | + with pytest.raises(Exception, match='Failed to create BQ table: some-gcp-project-id.ads_dap.incrementality'): |
| 103 | + write_results_to_bq(mock_collected_tasks(), mock_bq_config()) |
| 104 | + self.assertEqual(1, mock_create_dataset_success.call_count) |
| 105 | + self.assertEqual(1, mock_create_dataset_fail.call_count) |
| 106 | + self.assertEqual(0, mock_insert_rows_json_success.call_count) |
| 107 | + |
| 108 | + @patch("google.cloud.bigquery.Client.create_dataset", side_effect=mock_create_dataset_success) |
| 109 | + @patch("google.cloud.bigquery.Client.create_table", side_effect=mock_create_table_success) |
| 110 | + @patch("google.cloud.bigquery.Client.insert_rows_json", side_effect=mock_insert_rows_json_fail) |
| 111 | + def test_write_results_to_bq_insert_rows_fail(self, mock_insert_rows_json_fail, mock_create_table_success, mock_create_dataset_success): |
| 112 | + with pytest.raises(Exception, match=re.escape("Error inserting rows into some-gcp-project-id.ads_dap.incrementality: [{'key': 0, 'errors': 'Problem writing bucket 1 results'}, {'key': 1, 'errors': 'Problem writing bucket 2 results'}, {'key': 2, 'errors': 'Problem writing bucket 3 results'}]")): |
| 113 | + write_results_to_bq(mock_collected_tasks(), mock_bq_config()) |
| 114 | + self.assertEqual(1, mock_create_dataset_success.call_count) |
| 115 | + self.assertEqual(1, mock_create_table_success.call_count) |
| 116 | + self.assertEqual(1, mock_insert_rows_json_fail.call_count) |
0 commit comments