Skip to content

Commit 48f9d3f

Browse files
committed
Refactor cell tracking using hook from ExecutePreprocessor
1 parent 70a70a6 commit 48f9d3f

File tree

4 files changed

+233
-141
lines changed

4 files changed

+233
-141
lines changed

jupyter_scheduler/executors.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,12 @@
1111
import nbformat
1212
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor
1313

14-
from jupyter_scheduler.models import DescribeJob, JobFeature, Status, UpdateJob
14+
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
1515
from jupyter_scheduler.orm import Job, create_session
1616
from jupyter_scheduler.parameterize import add_parameters
1717
from jupyter_scheduler.utils import get_utc_timestamp
1818

1919

20-
class TrackingExecutePreprocessor(ExecutePreprocessor):
21-
"""Custom ExecutePreprocessor that tracks completed cells and updates the database"""
22-
23-
def __init__(self, db_session, job_id, **kwargs):
24-
super().__init__(**kwargs)
25-
self.db_session = db_session
26-
self.job_id = job_id
27-
28-
def preprocess_cell(self, cell, resources, index):
29-
"""
30-
Override to track completed cells in the database.
31-
Calls the superclass implementation and then updates the database.
32-
"""
33-
# Call the superclass implementation
34-
cell, resources = super().preprocess_cell(cell, resources, index)
35-
36-
# Update the database with the current count of completed cells
37-
with self.db_session() as session:
38-
session.query(Job).filter(Job.job_id == self.job_id).update(
39-
{"completed_cells": self.code_cells_executed}
40-
)
41-
session.commit()
42-
43-
return cell, resources
44-
45-
4620
class ExecutionManager(ABC):
4721
"""Base execution manager.
4822
Clients are expected to override this class
@@ -158,14 +132,14 @@ def execute(self):
158132
nb = add_parameters(nb, job.parameters)
159133

160134
staging_dir = os.path.dirname(self.staging_paths["input"])
161-
ep = TrackingExecutePreprocessor(
162-
db_session=self.db_session,
163-
job_id=self.job_id,
164-
kernel_name=nb.metadata.kernelspec["name"],
165-
store_widget_state=True,
166-
cwd=staging_dir
135+
136+
ep = ExecutePreprocessor(
137+
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
167138
)
168139

140+
if self.supported_features().get(JobFeature.track_cell_execution, False):
141+
ep.on_cell_executed = self.__update_completed_cells_hook(ep)
142+
169143
try:
170144
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
171145
except CellExecutionError as e:
@@ -174,6 +148,16 @@ def execute(self):
174148
self.add_side_effects_files(staging_dir)
175149
self.create_output_files(job, nb)
176150

151+
def __update_completed_cells_hook(self, ep: ExecutePreprocessor):
152+
"""Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job."""
153+
def update_completed_cells(cell, cell_index, execute_reply):
154+
with self.db_session() as session:
155+
session.query(Job).filter(Job.job_id == self.job_id).update(
156+
{"completed_cells": ep.code_cells_executed}
157+
)
158+
session.commit()
159+
return update_completed_cells
160+
177161
def add_side_effects_files(self, staging_dir: str):
178162
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
179163
input_notebook = os.path.relpath(self.staging_paths["input"])
@@ -203,6 +187,7 @@ def create_output_files(self, job: DescribeJob, notebook_node):
203187
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
204188
f.write(output)
205189

190+
@classmethod
206191
def supported_features(cls) -> Dict[JobFeature, bool]:
207192
return {
208193
JobFeature.job_name: True,
@@ -218,8 +203,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]:
218203
JobFeature.output_filename_template: False,
219204
JobFeature.stop_job: True,
220205
JobFeature.delete_job: True,
206+
JobFeature.track_cell_execution: True,
221207
}
222208

209+
@classmethod
223210
def validate(cls, input_path: str) -> bool:
224211
with open(input_path, encoding="utf-8") as f:
225212
nb = nbformat.read(f, as_version=4)

jupyter_scheduler/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,4 @@ class JobFeature(str, Enum):
297297
output_filename_template = "output_filename_template"
298298
stop_job = "stop_job"
299299
delete_job = "delete_job"
300+
track_cell_execution = "track_cell_execution"

jupyter_scheduler/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def create_job(self, model: CreateJob) -> str:
442442
raise InputUriError(model.input_uri)
443443

444444
input_path = os.path.join(self.root_dir, model.input_uri)
445-
if not self.execution_manager_class.validate(self.execution_manager_class, input_path):
445+
if not self.execution_manager_class.validate(input_path):
446446
raise SchedulerError(
447447
"""There is no kernel associated with the notebook. Please open
448448
the notebook, select a kernel, and re-submit the job to execute.

0 commit comments

Comments
 (0)