Skip to content

Commit 1de776d

Browse files
committed
Move cell execution state to the document's awareness object
1 parent 8059059 commit 1de776d

File tree

4 files changed

+256
-54
lines changed

4 files changed

+256
-54
lines changed

jupyter_server_documents/kernels/kernel_client.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,45 @@ async def stop_listening(self):
100100
_listening_task: t.Optional[t.Awaitable] = Any(allow_none=True)
101101

102102
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
103-
"""Use the given session to send the message."""
103+
"""
104+
Handle incoming kernel messages and set up immediate cell execution state tracking.
105+
106+
This method processes incoming kernel messages and caches them for response mapping.
107+
Importantly, it detects execute_request messages and immediately sets the corresponding
108+
cell state to 'busy' to provide real-time feedback for queued cell executions.
109+
110+
This ensures that when multiple cells are executed simultaneously, all queued cells
111+
show a '*' prompt immediately, not just the currently executing cell.
112+
113+
Args:
114+
channel_name: The kernel channel name (shell, iopub, etc.)
115+
msg: The raw kernel message as bytes
116+
"""
104117
# Cache the message ID and its socket name so that
105118
# any response message can be mapped back to the
106119
# source channel.
107120
header = self.session.unpack(msg[0])
108-
msg_id = header["msg_id"]
121+
msg_id = header["msg_id"]
122+
msg_type = header.get("msg_type")
109123
metadata = self.session.unpack(msg[2])
110124
cell_id = metadata.get("cellId")
111125

112-
# Clear cell outputs if cell is re-executedq
126+
# Clear cell outputs if cell is re-executed
113127
if cell_id:
114128
existing = self.message_cache.get(cell_id=cell_id)
115129
if existing and existing['msg_id'] != msg_id:
116130
asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id))
117131

