Skip to content

Commit 1a46b8d

Browse files
committed
Add completed_cells tracking during notebook execution
1 parent 576fecf commit 1a46b8d

File tree

6 files changed

+349
-4
lines changed

6 files changed

+349
-4
lines changed

jupyter_scheduler/executors.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,38 @@
1111
import nbformat
1212
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor
1313

14-
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
14+
from jupyter_scheduler.models import DescribeJob, JobFeature, Status, UpdateJob
1515
from jupyter_scheduler.orm import Job, create_session
1616
from jupyter_scheduler.parameterize import add_parameters
1717
from jupyter_scheduler.utils import get_utc_timestamp
1818

1919

20+
class TrackingExecutePreprocessor(ExecutePreprocessor):
21+
"""Custom ExecutePreprocessor that tracks completed cells and updates the database"""
22+
23+
def __init__(self, db_session, job_id, **kwargs):
24+
super().__init__(**kwargs)
25+
self.db_session = db_session
26+
self.job_id = job_id
27+
28+
def preprocess_cell(self, cell, resources, index):
29+
"""
30+
Override to track completed cells in the database.
31+
Calls the superclass implementation and then updates the database.
32+
"""
33+
# Call the superclass implementation
34+
cell, resources = super().preprocess_cell(cell, resources, index)
35+
36+
# Update the database with the current count of completed cells
37+
with self.db_session() as session:
38+
session.query(Job).filter(Job.job_id == self.job_id).update(
39+
{"completed_cells": self.code_cells_executed}
40+
)
41+
session.commit()
42+
43+
return cell, resources
44+
45+
2046
class ExecutionManager(ABC):
2147
"""Base execution manager.
2248
Clients are expected to override this class
@@ -132,8 +158,12 @@ def execute(self):
132158
nb = add_parameters(nb, job.parameters)
133159

134160
staging_dir = os.path.dirname(self.staging_paths["input"])
135-
ep = ExecutePreprocessor(
136-
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
161+
ep = TrackingExecutePreprocessor(
162+
db_session=self.db_session,
163+
job_id=self.job_id,
164+
kernel_name=nb.metadata.kernelspec["name"],
165+
store_widget_state=True,
166+
cwd=staging_dir
137167
)
138168

139169
try:

jupyter_scheduler/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class DescribeJob(BaseModel):
148148
downloaded: bool = False
149149
package_input_folder: Optional[bool] = None
150150
packaged_files: Optional[List[str]] = []
151+
completed_cells: Optional[int] = None
151152

152153
class Config:
153154
orm_mode = True
@@ -193,6 +194,7 @@ class UpdateJob(BaseModel):
193194
status: Optional[Status] = None
194195
name: Optional[str] = None
195196
compute_type: Optional[str] = None
197+
completed_cells: Optional[int] = None
196198

197199

198200
class DeleteJob(BaseModel):

jupyter_scheduler/orm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Job(CommonColumns, Base):
103103
url = Column(String(256), default=generate_jobs_url)
104104
pid = Column(Integer)
105105
idempotency_token = Column(String(256))
106+
completed_cells = Column(Integer)
106107
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
107108
# Any default values specified for new columns will be ignored during the migration process.
108109

jupyter_scheduler/tests/test_execution_manager.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import shutil
22
from pathlib import Path
33
from typing import Tuple
4+
from unittest.mock import MagicMock, patch
45

56
import pytest
7+
import nbformat
68

7-
from jupyter_scheduler.executors import DefaultExecutionManager
9+
from jupyter_scheduler.executors import DefaultExecutionManager, TrackingExecutePreprocessor
810
from jupyter_scheduler.orm import Job
911

1012

@@ -58,3 +60,128 @@ def test_add_side_effects_files(
5860

5961
job = jp_scheduler_db.query(Job).filter(Job.job_id == job_id).one()
6062
assert side_effect_file_name in job.packaged_files
63+
64+
65+
@pytest.fixture
66+
def mock_cell():
67+
"""Create a mock notebook cell for testing"""
68+
cell = nbformat.v4.new_code_cell(source="print('test')")
69+
return cell
70+
71+
72+
@pytest.fixture
73+
def mock_resources():
74+
"""Create mock resources for testing"""
75+
return {"metadata": {"path": "/test/path"}}
76+
77+
78+
def test_tracking_execute_preprocessor_initialization():
79+
"""Test TrackingExecutePreprocessor initialization"""
80+
mock_db_session = MagicMock()
81+
job_id = "test-job-id"
82+
83+
preprocessor = TrackingExecutePreprocessor(
84+
db_session=mock_db_session,
85+
job_id=job_id,
86+
kernel_name="python3"
87+
)
88+
89+
assert preprocessor.db_session == mock_db_session
90+
assert preprocessor.job_id == job_id
91+
assert preprocessor.kernel_name == "python3"
92+
93+
94+
def test_tracking_execute_preprocessor_updates_database(mock_cell, mock_resources):
95+
"""Test that TrackingExecutePreprocessor updates the database after cell execution"""
96+
mock_db_session = MagicMock()
97+
mock_session_context = MagicMock()
98+
mock_db_session.return_value.__enter__.return_value = mock_session_context
99+
100+
job_id = "test-job-id"
101+
102+
with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute:
103+
with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'):
104+
preprocessor = TrackingExecutePreprocessor(
105+
db_session=mock_db_session,
106+
job_id=job_id,
107+
kernel_name="python3"
108+
)
109+
110+
# Mock the code_cells_executed attribute
111+
preprocessor.code_cells_executed = 3
112+
preprocessor.resources = mock_resources
113+
114+
# Mock the execute_cell method to return the cell
115+
mock_execute.return_value = mock_cell
116+
117+
# Call preprocess_cell
118+
result_cell, result_resources = preprocessor.preprocess_cell(mock_cell, mock_resources, 0)
119+
120+
# Verify the superclass method was called
121+
mock_execute.assert_called_once_with(mock_cell, 0, store_history=True)
122+
123+
# Verify database update was called
124+
mock_session_context.query.assert_called_once_with(Job)
125+
mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with(
126+
{"completed_cells": 3}
127+
)
128+
mock_session_context.commit.assert_called_once()
129+
130+
# Verify return values
131+
assert result_cell == mock_cell
132+
assert result_resources == mock_resources
133+
134+
135+
def test_tracking_execute_preprocessor_handles_database_errors(mock_cell, mock_resources):
136+
"""Test that TrackingExecutePreprocessor handles database errors gracefully"""
137+
mock_db_session = MagicMock()
138+
mock_session_context = MagicMock()
139+
mock_db_session.return_value.__enter__.return_value = mock_session_context
140+
141+
# Make the database update raise an exception
142+
mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception("DB Error")
143+
144+
job_id = "test-job-id"
145+
146+
with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute:
147+
with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'):
148+
preprocessor = TrackingExecutePreprocessor(
149+
db_session=mock_db_session,
150+
job_id=job_id,
151+
kernel_name="python3"
152+
)
153+
154+
preprocessor.code_cells_executed = 1
155+
preprocessor.resources = mock_resources
156+
mock_execute.return_value = mock_cell
157+
158+
# The database error should propagate
159+
with pytest.raises(Exception, match="DB Error"):
160+
preprocessor.preprocess_cell(mock_cell, mock_resources, 0)
161+
162+
163+
def test_tracking_execute_preprocessor_uses_correct_job_id(mock_cell, mock_resources):
164+
"""Test that TrackingExecutePreprocessor uses the correct job_id in database queries"""
165+
mock_db_session = MagicMock()
166+
mock_session_context = MagicMock()
167+
mock_db_session.return_value.__enter__.return_value = mock_session_context
168+
169+
job_id = "specific-job-id-123"
170+
171+
with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute:
172+
with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'):
173+
preprocessor = TrackingExecutePreprocessor(
174+
db_session=mock_db_session,
175+
job_id=job_id,
176+
kernel_name="python3"
177+
)
178+
179+
preprocessor.code_cells_executed = 2
180+
preprocessor.resources = mock_resources
181+
mock_execute.return_value = mock_cell
182+
183+
preprocessor.preprocess_cell(mock_cell, mock_resources, 0)
184+
185+
# Verify the correct job_id is used in the filter
186+
filter_call = mock_session_context.query.return_value.filter.call_args[0][0]
187+
assert str(filter_call).find(job_id) != -1 or filter_call.right.value == job_id

jupyter_scheduler/tests/test_handlers.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ async def test_get_jobs_for_single_job(jp_fetch):
131131
url="url_a",
132132
create_time=1664305872620,
133133
update_time=1664305872620,
134+
completed_cells=5,
134135
)
135136
response = await jp_fetch("scheduler", "jobs", job_id, method="GET")
136137

@@ -140,6 +141,7 @@ async def test_get_jobs_for_single_job(jp_fetch):
140141
assert body["job_id"] == job_id
141142
assert body["input_filename"]
142143
assert body["job_files"]
144+
assert body["completed_cells"] == 5
143145

144146

145147
@pytest.mark.parametrize(
@@ -320,6 +322,28 @@ async def test_patch_jobs(jp_fetch):
320322
mock_update_job.assert_called_once_with(job_id, UpdateJob(**body))
321323

322324

325+
async def test_patch_jobs_with_completed_cells(jp_fetch):
326+
with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job:
327+
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
328+
body = {"name": "updated job", "completed_cells": 10}
329+
response = await jp_fetch(
330+
"scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body)
331+
)
332+
assert response.code == 204
333+
mock_update_job.assert_called_once_with(job_id, UpdateJob(**body))
334+
335+
336+
async def test_patch_jobs_completed_cells_only(jp_fetch):
337+
with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job:
338+
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
339+
body = {"completed_cells": 15}
340+
response = await jp_fetch(
341+
"scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body)
342+
)
343+
assert response.code == 204
344+
mock_update_job.assert_called_once_with(job_id, UpdateJob(**body))
345+
346+
323347
async def test_patch_jobs_for_stop_job(jp_fetch):
324348
with patch("jupyter_scheduler.scheduler.Scheduler.stop_job") as mock_stop_job:
325349
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
@@ -677,3 +701,73 @@ async def test_delete_job_definition_for_unexpected_error(jp_fetch):
677701
assert expected_http_error(
678702
e, 500, "Unexpected error occurred while deleting the job definition."
679703
)
704+
705+
706+
# Model validation tests for completed_cells field
707+
def test_describe_job_completed_cells_validation():
708+
"""Test DescribeJob model validation for completed_cells field"""
709+
# Test valid integer values
710+
job_data = {
711+
"name": "test_job",
712+
"input_filename": "test.ipynb",
713+
"runtime_environment_name": "test_env",
714+
"job_id": "test-job-id",
715+
"url": "http://test.com/jobs/test-job-id",
716+
"create_time": 1234567890,
717+
"update_time": 1234567890,
718+
"completed_cells": 5
719+
}
720+
job = DescribeJob(**job_data)
721+
assert job.completed_cells == 5
722+
723+
# Test None value
724+
job_data["completed_cells"] = None
725+
job = DescribeJob(**job_data)
726+
assert job.completed_cells is None
727+
728+
# Test zero value
729+
job_data["completed_cells"] = 0
730+
job = DescribeJob(**job_data)
731+
assert job.completed_cells == 0
732+
733+
# Test invalid type
734+
job_data["completed_cells"] = "invalid"
735+
with pytest.raises(ValidationError):
736+
DescribeJob(**job_data)
737+
738+
739+
def test_update_job_completed_cells_validation():
740+
"""Test UpdateJob model validation for completed_cells field"""
741+
# Test valid integer values
742+
update_data = {"completed_cells": 10}
743+
update_job = UpdateJob(**update_data)
744+
assert update_job.completed_cells == 10
745+
746+
# Test None value
747+
update_data = {"completed_cells": None}
748+
update_job = UpdateJob(**update_data)
749+
assert update_job.completed_cells is None
750+
751+
# Test zero value
752+
update_data = {"completed_cells": 0}
753+
update_job = UpdateJob(**update_data)
754+
assert update_job.completed_cells == 0
755+
756+
# Test invalid type
757+
update_data = {"completed_cells": "invalid"}
758+
with pytest.raises(ValidationError):
759+
UpdateJob(**update_data)
760+
761+
# Test exclude_none behavior
762+
update_data = {"name": "test", "completed_cells": None}
763+
update_job = UpdateJob(**update_data)
764+
job_dict = update_job.dict(exclude_none=True)
765+
assert "completed_cells" not in job_dict
766+
assert job_dict["name"] == "test"
767+
768+
# Test include completed_cells when not None
769+
update_data = {"name": "test", "completed_cells": 5}
770+
update_job = UpdateJob(**update_data)
771+
job_dict = update_job.dict(exclude_none=True)
772+
assert job_dict["completed_cells"] == 5
773+
assert job_dict["name"] == "test"

0 commit comments

Comments
 (0)