From aecd9c3eadb26e5e016b7cf919ed16db0dd2b78a Mon Sep 17 00:00:00 2001 From: Siddharth Shettigar Date: Wed, 21 May 2025 23:08:18 -0700 Subject: [PATCH 1/4] Add check for jira_ticket parameter being passed in spark_args for adhoc spark jobs --- service_configuration_lib/spark_config.py | 26 +++++ tests/spark_config_test.py | 133 ++++++++++++++++++++++ 2 files changed, 159 insertions(+) diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index 80e2671..201f288 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -36,6 +36,7 @@ CLUSTERMAN_METRICS_YAML_FILE_PATH = '/nail/srv/configs/clusterman_metrics.yaml' CLUSTERMAN_YAML_FILE_PATH = '/nail/srv/configs/clusterman.yaml' SPARK_TRON_JOB_USER = 'TRON' +JIRA_TICKET_PATTERN = re.compile(r'^[A-Z]+-[0-9]+$') NON_CONFIGURABLE_SPARK_OPTS = { 'spark.master', @@ -989,6 +990,15 @@ def compute_approx_hourly_cost_dollars( ) return min_dollars, max_dollars + def _get_valid_jira_ticket(self, user_spark_opts: Mapping[str, str]) -> Optional[str]: + """Checks for and validates the 'jira_ticket' format.""" + ticket = user_spark_opts.get('jira_ticket') + if ticket and isinstance(ticket, str) and JIRA_TICKET_PATTERN.match(ticket): + log.info(f'Valid Jira ticket provided: {ticket}') + return ticket + log.warning(f'Jira ticket missing or invalid format: {ticket}') + return None + def get_spark_conf( self, cluster_manager: str, @@ -1046,6 +1056,22 @@ def get_spark_conf( # is str type. user_spark_opts = _convert_user_spark_opts_value_to_str(user_spark_opts) + if self.mandatory_default_spark_srv_conf.get('spark.jira_ticket.enabled') == 'true': + needs_jira_check = os.environ.get('USER', '') not in ['batch', 'TRON', ''] + if needs_jira_check: + valid_ticket = self._get_valid_jira_ticket(user_spark_opts) + if valid_ticket is None: + is_jupyter = _is_jupyterhub_job(user_spark_opts.get('spark.app.name', spark_app_base_name)) + error_msg = ( + 'Job requires a valid Jira ticket (format PROJ-1234) provided via spark-args.\n' + 'Reason: https://yelpwiki.yelpcorp.com/spaces/AML/pages/402885641/jira_ticket+in+spark-args \n' + 'Please add jira_ticket=YOUR-TICKET to your spark-args. \n' + 'For questions please reach out to #spark on slack. \n' + ) + raise RuntimeError(error_msg) + else: + log.debug('Jira ticket check not required for this job configuration.') + app_base_name = ( user_spark_opts.get('spark.app.name') or spark_app_base_name diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index 5c57773..c66ac89 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -1713,3 +1713,136 @@ def test_send_and_calculate_resources_cost( mock_clusterman_metrics.util.costs.estimate_cost_per_hour.assert_called_once_with( cluster='test-cluster', pool='test-pool', cpus=10, mem=2048, ) + + +class TestGetValidJiraTicket: + """Tests for the _get_valid_jira_ticket function.""" + + @pytest.fixture + def mock_spark_srv_conf_file(self): + pass + + @pytest.mark.parametrize( + 'ticket,expected_result', [ + ('CLOUD-123', 'CLOUD-123'), + ('PROJ-456', 'PROJ-456'), + ('ABC-789', 'ABC-789'), + ('LONGPROJECT-1234', 'LONGPROJECT-1234'), + ], + ) + def test_valid_jira_tickets(self, ticket, expected_result, mock_spark_srv_conf_file, mock_log): + """Test that valid Jira tickets are accepted and returned as is.""" + spark_conf_builder = spark_config.SparkConfBuilder() + result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) + assert result == expected_result + mock_log.info.assert_called_once_with(f'Valid Jira ticket provided: {ticket}') + + @pytest.mark.parametrize( + 'ticket', [ + 'cloud-123', + 'proj-456', + 'PROJ-ABC', + 'CLOUD-ABC-1234', + '123-456', + 'PROJ123', + 'PROJ-', + '-123', + '', + ], + ) + def test_invalid_jira_ticket_formats(self, ticket, mock_spark_srv_conf_file, mock_log): + """Test that invalid Jira ticket formats are rejected.""" + spark_conf_builder = spark_config.SparkConfBuilder() + result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) + assert result is None + mock_log.warning.assert_called_once_with(f'Jira ticket missing or invalid format: {ticket}') + + @pytest.mark.parametrize( + 'ticket', [ + None, + 123, + True, + ['PROJ-123'], + {'ticket': 'PROJ-123'}, + ], + ) + def test_invalid_jira_ticket_types(self, ticket, mock_spark_srv_conf_file, mock_log): + """Test that non-string Jira tickets are rejected.""" + spark_conf_builder = spark_config.SparkConfBuilder() + result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) + assert result is None + mock_log.warning.assert_called_once_with(f'Jira ticket missing or invalid format: {ticket}') + + def test_missing_jira_ticket(self, mock_spark_srv_conf_file, mock_log): + """Test that missing Jira ticket key is handled correctly.""" + spark_conf_builder = spark_config.SparkConfBuilder() + result = spark_conf_builder._get_valid_jira_ticket({}) # Empty dict, no jira_ticket key + assert result is None + mock_log.warning.assert_called_once_with('Jira ticket missing or invalid format: None') + + @pytest.mark.parametrize( + 'mandatory_config,user,expected_exception', [ + ({'spark.jira_ticket.enabled': 'true'}, 'regular_user', True), + ({'spark.jira_ticket.enabled': 'true'}, 'batch', False), + ({'spark.jira_ticket.enabled': 'true'}, 'TRON', False), + ({'spark.jira_ticket.enabled': 'true'}, '', False), + ({'spark.jira_ticket.enabled': 'false'}, 'regular_user', False), + ], + ) + def test_jira_ticket_enforcement( + self, mandatory_config, user, expected_exception, + mock_spark_srv_conf_file, monkeypatch, + ): + """Test that Jira ticket enforcement works correctly based on configuration and user.""" + monkeypatch.setenv('USER', user) + with mock.patch.object(spark_config.SparkConfBuilder, '__init__', return_value=None): + spark_conf_builder = spark_config.SparkConfBuilder() + spark_conf_builder.mandatory_default_spark_srv_conf = mandatory_config + + spark_conf_builder.spark_srv_conf = {} + spark_conf_builder.spark_constants = {} + spark_conf_builder.default_spark_srv_conf = {} + spark_conf_builder.spark_costs = {} + spark_conf_builder.is_driver_on_k8s_tron = False + + with mock.patch.object(spark_conf_builder, '_get_valid_jira_ticket') as mock_get_valid_jira_ticket: + mock_get_valid_jira_ticket.return_value = None + + if expected_exception: + with pytest.raises(RuntimeError, match='Job requires a valid Jira ticket'): + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test_app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) + else: + # Should not raise an exception + with mock.patch.multiple( + spark_conf_builder, + _adjust_spark_requested_resources=mock.DEFAULT, + get_dra_configs=mock.DEFAULT, + compute_approx_hourly_cost_dollars=mock.DEFAULT, + _append_spark_prometheus_conf=mock.DEFAULT, + _append_event_log_conf=mock.DEFAULT, + _append_sql_partitions_conf=mock.DEFAULT, + update_spark_srv_configs=mock.DEFAULT, + ) as mocks: + # Set return values for mocked methods + for mock_method in mocks.values(): + mock_method.return_value = {} + + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test_app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) From f94e0a6133e7ab0ceb5bd0c464cb13179b557495 Mon Sep 17 00:00:00 2001 From: Siddharth Shettigar Date: Thu, 22 May 2025 00:08:27 -0700 Subject: [PATCH 2/4] mock_spark_srv_conf_file in new test class --- tests/spark_config_test.py | 56 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index c66ac89..85b5034 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -1719,8 +1719,60 @@ class TestGetValidJiraTicket: """Tests for the _get_valid_jira_ticket function.""" @pytest.fixture - def mock_spark_srv_conf_file(self): - pass + def mock_spark_srv_conf_file(self, tmpdir, monkeypatch): + spark_run_conf = { + 'environments': { + 'testing': { + 'account_id': TEST_ACCOUNT_ID, + 'default_event_log_dir': 's3a://test/eventlog', + 'history_server': 'https://spark-history-testing', + }, + }, + 'spark_constants': { + 'target_mem_cpu_ratio': 7, + 'resource_configs': { + 'recommended': { + 'cpu': 4, + 'mem': 28, + }, + 'medium': { + 'cpu': 8, + 'mem': 56, + }, + 'max': { + 'cpu': 12, + 'mem': 110, + }, + }, + 'cost_factor': { + 'test-cluster': { + 'test-pool': 100, + }, + }, + 'adjust_executor_res_ratio_thresh': 99999, + 'default_resources_waiting_time_per_executor': 2, + 'default_clusterman_observed_scaling_time': 15, + 'high_cost_threshold_daily': 500, + 'defaults': { + 'spark.executor.cores': 4, + 'spark.executor.instances': 2, + 'spark.executor.memory': 28, + 'spark.task.cpus': 1, + 'spark.sql.shuffle.partitions': 128, + 'spark.dynamicAllocation.executorAllocationRatio': 0.8, + 'spark.dynamicAllocation.cachedExecutorIdleTimeout': '1500s', + 'spark.yelp.dra.minExecutorRatio': 0.25, + }, + 'mandatory_defaults': { + 'spark.kubernetes.allocation.batch.size': 512, + 'spark.kubernetes.decommission.script': '/opt/spark/kubernetes/dockerfiles/spark/decom.sh', + 'spark.logConf': 'true', + }, + }, + } + fp = tmpdir.join('tmp_spark_srv_config.yaml') + fp.write(yaml.dump(spark_run_conf)) + monkeypatch.setattr(utils, 'DEFAULT_SPARK_RUN_CONFIG', str(fp)) @pytest.mark.parametrize( 'ticket,expected_result', [ From 8c3751a585f12192aeccdf3130a4537c49275169 Mon Sep 17 00:00:00 2001 From: Siddharth Shettigar Date: Tue, 27 May 2025 09:33:51 -0700 Subject: [PATCH 3/4] update to pass jira ticket via paasta spark-run parameter --- service_configuration_lib/spark_config.py | 25 +- tests/spark_config_test.py | 457 ++++++++++++---------- 2 files changed, 257 insertions(+), 225 deletions(-) diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index 201f288..a217624 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -306,6 +306,7 @@ def _get_k8s_spark_env( include_self_managed_configs: bool = True, k8s_server_address: Optional[str] = None, user: Optional[str] = None, + jira_ticket: Optional[str] = None, ) -> Dict[str, str]: # RFC 1123: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names # technically only paasta instance can be longer than 63 chars. But we apply the normalization regardless. @@ -358,6 +359,9 @@ def _get_k8s_spark_env( 'spark.master': f'k8s://{k8s_server_address}', }) + if jira_ticket is not None: + spark_env['spark.kubernetes.executor.label.spark.yelp.com/jira_ticket'] = jira_ticket + return spark_env @@ -990,10 +994,10 @@ def compute_approx_hourly_cost_dollars( ) return min_dollars, max_dollars - def _get_valid_jira_ticket(self, user_spark_opts: Mapping[str, str]) -> Optional[str]: + def _get_valid_jira_ticket(self, jira_ticket: Optional[str]) -> Optional[str]: """Checks for and validates the 'jira_ticket' format.""" - ticket = user_spark_opts.get('jira_ticket') - if ticket and isinstance(ticket, str) and JIRA_TICKET_PATTERN.match(ticket): + ticket = jira_ticket + if ticket and JIRA_TICKET_PATTERN.match(ticket): log.info(f'Valid Jira ticket provided: {ticket}') return ticket log.warning(f'Jira ticket missing or invalid format: {ticket}') @@ -1016,6 +1020,7 @@ def get_spark_conf( spark_opts_from_env: Optional[Mapping[str, str]] = None, aws_region: Optional[str] = None, service_account_name: Optional[str] = None, + jira_ticket: Optional[str] = None, force_spark_resource_configs: bool = True, user: Optional[str] = None, ) -> Dict[str, str]: @@ -1056,17 +1061,16 @@ def get_spark_conf( # is str type. user_spark_opts = _convert_user_spark_opts_value_to_str(user_spark_opts) - if self.mandatory_default_spark_srv_conf.get('spark.jira_ticket.enabled') == 'true': + if self.mandatory_default_spark_srv_conf.get('spark.yelp.jira_ticket.enabled') == 'true': needs_jira_check = os.environ.get('USER', '') not in ['batch', 'TRON', ''] if needs_jira_check: - valid_ticket = self._get_valid_jira_ticket(user_spark_opts) + valid_ticket = self._get_valid_jira_ticket(jira_ticket) if valid_ticket is None: - is_jupyter = _is_jupyterhub_job(user_spark_opts.get('spark.app.name', spark_app_base_name)) error_msg = ( - 'Job requires a valid Jira ticket (format PROJ-1234) provided via spark-args.\n' - 'Reason: https://yelpwiki.yelpcorp.com/spaces/AML/pages/402885641/jira_ticket+in+spark-args \n' - 'Please add jira_ticket=YOUR-TICKET to your spark-args. \n' - 'For questions please reach out to #spark on slack. \n' + 'Job requires a valid Jira ticket (format PROJ-1234).\n' + 'Please pass the parameter as: paasta spark-run --jira-ticket=PROJ-1234 \n' + 'For more information: https://yelpwiki.yelpcorp.com/spaces/AML/pages/402885641 \n' + 'If you have questions, please reach out to #spark on Slack.\n' ) raise RuntimeError(error_msg) else: @@ -1159,6 +1163,7 @@ def get_spark_conf( include_self_managed_configs=not use_eks, k8s_server_address=k8s_server_address, user=user, + jira_ticket=jira_ticket, )) elif cluster_manager == 'local': spark_conf.update(_get_local_spark_env( diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index 85b5034..9bee293 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -12,7 +12,6 @@ from service_configuration_lib import spark_config from service_configuration_lib import utils - TEST_ACCOUNT_ID = '123456789' TEST_USER = 'UNIT_TEST_USER' @@ -21,6 +20,61 @@ TIME_RETURN_VALUE = 123.456 RANDOM_STRING_RETURN_VALUE = 'do1re2mi3fa4sol4' +BASE_SPARK_RUN_CONF = { + 'environments': { + 'testing': { + 'account_id': TEST_ACCOUNT_ID, + 'default_event_log_dir': 's3a://test/eventlog', + 'history_server': 'https://spark-history-testing', + }, + }, + 'spark_constants': { + 'target_mem_cpu_ratio': 7, + 'resource_configs': { + 'recommended': { + 'cpu': 4, + 'mem': 28, + }, + 'medium': { + 'cpu': 8, + 'mem': 56, + }, + 'max': { + 'cpu': 12, + 'mem': 110, + }, + }, + 'cost_factor': { + 'test-cluster': { + 'test-pool': 100, + }, + 'spark-pnw-prod': { + 'batch': 0.041, + 'stable_batch': 0.142, + }, + }, + 'adjust_executor_res_ratio_thresh': 99999, + 'default_resources_waiting_time_per_executor': 2, + 'default_clusterman_observed_scaling_time': 15, + 'high_cost_threshold_daily': 500, + 'defaults': { + 'spark.executor.cores': 4, + 'spark.executor.instances': 2, + 'spark.executor.memory': 28, + 'spark.task.cpus': 1, + 'spark.sql.shuffle.partitions': 128, + 'spark.dynamicAllocation.executorAllocationRatio': 0.8, + 'spark.dynamicAllocation.cachedExecutorIdleTimeout': '1500s', + 'spark.yelp.dra.minExecutorRatio': 0.25, + }, + 'mandatory_defaults': { + 'spark.kubernetes.allocation.batch.size': 512, + 'spark.kubernetes.decommission.script': '/opt/spark/kubernetes/dockerfiles/spark/decom.sh', + 'spark.logConf': 'true', + }, + }, +} + @pytest.fixture def mock_log(monkeypatch): @@ -180,60 +234,8 @@ class TestGetSparkConf: @pytest.fixture def mock_spark_srv_conf_file(self, tmpdir, monkeypatch): - spark_run_conf = { - 'environments': { - 'testing': { - 'account_id': TEST_ACCOUNT_ID, - 'default_event_log_dir': 's3a://test/eventlog', - 'history_server': 'https://spark-history-testing', - }, - }, - 'spark_constants': { - 'target_mem_cpu_ratio': 7, - 'resource_configs': { - 'recommended': { - 'cpu': 4, - 'mem': 28, - }, - 'medium': { - 'cpu': 8, - 'mem': 56, - }, - 'max': { - 'cpu': 12, - 'mem': 110, - }, - }, - 'cost_factor': { - 'test-cluster': { - 'test-pool': 100, - }, - 'spark-pnw-prod': { - 'batch': 0.041, - 'stable_batch': 0.142, - }, - }, - 'adjust_executor_res_ratio_thresh': 99999, - 'default_resources_waiting_time_per_executor': 2, - 'default_clusterman_observed_scaling_time': 15, - 'high_cost_threshold_daily': 500, - 'defaults': { - 'spark.executor.cores': 4, - 'spark.executor.instances': 2, - 'spark.executor.memory': 28, - 'spark.task.cpus': 1, - 'spark.sql.shuffle.partitions': 128, - 'spark.dynamicAllocation.executorAllocationRatio': 0.8, - 'spark.dynamicAllocation.cachedExecutorIdleTimeout': '1500s', - 'spark.yelp.dra.minExecutorRatio': 0.25, - }, - 'mandatory_defaults': { - 'spark.kubernetes.allocation.batch.size': 512, - 'spark.kubernetes.decommission.script': '/opt/spark/kubernetes/dockerfiles/spark/decom.sh', - 'spark.logConf': 'true', - }, - }, - } + # Use the base configuration + spark_run_conf = dict(BASE_SPARK_RUN_CONF) fp = tmpdir.join('tmp_spark_srv_config.yaml') fp.write(yaml.dump(spark_run_conf)) monkeypatch.setattr(utils, 'DEFAULT_SPARK_RUN_CONFIG', str(fp)) @@ -1715,186 +1717,211 @@ def test_send_and_calculate_resources_cost( ) -class TestGetValidJiraTicket: - """Tests for the _get_valid_jira_ticket function.""" +class TestJiraTicketFunctionality: + """Tests for the Jira ticket functionality in SparkConfBuilder.""" @pytest.fixture - def mock_spark_srv_conf_file(self, tmpdir, monkeypatch): - spark_run_conf = { - 'environments': { - 'testing': { - 'account_id': TEST_ACCOUNT_ID, - 'default_event_log_dir': 's3a://test/eventlog', - 'history_server': 'https://spark-history-testing', - }, - }, - 'spark_constants': { - 'target_mem_cpu_ratio': 7, - 'resource_configs': { - 'recommended': { - 'cpu': 4, - 'mem': 28, - }, - 'medium': { - 'cpu': 8, - 'mem': 56, - }, - 'max': { - 'cpu': 12, - 'mem': 110, - }, - }, - 'cost_factor': { - 'test-cluster': { - 'test-pool': 100, - }, - }, - 'adjust_executor_res_ratio_thresh': 99999, - 'default_resources_waiting_time_per_executor': 2, - 'default_clusterman_observed_scaling_time': 15, - 'high_cost_threshold_daily': 500, - 'defaults': { - 'spark.executor.cores': 4, - 'spark.executor.instances': 2, - 'spark.executor.memory': 28, - 'spark.task.cpus': 1, - 'spark.sql.shuffle.partitions': 128, - 'spark.dynamicAllocation.executorAllocationRatio': 0.8, - 'spark.dynamicAllocation.cachedExecutorIdleTimeout': '1500s', - 'spark.yelp.dra.minExecutorRatio': 0.25, - }, - 'mandatory_defaults': { - 'spark.kubernetes.allocation.batch.size': 512, - 'spark.kubernetes.decommission.script': '/opt/spark/kubernetes/dockerfiles/spark/decom.sh', - 'spark.logConf': 'true', - }, - }, - } + def mock_spark_srv_conf_file_with_jira_enabled(self, tmpdir, monkeypatch): + """Create a mock spark service config file with Jira ticket validation enabled.""" + # Use the base configuration and modify the jira ticket setting + spark_run_conf = dict(BASE_SPARK_RUN_CONF) + spark_run_conf['spark_constants']['mandatory_defaults']['spark.yelp.jira_ticket.enabled'] = 'true' + fp = tmpdir.join('tmp_spark_srv_config.yaml') + fp.write(yaml.dump(spark_run_conf)) + monkeypatch.setattr(utils, 'DEFAULT_SPARK_RUN_CONFIG', str(fp)) + + @pytest.fixture + def mock_spark_srv_conf_file_with_jira_disabled(self, tmpdir, monkeypatch): + """Create a mock spark service config file with Jira ticket validation disabled.""" + # Use the base configuration and modify the jira ticket setting + spark_run_conf = dict(BASE_SPARK_RUN_CONF) + spark_run_conf['spark_constants']['mandatory_defaults']['spark.yelp.jira_ticket.enabled'] = 'false' fp = tmpdir.join('tmp_spark_srv_config.yaml') fp.write(yaml.dump(spark_run_conf)) monkeypatch.setattr(utils, 'DEFAULT_SPARK_RUN_CONFIG', str(fp)) @pytest.mark.parametrize( - 'ticket,expected_result', [ - ('CLOUD-123', 'CLOUD-123'), - ('PROJ-456', 'PROJ-456'), - ('ABC-789', 'ABC-789'), - ('LONGPROJECT-1234', 'LONGPROJECT-1234'), + 'jira_ticket,expected_result', [ + ('PROJ-1234', 'PROJ-1234'), # Valid format + ('ABC-123', 'ABC-123'), # Valid format + ('LONGPROJ-9876', 'LONGPROJ-9876'), # Valid format with longer project name + ('proj-1234', None), # Invalid: lowercase project + ('PROJ1234', None), # Invalid: missing hyphen + ('PROJ-abc', None), # Invalid: non-numeric issue number + ('1234-PROJ', None), # Invalid: wrong order + ('', None), # Invalid: empty string + (None, None), # Invalid: None value ], ) - def test_valid_jira_tickets(self, ticket, expected_result, mock_spark_srv_conf_file, mock_log): - """Test that valid Jira tickets are accepted and returned as is.""" + def test_get_valid_jira_ticket(self, jira_ticket, expected_result, mock_log): + """Test the _get_valid_jira_ticket method with various inputs.""" spark_conf_builder = spark_config.SparkConfBuilder() - result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) + result = spark_conf_builder._get_valid_jira_ticket(jira_ticket) assert result == expected_result - mock_log.info.assert_called_once_with(f'Valid Jira ticket provided: {ticket}') - @pytest.mark.parametrize( - 'ticket', [ - 'cloud-123', - 'proj-456', - 'PROJ-ABC', - 'CLOUD-ABC-1234', - '123-456', - 'PROJ123', - 'PROJ-', - '-123', - '', - ], - ) - def test_invalid_jira_ticket_formats(self, ticket, mock_spark_srv_conf_file, mock_log): - """Test that invalid Jira ticket formats are rejected.""" + if expected_result: + mock_log.info.assert_called_with(f'Valid Jira ticket provided: {jira_ticket}') + else: + mock_log.warning.assert_called_with(f'Jira ticket missing or invalid format: {jira_ticket}') + + def test_k8s_spark_env_with_jira_ticket(self): + """Test that _get_k8s_spark_env adds the Jira ticket label when provided.""" + jira_ticket = 'PROJ-1234' + result = spark_config._get_k8s_spark_env( + paasta_cluster='test-cluster', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + pod_template_path=None, + volumes=None, + paasta_pool='test-pool', + driver_ui_port=12345, + jira_ticket=jira_ticket, + ) + + assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' in result + assert result['spark.kubernetes.executor.label.spark.yelp.com/jira_ticket'] == jira_ticket + + def test_k8s_spark_env_without_jira_ticket(self): + """Test that _get_k8s_spark_env doesn't add the Jira ticket label when not provided.""" + result = spark_config._get_k8s_spark_env( + paasta_cluster='test-cluster', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + pod_template_path=None, + volumes=None, + paasta_pool='test-pool', + driver_ui_port=12345, + jira_ticket=None, + ) + + assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' not in result + + @mock.patch.dict(os.environ, {'USER': 'regular_user'}) + def test_get_spark_conf_with_valid_jira_ticket(self, mock_spark_srv_conf_file_with_jira_enabled): + """Test get_spark_conf with a valid Jira ticket when validation is enabled.""" spark_conf_builder = spark_config.SparkConfBuilder() - result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) - assert result is None - mock_log.warning.assert_called_once_with(f'Jira ticket missing or invalid format: {ticket}') - @pytest.mark.parametrize( - 'ticket', [ - None, - 123, - True, - ['PROJ-123'], - {'ticket': 'PROJ-123'}, - ], - ) - def test_invalid_jira_ticket_types(self, ticket, mock_spark_srv_conf_file, mock_log): - """Test that non-string Jira tickets are rejected.""" + # This should not raise an exception + result = spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + jira_ticket='PROJ-1234', + ) + + # Verify the Jira ticket is passed to _get_k8s_spark_env + assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' in result + assert result['spark.kubernetes.executor.label.spark.yelp.com/jira_ticket'] == 'PROJ-1234' + + @mock.patch.dict(os.environ, {'USER': 'regular_user'}) + def test_get_spark_conf_with_invalid_jira_ticket(self, mock_spark_srv_conf_file_with_jira_enabled): + """Test get_spark_conf with an invalid Jira ticket when validation is enabled.""" spark_conf_builder = spark_config.SparkConfBuilder() - result = spark_conf_builder._get_valid_jira_ticket({'jira_ticket': ticket}) - assert result is None - mock_log.warning.assert_called_once_with(f'Jira ticket missing or invalid format: {ticket}') - def test_missing_jira_ticket(self, mock_spark_srv_conf_file, mock_log): - """Test that missing Jira ticket key is handled correctly.""" + # This should raise a RuntimeError + with pytest.raises(RuntimeError) as excinfo: + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + jira_ticket='invalid-ticket', + ) + + # Verify the error message + assert 'Job requires a valid Jira ticket (format PROJ-1234)' in str(excinfo.value) + assert 'paasta spark-run --jira-ticket=PROJ-1234' in str(excinfo.value) + + @mock.patch.dict(os.environ, {'USER': 'regular_user'}) + def test_get_spark_conf_without_jira_ticket(self, mock_spark_srv_conf_file_with_jira_enabled): + """Test get_spark_conf without a Jira ticket when validation is enabled.""" + spark_conf_builder = spark_config.SparkConfBuilder() + + # This should raise a RuntimeError + with pytest.raises(RuntimeError) as excinfo: + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) + + # Verify the error message + assert 'Job requires a valid Jira ticket (format PROJ-1234)' in str(excinfo.value) + + @mock.patch.dict(os.environ, {'USER': 'regular_user'}) + def test_get_spark_conf_with_jira_validation_disabled(self, mock_spark_srv_conf_file_with_jira_disabled): + """Test get_spark_conf without a Jira ticket when validation is disabled.""" spark_conf_builder = spark_config.SparkConfBuilder() - result = spark_conf_builder._get_valid_jira_ticket({}) # Empty dict, no jira_ticket key - assert result is None - mock_log.warning.assert_called_once_with('Jira ticket missing or invalid format: None') + + # This should not raise an exception + result = spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) + + # Verify no Jira ticket label is added + assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' not in result @pytest.mark.parametrize( - 'mandatory_config,user,expected_exception', [ - ({'spark.jira_ticket.enabled': 'true'}, 'regular_user', True), - ({'spark.jira_ticket.enabled': 'true'}, 'batch', False), - ({'spark.jira_ticket.enabled': 'true'}, 'TRON', False), - ({'spark.jira_ticket.enabled': 'true'}, '', False), - ({'spark.jira_ticket.enabled': 'false'}, 'regular_user', False), + 'user_env,should_check', [ + ('regular_user', True), + ('batch', False), + ('TRON', False), + ('', False), ], ) - def test_jira_ticket_enforcement( - self, mandatory_config, user, expected_exception, - mock_spark_srv_conf_file, monkeypatch, + def test_jira_ticket_check_for_different_users( + self, user_env, should_check, mock_spark_srv_conf_file_with_jira_enabled, mock_log, ): - """Test that Jira ticket enforcement works correctly based on configuration and user.""" - monkeypatch.setenv('USER', user) - with mock.patch.object(spark_config.SparkConfBuilder, '__init__', return_value=None): + """Test that Jira ticket validation is skipped for certain users.""" + with mock.patch.dict(os.environ, {'USER': user_env}): spark_conf_builder = spark_config.SparkConfBuilder() - spark_conf_builder.mandatory_default_spark_srv_conf = mandatory_config - - spark_conf_builder.spark_srv_conf = {} - spark_conf_builder.spark_constants = {} - spark_conf_builder.default_spark_srv_conf = {} - spark_conf_builder.spark_costs = {} - spark_conf_builder.is_driver_on_k8s_tron = False - - with mock.patch.object(spark_conf_builder, '_get_valid_jira_ticket') as mock_get_valid_jira_ticket: - mock_get_valid_jira_ticket.return_value = None - - if expected_exception: - with pytest.raises(RuntimeError, match='Job requires a valid Jira ticket'): - spark_conf_builder.get_spark_conf( - cluster_manager='kubernetes', - spark_app_base_name='test_app', - user_spark_opts={}, - paasta_cluster='test-cluster', - paasta_pool='test-pool', - paasta_service='test-service', - paasta_instance='test-instance', - docker_img='test-image', - ) - else: - # Should not raise an exception - with mock.patch.multiple( - spark_conf_builder, - _adjust_spark_requested_resources=mock.DEFAULT, - get_dra_configs=mock.DEFAULT, - compute_approx_hourly_cost_dollars=mock.DEFAULT, - _append_spark_prometheus_conf=mock.DEFAULT, - _append_event_log_conf=mock.DEFAULT, - _append_sql_partitions_conf=mock.DEFAULT, - update_spark_srv_configs=mock.DEFAULT, - ) as mocks: - # Set return values for mocked methods - for mock_method in mocks.values(): - mock_method.return_value = {} - - spark_conf_builder.get_spark_conf( - cluster_manager='kubernetes', - spark_app_base_name='test_app', - user_spark_opts={}, - paasta_cluster='test-cluster', - paasta_pool='test-pool', - paasta_service='test-service', - paasta_instance='test-instance', - docker_img='test-image', - ) + + if should_check: + # For regular users, validation should be enforced + with pytest.raises(RuntimeError): + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) + else: + # For special users, validation should be skipped + spark_conf_builder.get_spark_conf( + cluster_manager='kubernetes', + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + ) + mock_log.debug.assert_called_with('Jira ticket check not required for this job configuration.') From 66cf3de861a70219ec49edee176d93b48f00c769 Mon Sep 17 00:00:00 2001 From: Siddharth Shettigar Date: Wed, 28 May 2025 12:02:08 -0700 Subject: [PATCH 4/4] fix failing tests --- tests/spark_config_test.py | 4 +++- tests/utils_test.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index 9bee293..2aaab2e 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -1753,7 +1753,9 @@ def mock_spark_srv_conf_file_with_jira_disabled(self, tmpdir, monkeypatch): (None, None), # Invalid: None value ], ) - def test_get_valid_jira_ticket(self, jira_ticket, expected_result, mock_log): + def test_get_valid_jira_ticket( + self, jira_ticket, expected_result, mock_log, mock_spark_srv_conf_file_with_jira_disabled, + ): """Test the _get_valid_jira_ticket method with various inputs.""" spark_conf_builder = spark_config.SparkConfBuilder() result = spark_conf_builder._get_valid_jira_ticket(jira_ticket) diff --git a/tests/utils_test.py b/tests/utils_test.py index 9f49145..d08435d 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -154,6 +154,8 @@ def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead): @pytest.fixture def mock_runtimeenv(): + # Clear the lru_cache before applying the mock + utils.get_runtime_env.cache_clear() with patch('builtins.open', mock_open(read_data=MOCK_ENV_NAME)) as m: yield m