Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pytest

pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin")
pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin", "pytest_asyncio")


def pytest_configure(config):
"""Configure pytest settings."""
# Set asyncio fixture loop scope to function to avoid warnings
config.option.asyncio_default_fixture_loop_scope = "function"


@pytest.fixture
Expand Down
59 changes: 39 additions & 20 deletions jupyter_server_documents/kernels/kernel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,41 @@ async def stop_listening(self):
_listening_task: t.Optional[t.Awaitable] = Any(allow_none=True)

def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
"""Use the given session to send the message."""
"""
Handle incoming kernel messages and set up immediate cell execution state tracking.

This method processes incoming kernel messages and caches them for response mapping.
Importantly, it detects execute_request messages and immediately sets the corresponding
cell state to 'busy' to provide real-time feedback for queued cell executions.

This ensures that when multiple cells are executed simultaneously, all queued cells
show a '*' prompt immediately, not just the currently executing cell.

Args:
channel_name: The kernel channel name (shell, iopub, etc.)
msg: The raw kernel message as bytes
"""
# Cache the message ID and its socket name so that
# any response message can be mapped back to the
# source channel.
header = self.session.unpack(msg[0])
msg_id = header["msg_id"]
msg_id = header["msg_id"]
msg_type = header.get("msg_type")
metadata = self.session.unpack(msg[2])
cell_id = metadata.get("cellId")

# Clear cell outputs if cell is re-executedq
# Clear cell outputs if cell is re-executed
if cell_id:
existing = self.message_cache.get(cell_id=cell_id)
if existing and existing['msg_id'] != msg_id:
asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id))

# IMPORTANT: Set cell to 'busy' immediately when execute_request is received
# This ensures queued cells show '*' prompt even before kernel starts processing them
if msg_type == "execute_request" and channel_name == "shell" and cell_id:
for yroom in self._yrooms:
yroom.set_cell_awareness_state(cell_id, "busy")

