33import typing as t
44from datetime import datetime , timezone
55
6- from traitlets import HasTraits , Type
76from jupyter_client .asynchronous .client import AsyncKernelClient
87from jupyter_client .channels import AsyncZMQSocketChannel
98from jupyter_client .channelsabc import ChannelABC
9+ from traitlets import HasTraits , Type
10+
11+ from .message_utils import encode_channel_in_message_dict , parse_msg_id
1012from .states import ExecutionStates
11- from .message_utils import parse_msg_id , encode_channel_in_message_dict
1213
1314
1415class NamedAsyncZMQSocketChannel (AsyncZMQSocketChannel ):
1516 """Prepends the channel name to all message IDs to this socket."""
17+
1618 channel_name = "unknown"
17-
19+
1820 def send (self , msg ):
1921 """Send a message with automatic channel encoding."""
2022 msg = encode_channel_in_message_dict (msg , self .channel_name )
21- return super ().send (msg )
22-
23-
23+ return super ().send (msg )
24+
25+
2426class ShellChannel (NamedAsyncZMQSocketChannel ):
2527 """Shell channel that automatically encodes 'shell' in outgoing msg_ids."""
28+
2629 channel_name = "shell"
2730
2831
29- class ControlChannel (AsyncZMQSocketChannel ):
32+ class ControlChannel (NamedAsyncZMQSocketChannel ):
3033 """Control channel that automatically encodes 'control' in outgoing msg_ids."""
34+
3135 channel_name = "control"
3236
3337
34- class StdinChannel (AsyncZMQSocketChannel ):
38+ class StdinChannel (NamedAsyncZMQSocketChannel ):
3539 """Stdin channel that automatically encodes 'stdin' in outgoing msg_ids."""
40+
3641 channel_name = "stdin"
3742
3843
@@ -77,7 +82,9 @@ class JupyterServerKernelClientMixin(HasTraits):
7782 # Connection test configuration
7883 connection_test_timeout : float = 120.0 # Total timeout for connection test in seconds
7984 connection_test_check_interval : float = 0.1 # How often to check for messages in seconds
80- connection_test_retry_interval : float = 10.0 # How often to retry kernel_info requests in seconds
85+ connection_test_retry_interval : float = (
86+ 10.0 # How often to retry kernel_info requests in seconds
87+ )
8188
8289 # Override channel classes to use our custom ones with automatic encoding
8390 shell_channel_class = Type (ShellChannel )
@@ -105,8 +112,8 @@ def __init__(self, *args, **kwargs):
105112 def add_listener (
106113 self ,
107114 callback : t .Callable [[str , list [bytes ]], None ],
108- msg_types : t .Optional [t . List [ t . Tuple [str , str ]]] = None ,
109- exclude_msg_types : t .Optional [t . List [ t . Tuple [str , str ]]] = None
115+ msg_types : t .Optional [list [ tuple [str , str ]]] = None ,
116+ exclude_msg_types : t .Optional [list [ tuple [str , str ]]] = None ,
110117 ):
111118 """Add a listener to be called when messages are received.
112119
@@ -128,8 +135,8 @@ def add_listener(
128135
129136 # Store the listener with its filter configuration
130137 self ._listeners [callback ] = {
131- ' msg_types' : set (msg_types ) if msg_types else None ,
132- ' exclude_msg_types' : set (exclude_msg_types ) if exclude_msg_types else None
138+ " msg_types" : set (msg_types ) if msg_types else None ,
139+ " exclude_msg_types" : set (exclude_msg_types ) if exclude_msg_types else None ,
133140 }
134141
135142 def remove_listener (self , callback : t .Callable [[str , list [bytes ]], None ]):
@@ -188,7 +195,7 @@ def _send_message(self, channel_name: str, msg: list[bytes]):
188195 channel = getattr (self , f"{ channel_name } _channel" , None )
189196 channel .session .send_raw (channel .socket , msg )
190197
191- except Exception as e :
198+ except Exception :
192199 self .log .warn ("Error handling incoming message." )
193200
194201 def handle_incoming_message (self , channel_name : str , msg : list [bytes ]):
@@ -225,7 +232,6 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
225232
226233 self ._send_message (channel_name , msg )
227234
228-
229235 def handle_outgoing_message (self , channel_name : str , msg : list [bytes ]):
230236 """Public API for manufacturing messages to send to kernel client listeners.
231237
@@ -246,17 +252,19 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]):
246252
247253 # Validate message format before routing
248254 if not msg or len (msg ) < 4 :
249- self .log .warning (f"Cannot route malformed message on { channel_name } : { len (msg ) if msg else 0 } parts (expected at least 4)" )
255+ self .log .warning (
256+ f"Cannot route malformed message on { channel_name } : { len (msg ) if msg else 0 } parts (expected at least 4)"
257+ )
250258 return
251259
252260 # Extract message type for filtering
253261 msg_type = None
254262 try :
255263 header = self .session .unpack (msg [0 ]) if msg and len (msg ) > 0 else {}
256- msg_type = header .get (' msg_type' , ' unknown' )
264+ msg_type = header .get (" msg_type" , " unknown" )
257265 except Exception as e :
258266 self .log .debug (f"Error extracting message type: { e } " )
259- msg_type = ' unknown'
267+ msg_type = " unknown"
260268
261269 # Create tasks for listeners that match the filter
262270 tasks = []
@@ -269,7 +277,9 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]):
269277 if tasks :
270278 await asyncio .gather (* tasks , return_exceptions = True )
271279
272- def _should_route_to_listener (self , msg_type : str , channel_name : str , filter_config : dict ) -> bool :
280+ def _should_route_to_listener (
281+ self , msg_type : str , channel_name : str , filter_config : dict
282+ ) -> bool :
273283 """Determine if a message should be routed to a listener based on its filter configuration.
274284
275285 Args:
@@ -280,8 +290,8 @@ def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_con
280290 Returns:
281291 bool: True if the message should be routed to the listener, False otherwise
282292 """
283- msg_types = filter_config .get (' msg_types' )
284- exclude_msg_types = filter_config .get (' exclude_msg_types' )
293+ msg_types = filter_config .get (" msg_types" )
294+ exclude_msg_types = filter_config .get (" exclude_msg_types" )
285295
286296 # If msg_types is specified (inclusion filter)
287297 if msg_types is not None :
@@ -303,7 +313,13 @@ async def _call_listener(self, listener: t.Callable, channel_name: str, msg: lis
303313 except Exception as e :
304314 self .log .error (f"Error in listener { listener } : { e } " )
305315
306- def _update_execution_state_from_status (self , channel_name : str , msg_dict : dict , parent_msg_id : str = None , execution_state : str = None ):
316+ def _update_execution_state_from_status (
317+ self ,
318+ channel_name : str ,
319+ msg_dict : dict ,
320+ parent_msg_id : str = None ,
321+ execution_state : str = None ,
322+ ):
307323 """Update execution state from a status message if it originated from shell channel.
308324
309325 This method checks if a status message on the iopub channel originated from a shell
@@ -366,7 +382,9 @@ def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict,
366382 if isinstance (content , bytes ):
367383 content = self .session .unpack (content )
368384 execution_state = content .get ("execution_state" )
369- self .log .debug (f"Ignoring status message - cannot parse parent channel (state would be: { execution_state } )" )
385+ self .log .debug (
386+ f"Ignoring status message - cannot parse parent channel (state would be: { execution_state } )"
387+ )
370388 except Exception as e :
371389 self .log .debug (f"Error updating execution state from status message: { e } " )
372390
@@ -391,10 +409,7 @@ async def broadcast_state(self):
391409 return
392410
393411 # Create status message
394- msg_dict = self .session .msg (
395- "status" ,
396- content = {"execution_state" : self .execution_state }
397- )
412+ msg_dict = self .session .msg ("status" , content = {"execution_state" : self .execution_state })
398413
399414 # Serialize the message
400415 # session.serialize() returns:
@@ -404,7 +419,9 @@ async def broadcast_state(self):
404419 # Skip delimiter (index 0) and signature (index 1) to get message parts
405420 # Result: [header, parent_header, metadata, content, buffers...]
406421 if len (serialized ) < 6 : # Need delimiter + signature + 4 message parts minimum
407- self .log .warning (f"broadcast_state: serialized message too short: { len (serialized )} parts" )
422+ self .log .warning (
423+ f"broadcast_state: serialized message too short: { len (serialized )} parts"
424+ )
408425 return
409426
410427 msg_parts = serialized [2 :] # Skip delimiter and signature
@@ -422,7 +439,7 @@ async def start_listening(self):
422439 self ._listening = True
423440
424441 # Monitor each channel for incoming messages
425- for channel_name in [' iopub' , ' shell' , ' stdin' , ' control' ]:
442+ for channel_name in [" iopub" , " shell" , " stdin" , " control" ]:
426443 channel = getattr (self , f"{ channel_name } _channel" , None )
427444 if channel and channel .is_alive ():
428445 task = asyncio .create_task (self ._monitor_channel_messages (channel_name , channel ))
@@ -433,12 +450,12 @@ async def start_listening(self):
433450 async def stop_listening (self ):
434451 """Stop listening for messages."""
435452 # Stop monitoring tasks
436- if hasattr (self , ' _monitoring_tasks' ):
453+ if hasattr (self , " _monitoring_tasks" ):
437454 for task in self ._monitoring_tasks :
438455 task .cancel ()
439456 self ._monitoring_tasks = []
440457
441- self .log .info (f "Stopped listening" )
458+ self .log .info ("Stopped listening" )
442459
443460 async def _monitor_channel_messages (self , channel_name : str , channel : ChannelABC ):
444461 """Monitor a channel for incoming messages and route them to listeners."""
@@ -466,7 +483,9 @@ async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC
466483 if msg_list and len (msg_list ) >= 5 :
467484 await self ._route_to_listeners (channel_name , msg_list [1 :])
468485 else :
469- self .log .warning (f"Received malformed message on { channel_name } : { len (msg_list ) if msg_list else 0 } parts" )
486+ self .log .warning (
487+ f"Received malformed message on { channel_name } : { len (msg_list ) if msg_list else 0 } parts"
488+ )
470489
471490 except Exception as e :
472491 # Log the error instead of silently ignoring it
@@ -514,7 +533,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
514533 await asyncio .gather (
515534 self ._send_kernel_info_shell (),
516535 self ._send_kernel_info_control (),
517- return_exceptions = True
536+ return_exceptions = True ,
518537 )
519538 except Exception as e :
520539 self .log .debug (f"Error sending initial kernel_info requests: { e } " )
@@ -531,25 +550,35 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
531550
532551 # Check if we've received any status messages since connection attempt
533552 # This indicates the kernel is connected, even if busy executing something
534- if self .last_shell_status_time and self .last_shell_status_time > connection_attempt_time :
553+ if (
554+ self .last_shell_status_time
555+ and self .last_shell_status_time > connection_attempt_time
556+ ):
535557 self .log .info ("Kernel communication test succeeded: received shell status message" )
536558 return True
537559
538- if self .last_control_status_time and self .last_control_status_time > connection_attempt_time :
539- self .log .info ("Kernel communication test succeeded: received control status message" )
560+ if (
561+ self .last_control_status_time
562+ and self .last_control_status_time > connection_attempt_time
563+ ):
564+ self .log .info (
565+ "Kernel communication test succeeded: received control status message"
566+ )
540567 return True
541568
542569 # Send kernel_info requests at regular intervals
543570 time_since_last_request = time .time () - last_kernel_info_time
544571 if time_since_last_request >= self .connection_test_retry_interval :
545- self .log .debug (f"Sending kernel_info requests to shell and control channels (elapsed: { elapsed :.1f} s)" )
572+ self .log .debug (
573+ f"Sending kernel_info requests to shell and control channels (elapsed: { elapsed :.1f} s)"
574+ )
546575
547576 try :
548577 # Send kernel_info to both channels in parallel (no reply expected)
549578 await asyncio .gather (
550579 self ._send_kernel_info_shell (),
551580 self ._send_kernel_info_control (),
552- return_exceptions = True
581+ return_exceptions = True ,
553582 )
554583 last_kernel_info_time = time .time ()
555584 except Exception as e :
@@ -564,7 +593,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool:
564593 async def _send_kernel_info_shell (self ):
565594 """Send kernel_info request on shell channel (no reply expected)."""
566595 try :
567- if hasattr (self , ' kernel_info' ):
596+ if hasattr (self , " kernel_info" ):
568597 # Send without waiting for reply
569598 self .kernel_info ()
570599 except Exception as e :
@@ -573,8 +602,8 @@ async def _send_kernel_info_shell(self):
573602 async def _send_kernel_info_control (self ):
574603 """Send kernel_info request on control channel (no reply expected)."""
575604 try :
576- if hasattr (self .control_channel , ' send' ):
577- msg = self .session .msg (' kernel_info_request' )
605+ if hasattr (self .control_channel , " send" ):
606+ msg = self .session .msg (" kernel_info_request" )
578607 # Channel wrapper will automatically encode channel in msg_id
579608 self .control_channel .send (msg )
580609 except Exception as e :
@@ -617,7 +646,7 @@ async def connect(self) -> bool:
617646 await self .start_listening ()
618647
619648 # Unpause heartbeat channel if method exists
620- if hasattr (self .hb_channel , ' unpause' ):
649+ if hasattr (self .hb_channel , " unpause" ):
621650 self .hb_channel .unpause ()
622651
623652 # Wait for heartbeat
@@ -631,7 +660,9 @@ async def connect(self) -> bool:
631660
632661 # Test kernel communication (handles retries internally)
633662 if not await self ._test_kernel_communication ():
634- self .log .error (f"Kernel communication test failed after { self .connection_test_timeout } s timeout" )
663+ self .log .error (
664+ f"Kernel communication test failed after { self .connection_test_timeout } s timeout"
665+ )
635666 return False
636667
637668 # Mark connection as ready and process queued messages
0 commit comments