Skip to content

Commit 2100b20

Browse files
committed
Improve code formatting and consistency in v3 kernel client
- Fix import ordering and add blank lines for better readability - Improve code formatting with consistent spacing and line breaks - Remove unused exception variable and fix minor style issues - Ensure consistent inheritance for channel classes - Update type annotations to use modern Python syntax
1 parent aaf8d5b commit 2100b20

File tree

1 file changed

+74
-43
lines changed

1 file changed

+74
-43
lines changed

jupyter_server/services/kernels/v3/client.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,41 @@
33
import typing as t
44
from datetime import datetime, timezone
55

6-
from traitlets import HasTraits, Type
76
from jupyter_client.asynchronous.client import AsyncKernelClient
87
from jupyter_client.channels import AsyncZMQSocketChannel
98
from jupyter_client.channelsabc import ChannelABC
9+
from traitlets import HasTraits, Type
10+
11+
from .message_utils import encode_channel_in_message_dict, parse_msg_id
1012
from .states import ExecutionStates
11-
from .message_utils import parse_msg_id, encode_channel_in_message_dict
1213

1314

1415
class NamedAsyncZMQSocketChannel(AsyncZMQSocketChannel):
1516
"""Prepends the channel name to all message IDs to this socket."""
17+
1618
channel_name = "unknown"
17-
19+
1820
def send(self, msg):
1921
"""Send a message with automatic channel encoding."""
2022
msg = encode_channel_in_message_dict(msg, self.channel_name)
21-
return super().send(msg)
22-
23-
23+
return super().send(msg)
24+
25+
2426
class ShellChannel(NamedAsyncZMQSocketChannel):
2527
"""Shell channel that automatically encodes 'shell' in outgoing msg_ids."""
28+
2629
channel_name = "shell"
2730

2831

29-
class ControlChannel(AsyncZMQSocketChannel):
32+
class ControlChannel(NamedAsyncZMQSocketChannel):
3033
"""Control channel that automatically encodes 'control' in outgoing msg_ids."""
34+
3135
channel_name = "control"
3236

3337

34-
class StdinChannel(AsyncZMQSocketChannel):
38+
class StdinChannel(NamedAsyncZMQSocketChannel):
3539
"""Stdin channel that automatically encodes 'stdin' in outgoing msg_ids."""
40+
3641
channel_name = "stdin"
3742

3843

@@ -77,7 +82,9 @@ class JupyterServerKernelClientMixin(HasTraits):
7782
# Connection test configuration
7883
connection_test_timeout: float = 120.0 # Total timeout for connection test in seconds
7984
connection_test_check_interval: float = 0.1 # How often to check for messages in seconds
80-
connection_test_retry_interval: float = 10.0 # How often to retry kernel_info requests in seconds
85+
connection_test_retry_interval: float = (
86+
10.0 # How often to retry kernel_info requests in seconds
87+
)
8188

8289
# Override channel classes to use our custom ones with automatic encoding
8390
shell_channel_class = Type(ShellChannel)
@@ -105,8 +112,8 @@ def __init__(self, *args, **kwargs):
105112
def add_listener(
106113
self,
107114
callback: t.Callable[[str, list[bytes]], None],
108-
msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None,
109-
exclude_msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None
115+
msg_types: t.Optional[list[tuple[str, str]]] = None,
116+
exclude_msg_types: t.Optional[list[tuple[str, str]]] = None,
110117
):
111118
"""Add a listener to be called when messages are received.
112119
@@ -128,8 +135,8 @@ def add_listener(
128135

129136
# Store the listener with its filter configuration
130137
self._listeners[callback] = {
131-
'msg_types': set(msg_types) if msg_types else None,
132-
'exclude_msg_types': set(exclude_msg_types) if exclude_msg_types else None
138+
"msg_types": set(msg_types) if msg_types else None,
139+
"exclude_msg_types": set(exclude_msg_types) if exclude_msg_types else None,
133140
}
134141

135142
def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]):
@@ -188,7 +195,7 @@ def _send_message(self, channel_name: str, msg: list[bytes]):
188195
channel = getattr(self, f"{channel_name}_channel", None)
189196
channel.session.send_raw(channel.socket, msg)
190197

191-
except Exception as e:
198+
except Exception:
192199
self.log.warn("Error handling incoming message.")
193200

194201
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
@@ -225,7 +232,6 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
225232

226233
self._send_message(channel_name, msg)
227234

228-
229235
def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
230236
"""Public API for manufacturing messages to send to kernel client listeners.
231237
@@ -246,17 +252,19 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]):
246252

