diff --git a/README.md b/README.md index 41048d1..53a5b42 100644 --- a/README.md +++ b/README.md @@ -490,6 +490,9 @@ Keyword Arguments: * `tasks` (list[Task], optional): A list of job tags. If not provided, a single task with a notebook task will be created, along with a disposable notebook. Latest Spark version and a single worker clusters will be used to run this ephemeral job. +* `environments` (list[JobEnvironment], optional): A list of job environments to be used when running tasks on + serverless compute. Required for running Spark Python tasks using serverless. WHen running Databricks Notebook tasks + with serverless, the specified environments will override the notebook environment. Usage: ```python diff --git a/src/databricks/labs/pytester/fixtures/compute.py b/src/databricks/labs/pytester/fixtures/compute.py index a4ea066..6ea90d6 100644 --- a/src/databricks/labs/pytester/fixtures/compute.py +++ b/src/databricks/labs/pytester/fixtures/compute.py @@ -13,7 +13,7 @@ CreateInstancePoolResponse, Library, ) -from databricks.sdk.service.jobs import Job, JobSettings, NotebookTask, SparkPythonTask, Task +from databricks.sdk.service.jobs import Job, JobEnvironment, JobSettings, NotebookTask, SparkPythonTask, Task from databricks.sdk.service.pipelines import CreatePipelineResponse, PipelineLibrary, NotebookLibrary, PipelineCluster from databricks.sdk.service.sql import ( CreateWarehouseRequestWarehouseType, @@ -189,6 +189,9 @@ def make_job( * `tasks` (list[Task], optional): A list of job tags. If not provided, a single task with a notebook task will be created, along with a disposable notebook. Latest Spark version and a single worker clusters will be used to run this ephemeral job. + * `environments` (list[JobEnvironment], optional): A list of job environments to be used when running tasks on + serverless compute. Required for running Spark Python tasks using serverless. WHen running Databricks Notebook tasks + with serverless, the specified environments will override the notebook environment. Usage: ```python @@ -209,6 +212,7 @@ def create( # pylint: disable=too-many-arguments libraries: list[Library] | None = None, tags: dict[str, str] | None = None, tasks: list[Task] | None = None, + environments: list[JobEnvironment] | None = None, ) -> Job: if notebook_path is not None: warnings.warn( @@ -250,11 +254,11 @@ def create( # pylint: disable=too-many-arguments path = path or make_notebook(content=content) task.notebook_task = NotebookTask(notebook_path=str(path)) tasks = [task] - response = ws.jobs.create(name=name, tasks=tasks, tags=tags) + response = ws.jobs.create(name=name, tasks=tasks, tags=tags, environments=environments) log_workspace_link(name, f"job/{response.job_id}", anchor=False) job = ws.jobs.get(response.job_id) if isinstance(response, Mock): # For testing - job = Job(settings=JobSettings(name=name, tasks=tasks, tags=tags)) + job = Job(settings=JobSettings(name=name, tasks=tasks, tags=tags, environments=environments)) return job yield from factory("job", create, lambda item: ws.jobs.delete(item.job_id)) diff --git a/tests/unit/fixtures/test_compute.py b/tests/unit/fixtures/test_compute.py index 84042a8..e04cd77 100644 --- a/tests/unit/fixtures/test_compute.py +++ b/tests/unit/fixtures/test_compute.py @@ -1,5 +1,6 @@ from databricks.labs.blueprint.paths import WorkspacePath -from databricks.sdk.service.jobs import SparkPythonTask +from databricks.sdk.service.compute import Environment +from databricks.sdk.service.jobs import JobEnvironment, SparkPythonTask from databricks.labs.pytester.fixtures.compute import ( make_cluster_policy, @@ -46,6 +47,8 @@ def test_make_job_no_args() -> None: assert tasks[0].new_cluster.spark_conf is None assert tasks[0].libraries is None assert tasks[0].timeout_seconds == 0 + environments = job.settings.environments + assert environments is None def test_make_job_with_name() -> None: @@ -116,6 +119,14 @@ def test_make_job_with_tasks() -> None: assert job.settings.tasks == ["CustomTasks"] +def test_make_job_with_environment() -> None: + environment = Environment(environment_version="4") + job_environment = JobEnvironment(environment_key="job_environment", spec=environment) + _, job = call_stateful(make_job, environments=[job_environment]) + assert job.settings.environments is not None + assert job.settings.environments[0] == job_environment + + def test_make_pipeline_no_args() -> None: ctx, pipeline = call_stateful(make_pipeline) assert ctx is not None