132+
# IMPORTANT: Set cell to 'busy' immediately when execute_request is received
133+
# This ensures queued cells show '*' prompt even before kernel starts processing them
134+
if msg_type == "execute_request" and channel_name == "shell" and cell_id:
135+
for yroom in self._yrooms:
136+
awareness = yroom.get_awareness()
137+
if awareness is not None:
138+
cell_states = awareness.get_local_state().get("cell_execution_states", {})
139+
cell_states[cell_id] = "busy"
140+
awareness.set_local_state_field("cell_execution_states", cell_states)
141+
118142
self.message_cache.add({
119143
"msg_id": msg_id,
120144
"channel": channel_name,
@@ -240,27 +264,27 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
240264
metadata["metadata"]["language_info"] = language_info
241265

242266
case "status":
243-
# Unpack cell-specific information and determine execution state
267+
# Handle kernel status messages and update cell execution states
268+
# This provides real-time feedback about cell execution progress
244269
content = self.session.unpack(dmsg["content"])
245270
execution_state = content.get("execution_state")
271+
246272
# Update status across all collaborative rooms
247273
for yroom in self._yrooms:
248-
# If this status came from the shell channel, update
249-
# the notebook status.
250-
if parent_msg_data["channel"] == "shell":
251-
awareness = yroom.get_awareness()
252-
if awareness is not None:
274+
awareness = yroom.get_awareness()
275+
if awareness is not None:
276+
# If this status came from the shell channel, update
277+
# the notebook kernel status.
278+
if parent_msg_data and parent_msg_data.get("channel") == "shell":
253279
# Update the kernel execution state at the top document level
254280
awareness.set_local_state_field("kernel", {"execution_state": execution_state})
255-
# Specifically update the running cell's execution state if cell_id is provided
256-
if cell_id:
257-
notebook = await yroom.get_jupyter_ydoc()
258-
_, target_cell = notebook.find_cell(cell_id)
259-
if target_cell:
260-
# Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
261-
# https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
262-
state = 'running' if execution_state == 'busy' else execution_state
263-
target_cell["execution_state"] = state
281+
282+
# Store cell execution state for persistence across client connections
283+
# This ensures that cell execution states survive page refreshes
284+
if cell_id:
285+
cell_states = awareness.get_local_state().get("cell_execution_states", {})
286+
cell_states[cell_id] = execution_state
287+
awareness.set_local_state_field("cell_execution_states", cell_states)
264288
break
265289

266290
case "execute_input":
@@ -278,8 +302,7 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
278302
case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output":
279303
if cell_id:
280304
# Process specific output messages through an optional processor
281-
if self.output_processor and cell_id:
282-
cell_id = parent_msg_data.get('cell_id')
305+
if self.output_processor:
283306
content = self.session.unpack(dmsg["content"])
284307
self.output_processor.process_output(dmsg['msg_type'], cell_id, content)
285308

jupyter_server_documents/rooms/yroom.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,23 @@ def get_awareness(self) -> pycrdt.Awareness:
369369
"""
370370
return self._awareness
371371

372+
def get_cell_execution_states(self) -> dict:
373+
"""
374+
Returns the persistent cell execution states for this room.
375+
These states survive client disconnections but are not saved to disk.
376+
"""
377+
if not hasattr(self, '_cell_execution_states'):
378+
self._cell_execution_states = {}
379+
return self._cell_execution_states
380+
381+
def set_cell_execution_state(self, cell_id: str, execution_state: str) -> None:
382+
"""
383+
Sets the execution state for a specific cell.
384+
This state persists across client disconnections.
385+
"""
386+
if not hasattr(self, '_cell_execution_states'):
387+
self._cell_execution_states = {}
388+
self._cell_execution_states[cell_id] = execution_state
372389

373390
def add_message(self, client_id: str, message: bytes) -> None:
374391
"""
@@ -512,7 +529,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
512529
return
513530

514531
self.clients.mark_synced(client_id)
515-
532+
516533
# Send SyncStep1 message
517534
try:
518535
assert isinstance(new_client.websocket, WebSocketHandler)

jupyter_server_documents/tests/kernels/test_kernel_client_integration.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ def mock_yroom_with_notebook(self):
2020
# Create a real YDoc and YNotebook
2121
ydoc = pycrdt.Doc()
2222
awareness = MagicMock(spec=pycrdt.Awareness) # Mock awareness instead of using real one
23+
24+
# Mock the local state to track changes
25+
local_state = {}
26+
awareness.get_local_state = MagicMock(return_value=local_state)
27+
28+
# Mock set_local_state_field to actually update the local_state dict
29+
def mock_set_local_state_field(field, value):
30+
local_state[field] = value
31+
32+
awareness.set_local_state_field = MagicMock(side_effect=mock_set_local_state_field)
33+
2334
ynotebook = YNotebook(ydoc, awareness)
2435

2536
# Add a simple notebook structure with one cell
@@ -54,6 +65,18 @@ def mock_yroom_with_notebook(self):
5465
yroom.get_jupyter_ydoc = AsyncMock(return_value=ynotebook)
5566
yroom.get_awareness = MagicMock(return_value=awareness)
5667

68+
# Add persistent cell execution state storage
69+
yroom._cell_execution_states = {}
70+
71+
def mock_get_cell_execution_states():
72+
return yroom._cell_execution_states
73+
74+
def mock_set_cell_execution_state(cell_id, execution_state):
75+
yroom._cell_execution_states[cell_id] = execution_state
76+
77+
yroom.get_cell_execution_states = MagicMock(side_effect=mock_get_cell_execution_states)
78+
yroom.set_cell_execution_state = MagicMock(side_effect=mock_set_cell_execution_state)
79+
5780
return yroom, ynotebook
5881

5982
@pytest.fixture
@@ -109,7 +132,7 @@ async def test_execute_input_updates_execution_count(self, kernel_client_with_yr
109132

110133
@pytest.mark.asyncio
111134
async def test_status_message_updates_cell_execution_state(self, kernel_client_with_yroom):
112-
"""Test that status messages update cell execution state in YDoc."""
135+
"""Test that status messages update cell execution state in YRoom for persistence and awareness for real-time updates."""
113136
client, yroom, ynotebook = kernel_client_with_yroom
114137

115138
# Mock message cache to return cell_id and channel
@@ -129,11 +152,16 @@ async def test_status_message_updates_cell_execution_state(self, kernel_client_w
129152
# Process the message
130153
await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter
131154

132-
# Verify the cell execution state was updated to 'running' (converted from 'busy')
133-
cells = ynotebook.get_cell_list()
134-
target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None)
135-
assert target_cell is not None
136-
assert target_cell.get("execution_state") == "running"
155+
# Verify the cell execution state was stored in YRoom for persistence
156+
cell_states = yroom.get_cell_execution_states()
157+
assert cell_states[cell_id] == "busy"
158+
159+
# Verify it's also in awareness for real-time updates
160+
awareness = yroom.get_awareness()
161+
local_state = awareness.get_local_state()
162+
assert local_state is not None
163+
assert "cell_execution_states" in local_state
164+
assert local_state["cell_execution_states"][cell_id] == "busy"
137165

138166
@pytest.mark.asyncio
139167
async def test_kernel_info_reply_updates_language_info(self, kernel_client_with_yroom):
@@ -304,12 +332,16 @@ async def test_complete_execution_flow(self, kernel_client_with_yroom):
304332
)
305333
await client.handle_document_related_message(msg_parts[1:])
306334

307-
# Verify final state of the cell in YDoc
335+
# Verify final state of the cell in YDoc and awareness
308336
cells = ynotebook.get_cell_list()
309337
target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None)
310338
assert target_cell is not None
311339
assert target_cell.get("execution_count") == 1
312-
assert target_cell.get("execution_state") == "idle"
340+
341+
# Verify execution state is stored in awareness, not YDoc
342+
awareness = yroom.get_awareness()
343+
cell_execution_states = awareness.get_local_state().get("cell_execution_states", {})
344+
assert cell_execution_states.get(cell_id) == "idle"
313345

314346
# Verify output processor was called for the result
315347
client.output_processor.process_output.assert_called_with(
@@ -383,15 +415,12 @@ def mock_get(msg_id):
383415
)
384416
await client.handle_document_related_message(msg_parts2[1:])
385417

386-
# Verify both cells have correct states
387-
cells = ynotebook.get_cell_list()
388-
cell1 = next((cell for cell in cells if cell.get("id") == "test-cell-1"), None)
389-
cell2 = next((cell for cell in cells if cell.get("id") == "test-cell-2"), None)
418+
# Verify both cells have correct states in awareness
419+
awareness = yroom.get_awareness()
420+
cell_execution_states = awareness.get_local_state().get("cell_execution_states", {})
390421

391-
assert cell1 is not None
392-
assert cell1.get("execution_state") == "running" # 'busy' -> 'running'
393-
assert cell2 is not None
394-
assert cell2.get("execution_state") == "idle"
422+
assert cell_execution_states.get("test-cell-1") == "busy" # 'busy' state
423+
assert cell_execution_states.get("test-cell-2") == "idle"
395424

396425
@pytest.mark.asyncio
397426
async def test_message_without_cell_id_skips_cell_updates(self, kernel_client_with_yroom):

0 commit comments

Comments
 (0)