247253
# Validate message format before routing
248254
if not msg or len(msg) < 4:
249-
self.log.warning(f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)")
255+
self.log.warning(
256+
f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)"
257+
)
250258
return
251259

252260
# Extract message type for filtering
253261
msg_type = None
254262
try:
255263
header = self.session.unpack(msg[0]) if msg and len(msg) > 0 else {}
256-
msg_type = header.get('msg_type', 'unknown')
264+
msg_type = header.get("msg_type", "unknown")
257265
except Exception as e:
258266
self.log.debug(f"Error extracting message type: {e}")
259-
msg_type = 'unknown'
267+
msg_type = "unknown"
260268

261269
# Create tasks for listeners that match the filter
262270
tasks = []
@@ -269,7 +277,9 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]):
269277
if tasks:
270278
await asyncio.gather(*tasks, return_exceptions=True)
271279

272-
def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_config: dict) -> bool:
280+
def _should_route_to_listener(
281+
self, msg_type: str, channel_name: str, filter_config: dict
282+
) -> bool:
273283
"""Determine if a message should be routed to a listener based on its filter configuration.
274284
275285
Args:
@@ -280,8 +290,8 @@ def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_con
280290
Returns:
281291
bool: True if the message should be routed to the listener, False otherwise
282292
"""
283-
msg_types = filter_config.get('msg_types')
284-
exclude_msg_types = filter_config.get('exclude_msg_types')
293+
msg_types = filter_config.get("msg_types")
294+
exclude_msg_types = filter_config.get("exclude_msg_types")
285295