self.message_cache.add({
"msg_id": msg_id,
"channel": channel_name,
Expand Down Expand Up @@ -240,27 +260,27 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
metadata["metadata"]["language_info"] = language_info

case "status":
# Unpack cell-specific information and determine execution state
# Handle kernel status messages and update cell execution states
# This provides real-time feedback about cell execution progress
content = self.session.unpack(dmsg["content"])
execution_state = content.get("execution_state")

# Update status across all collaborative rooms
for yroom in self._yrooms:
# If this status came from the shell channel, update
# the notebook status.
if parent_msg_data["channel"] == "shell":
awareness = yroom.get_awareness()
if awareness is not None:
awareness = yroom.get_awareness()
if awareness is not None:
# If this status came from the shell channel, update
# the notebook kernel status.
if parent_msg_data and parent_msg_data.get("channel") == "shell":
# Update the kernel execution state at the top document level
awareness.set_local_state_field("kernel", {"execution_state": execution_state})
# Specifically update the running cell's execution state if cell_id is provided
if cell_id:
notebook = await yroom.get_jupyter_ydoc()
_, target_cell = notebook.find_cell(cell_id)
if target_cell:
# Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
# https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
state = 'running' if execution_state == 'busy' else execution_state
target_cell["execution_state"] = state

# Store cell execution state for persistence across client connections
# This ensures that cell execution states survive page refreshes
if cell_id:
for yroom in self._yrooms:
yroom.set_cell_execution_state(cell_id, execution_state)
yroom.set_cell_awareness_state(cell_id, execution_state)
break

case "execute_input":
Expand All @@ -278,8 +298,7 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output":
if cell_id:
# Process specific output messages through an optional processor
if self.output_processor and cell_id:
cell_id = parent_msg_data.get('cell_id')
if self.output_processor:
content = self.session.unpack(dmsg["content"])
self.output_processor.process_output(dmsg['msg_type'], cell_id, content)

Expand Down
36 changes: 34 additions & 2 deletions jupyter_server_documents/rooms/yroom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations # see PEP-563 for motivation behind this
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, cast, Any
import asyncio
import uuid
import pycrdt
Expand Down Expand Up @@ -369,6 +369,38 @@ def get_awareness(self) -> pycrdt.Awareness:
"""
return self._awareness

def get_cell_execution_states(self) -> dict:
"""
Returns the persistent cell execution states for this room.
These states survive client disconnections but are not saved to disk.
"""
if not hasattr(self, '_cell_execution_states'):
self._cell_execution_states: dict[str, str] = {}
return self._cell_execution_states

def set_cell_execution_state(self, cell_id: str, execution_state: str) -> None:
"""
Sets the execution state for a specific cell.
This state persists across client disconnections.
"""
if not hasattr(self, '_cell_execution_states'):
self._cell_execution_states = {}
self._cell_execution_states[cell_id] = execution_state

def set_cell_awareness_state(self, cell_id: str, execution_state: str) -> None:
"""
Sets the execution state for a specific cell in the awareness system.
This provides real-time updates to all connected clients.
"""
awareness = self.get_awareness()
if awareness is not None:
local_state = awareness.get_local_state()
if local_state is not None:
cell_states = local_state.get("cell_execution_states", {})
else:
cell_states = {}
cell_states[cell_id] = execution_state
awareness.set_local_state_field("cell_execution_states", cell_states)

def add_message(self, client_id: str, message: bytes) -> None:
"""
Expand Down Expand Up @@ -512,7 +544,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
return

self.clients.mark_synced(client_id)

# Send SyncStep1 message
try:
assert isinstance(new_client.websocket, WebSocketHandler)
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions jupyter_server_documents/tests/kernels/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Configuration for kernel tests."""

import pytest
from unittest.mock import MagicMock


@pytest.fixture
def mock_logger():
"""Create a mock logger for testing."""
return MagicMock()


@pytest.fixture
def mock_session():
"""Create a mock session for testing."""
session = MagicMock()
session.msg_header.return_value = {"msg_id": "test-msg-id"}
session.msg.return_value = {"test": "message"}
session.serialize.return_value = ["", "serialized", "msg"]
session.deserialize.return_value = {"msg_type": "test", "content": b"test"}
session.unpack.return_value = {"test": "data"}
session.feed_identities.return_value = ([], [b"test", b"message"])
return session
105 changes: 105 additions & 0 deletions jupyter_server_documents/tests/kernels/test_kernel_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from unittest.mock import MagicMock, patch

from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient
from jupyter_server_documents.kernels.message_cache import KernelMessageCache
from jupyter_server_documents.outputs import OutputProcessor


class TestDocumentAwareKernelClient:
"""Test cases for DocumentAwareKernelClient."""

def test_default_message_cache(self):
"""Test that message cache is created by default."""
client = DocumentAwareKernelClient()
assert isinstance(client.message_cache, KernelMessageCache)

def test_default_output_processor(self):
"""Test that output processor is created by default."""
client = DocumentAwareKernelClient()
assert isinstance(client.output_processor, OutputProcessor)

@pytest.mark.asyncio
async def test_stop_listening_no_task(self):
"""Test that stop_listening does nothing when no task exists."""
client = DocumentAwareKernelClient()
client._listening_task = None

# Should not raise an exception
await client.stop_listening()

def test_add_listener(self):
"""Test adding a listener."""
client = DocumentAwareKernelClient()

def test_listener(channel, msg):
pass

client.add_listener(test_listener)

assert test_listener in client._listeners

def test_remove_listener(self):
"""Test removing a listener."""
client = DocumentAwareKernelClient()

def test_listener(channel, msg):
pass

client.add_listener(test_listener)
client.remove_listener(test_listener)

assert test_listener not in client._listeners

@pytest.mark.asyncio
async def test_add_yroom(self):
"""Test adding a YRoom."""
client = DocumentAwareKernelClient()

mock_yroom = MagicMock()
await client.add_yroom(mock_yroom)

assert mock_yroom in client._yrooms

@pytest.mark.asyncio
async def test_remove_yroom(self):
"""Test removing a YRoom."""
client = DocumentAwareKernelClient()

mock_yroom = MagicMock()
client._yrooms.add(mock_yroom)

await client.remove_yroom(mock_yroom)

assert mock_yroom not in client._yrooms

def test_send_kernel_info_creates_message(self):
"""Test that send_kernel_info creates a kernel info message."""
client = DocumentAwareKernelClient()

# Mock session
from jupyter_client.session import Session
client.session = Session()

with patch.object(client, 'handle_incoming_message') as mock_handle:
client.send_kernel_info()

# Verify that handle_incoming_message was called with shell channel
mock_handle.assert_called_once()
args, kwargs = mock_handle.call_args
assert args[0] == "shell" # Channel name
assert isinstance(args[1], list) # Message list

@pytest.mark.asyncio
async def test_handle_outgoing_message_control_channel(self):
"""Test that control channel messages bypass document handling."""
client = DocumentAwareKernelClient()

msg = [b"test", b"message"]

with patch.object(client, 'handle_document_related_message') as mock_handle_doc:
with patch.object(client, 'send_message_to_listeners') as mock_send:
await client.handle_outgoing_message("control", msg)

mock_handle_doc.assert_not_called()
mock_send.assert_called_once_with("control", msg)
Loading
Loading