286296
# If msg_types is specified (inclusion filter)
287297
if msg_types is not None:
@@ -303,7 +313,13 @@ async def _call_listener(self, listener: t.Callable, channel_name: str, msg: lis
303313
except Exception as e:
304314
self.log.error(f"Error in listener {listener}: {e}")
305315

306-
def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict, parent_msg_id: str = None, execution_state: str = None):
316+
def _update_execution_state_from_status(
317+
self,
318+
channel_name: str,
319+
msg_dict: dict,
320+
parent_msg_id: str = None,
321+
execution_state: str = None,
322+
):
307323
"""Update execution state from a status message if it originated from shell channel.
308324
309325
This method checks if a status message on the iopub channel originated from a shell
@@ -366,7 +382,9 @@ def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict,
366382
if isinstance(content, bytes):
367383
content = self.session.unpack(content)
368384
execution_state = content.get("execution_state")
369-
self.log.debug(f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})")
385+
self.log.debug(
386+
f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})"
387+
)
370388
except Exception as e:
371389
self.log.debug(f"Error updating execution state from status message: {e}")
372390

@@ -391,10 +409,7 @@ async def broadcast_state(self):
391409
return
392410

393411
# Create status message
394-
msg_dict = self.session.msg(
395-
"status",
396-
content={"execution_state": self.execution_state}
397-
)
412+
msg_dict = self.session.msg("status", content={"execution_state": self.execution_state})
398413

399414
# Serialize the message
400415
# session.serialize() returns:
@@ -404,7 +419,9 @@ async def broadcast_state(self):
404419
# Skip delimiter (index 0) and signature (index 1) to get message parts
405420
# Result: [header, parent_header, metadata, content, buffers...]
406421
if len(serialized) < 6: # Need delimiter + signature + 4 message parts minimum
407-
self.log.warning(f"broadcast_state: serialized message too short: {len(serialized)} parts")
422+
self.log.warning(
423+
f"broadcast_state: serialized message too short: {len(serialized)} parts"
424+
)
408425
return
409426

410427
msg_parts = serialized[2:] # Skip delimiter and signature
@@ -422,7 +439,7 @@ async def start_listening(self):
422439
self._listening = True
423440

424441
# Monitor each channel for incoming messages
425-
for channel_name in ['iopub', 'shell', 'stdin', 'control']:
442+
for channel_name in ["iopub", "shell", "stdin", "control"]:
426443
channel = getattr(self, f"{channel_name}_channel", None)
427444
if channel and channel.is_alive():
428445
task = asyncio.create_task(self._monitor_channel_messages(channel_name, channel))
@@ -433,12 +450,12 @@ async def start_listening(self):
433450
async def stop_listening(self):
434451
"""Stop listening for messages."""
435452
# Stop monitoring tasks
436-
if hasattr(self, '_monitoring_tasks'):
453+
if hasattr(self, "_monitoring_tasks"):
437454
for task in self._monitoring_tasks:
438455
task.cancel()
439456
self._monitoring_tasks = []
440457

441-
self.log.info(f"Stopped listening")
458+
self.log.info("Stopped listening")
442459

443460
async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC):
444461
"""Monitor a channel for incoming messages and route them to listeners."""
@@ -466,7 +483,9 @@ async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC
466483
if msg_list and len(msg_list) >= 5:
467484
await self._route_to_listeners(channel_name, msg_list[1:])
468485
else:
469-
self.log.warning(f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts")
486+
self.log.warning(
487+
f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts"
488+
)
470489

471490
except Exception as e:
472491
# Log the error instead of silently ignoring it
@@ -514,7 +533,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
514533
await asyncio.gather(
515534
self._send_kernel_info_shell(),
516535
self._send_kernel_info_control(),
517-
return_exceptions=True
536+
return_exceptions=True,
518537
)
519538
except Exception as e:
520539
self.log.debug(f"Error sending initial kernel_info requests: {e}")
@@ -531,25 +550,35 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
531550

532551
# Check if we've received any status messages since connection attempt
533552
# This indicates the kernel is connected, even if busy executing something
534-
if self.last_shell_status_time and self.last_shell_status_time > connection_attempt_time:
553+
if (
554+
self.last_shell_status_time
555+
and self.last_shell_status_time > connection_attempt_time
556+
):
535557
self.log.info("Kernel communication test succeeded: received shell status message")
536558
return True
537559

538-
if self.last_control_status_time and self.last_control_status_time > connection_attempt_time:
539-
self.log.info("Kernel communication test succeeded: received control status message")
560+
if (
561+
self.last_control_status_time
562+
and self.last_control_status_time > connection_attempt_time
563+
):
564+
self.log.info(
565+
"Kernel communication test succeeded: received control status message"
566+
)
540567
return True
541568

542569
# Send kernel_info requests at regular intervals
543570
time_since_last_request = time.time() - last_kernel_info_time
544571
if time_since_last_request >= self.connection_test_retry_interval:
545-
self.log.debug(f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)")
572+
self.log.debug(
573+
f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)"
574+
)
546575

547576
try:
548577
# Send kernel_info to both channels in parallel (no reply expected)
549578
await asyncio.gather(
550579
self._send_kernel_info_shell(),
551580
self._send_kernel_info_control(),
552-
return_exceptions=True
581+
return_exceptions=True,
553582
)
554583
last_kernel_info_time = time.time()
555584
except Exception as e:
@@ -564,7 +593,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
564593
async def _send_kernel_info_shell(self):
565594
"""Send kernel_info request on shell channel (no reply expected)."""
566595
try:
567-
if hasattr(self, 'kernel_info'):
596+
if hasattr(self, "kernel_info"):
568597
# Send without waiting for reply
569598
self.kernel_info()
570599
except Exception as e:
@@ -573,8 +602,8 @@ async def _send_kernel_info_shell(self):
573602
async def _send_kernel_info_control(self):
574603
"""Send kernel_info request on control channel (no reply expected)."""
575604
try:
576-
if hasattr(self.control_channel, 'send'):
577-
msg = self.session.msg('kernel_info_request')
605+
if hasattr(self.control_channel, "send"):
606+
msg = self.session.msg("kernel_info_request")
578607
# Channel wrapper will automatically encode channel in msg_id
579608
self.control_channel.send(msg)
580609
except Exception as e:
@@ -617,7 +646,7 @@ async def connect(self) -> bool:
617646
await self.start_listening()
618647

619648
# Unpause heartbeat channel if method exists
620-
if hasattr(self.hb_channel, 'unpause'):
649+
if hasattr(self.hb_channel, "unpause"):
621650
self.hb_channel.unpause()
622651

623652
# Wait for heartbeat
@@ -631,7 +660,9 @@ async def connect(self) -> bool:
631660

632661
# Test kernel communication (handles retries internally)
633662
if not await self._test_kernel_communication():
634-
self.log.error(f"Kernel communication test failed after {self.connection_test_timeout}s timeout")
663+
self.log.error(
664+
f"Kernel communication test failed after {self.connection_test_timeout}s timeout"
665+
)
635666
return False
636667

637668
# Mark connection as ready and process queued messages

0 commit comments

Comments
 (0)