|
| 1 | +"""Gateway kernel manager that integrates with our kernel monitoring system.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from jupyter_server.gateway.managers import GatewayMappingKernelManager |
| 5 | +from jupyter_server.gateway.managers import GatewayKernelManager as _GatewayKernelManager |
| 6 | +from jupyter_server.gateway.managers import GatewayKernelClient as _GatewayKernelClient |
| 7 | +from traitlets import default, Instance, Type |
| 8 | + |
| 9 | +from jupyter_server.services.kernels.v3.client import JupyterServerKernelClientMixin |
| 10 | + |
| 11 | + |
| 12 | +class GatewayKernelClient(JupyterServerKernelClientMixin, _GatewayKernelClient): |
| 13 | + """ |
| 14 | + Gateway kernel client that combines our monitoring capabilities with gateway support. |
| 15 | +
|
| 16 | + This client inherits from: |
| 17 | + - JupyterServerKernelClientMixin: Provides kernel monitoring capabilities, message caching, |
| 18 | + and execution state tracking that integrates with our kernel monitor system |
| 19 | + - GatewayKernelClient: Provides gateway communication capabilities for remote kernels |
| 20 | +
|
| 21 | + The combination allows remote gateway kernels to be monitored with the same level of |
| 22 | + detail as local kernels, including heartbeat monitoring, execution state tracking, |
| 23 | + and kernel lifecycle management. |
| 24 | + """ |
| 25 | + |
| 26 | + async def _test_kernel_communication(self, timeout: float = 10.0) -> bool: |
| 27 | + """Skip kernel_info test for gateway kernels. |
| 28 | +
|
| 29 | + Gateway kernels handle communication differently and the kernel_info |
| 30 | + test can hang due to message routing differences. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + bool: Always returns True for gateway kernels |
| 34 | + """ |
| 35 | + return True |
| 36 | + |
| 37 | + def _send_message(self, channel_name: str, msg: list[bytes]): |
| 38 | + # Send to gateway channel |
| 39 | + try: |
| 40 | + channel = getattr(self, f"{channel_name}_channel", None) |
| 41 | + if channel and hasattr(channel, 'send'): |
| 42 | + # Convert raw message to gateway format |
| 43 | + header = self.session.unpack(msg[0]) |
| 44 | + parent_header = self.session.unpack(msg[1]) |
| 45 | + metadata = self.session.unpack(msg[2]) |
| 46 | + content = self.session.unpack(msg[3]) |
| 47 | + |
| 48 | + full_msg = { |
| 49 | + 'header': header, |
| 50 | + 'parent_header': parent_header, |
| 51 | + 'metadata': metadata, |
| 52 | + 'content': content, |
| 53 | + 'buffers': msg[4:] if len(msg) > 4 else [], |
| 54 | + 'channel': channel_name, |
| 55 | + 'msg_id': header.get('msg_id'), |
| 56 | + 'msg_type': header.get('msg_type') |
| 57 | + } |
| 58 | + |
| 59 | + channel.send(full_msg) |
| 60 | + except Exception as e: |
| 61 | + self.log.warn(f"Error handling incoming message on gateway: {e}") |
| 62 | + |
| 63 | + async def _monitor_channel_messages(self, channel_name: str, channel): |
| 64 | + """Monitor a gateway channel for incoming messages.""" |
| 65 | + try: |
| 66 | + while channel.is_alive(): |
| 67 | + try: |
| 68 | + # Get message from gateway channel queue |
| 69 | + message = await channel.get_msg() |
| 70 | + |
| 71 | + # Update execution state from status messages |
| 72 | + # Gateway messages are already deserialized dicts |
| 73 | + self._update_execution_state_from_status( |
| 74 | + channel_name, |
| 75 | + message, |
| 76 | + parent_msg_id=message.get("parent_header", {}).get("msg_id"), |
| 77 | + execution_state=message.get("content", {}).get("execution_state") |
| 78 | + ) |
| 79 | + |
| 80 | + # Serialize message to standard format for listeners |
| 81 | + # Gateway messages are dicts, convert to list[bytes] format |
| 82 | + # session.serialize() returns: [b'<IDS|MSG>', signature, header, parent_header, metadata, content, buffers...] |
| 83 | + serialized = self.session.serialize(message) |
| 84 | + |
| 85 | + # Skip delimiter (index 0) and signature (index 1) to get [header, parent_header, metadata, content, ...] |
| 86 | + if serialized and len(serialized) >= 6: # Need delimiter + signature + 4 message parts |
| 87 | + msg_list = serialized[2:] |
| 88 | + else: |
| 89 | + self.log.warning(f"Gateway message too short: {len(serialized) if serialized else 0} parts") |
| 90 | + continue |
| 91 | + |
| 92 | + # Route to listeners |
| 93 | + await self._route_to_listeners(channel_name, msg_list) |
| 94 | + |
| 95 | + except asyncio.TimeoutError: |
| 96 | + # No message available, continue loop |
| 97 | + continue |
| 98 | + except Exception as e: |
| 99 | + self.log.debug(f"Error processing gateway message in {channel_name}: {e}") |
| 100 | + continue |
| 101 | + |
| 102 | + await asyncio.sleep(0.01) |
| 103 | + |
| 104 | + except asyncio.CancelledError: |
| 105 | + pass |
| 106 | + except Exception as e: |
| 107 | + self.log.error(f"Gateway channel monitoring failed for {channel_name}: {e}") |
| 108 | + |
| 109 | + |
| 110 | +class GatewayKernelManager(_GatewayKernelManager): |
| 111 | + """ |
| 112 | + Gateway kernel manager that uses our enhanced gateway kernel client. |
| 113 | +
|
| 114 | + This manager inherits from jupyter_server's GatewayKernelManager and configures it |
| 115 | + to use our GatewayKernelClient, which provides: |
| 116 | +
|
| 117 | + - Gateway communication capabilities for remote kernels |
| 118 | + - Kernel monitoring integration (heartbeat, execution state tracking) |
| 119 | + - Message ID encoding with channel and src_id using simple string operations |
| 120 | + - Full compatibility with our kernel monitor extension |
| 121 | + - Pre-created kernel client instance stored as a property |
| 122 | + - Automatic client connection/disconnection on kernel start/shutdown |
| 123 | +
|
| 124 | + When jupyter_server is configured to use a gateway, this manager ensures that |
| 125 | + remote kernels receive the same level of monitoring as local kernels. |
| 126 | + """ |
| 127 | + # Configure the manager to use our enhanced gateway client |
| 128 | + client_class = GatewayKernelClient |
| 129 | + client_factory = GatewayKernelClient |
| 130 | + |
| 131 | + kernel_client = Instance( |
| 132 | + 'jupyter_client.client.KernelClient', |
| 133 | + allow_none=True, |
| 134 | + help="""Pre-created kernel client instance. Created on initialization.""" |
| 135 | + ) |
| 136 | + |
| 137 | + def __init__(self, **kwargs): |
| 138 | + """Initialize the kernel manager and create a kernel client instance.""" |
| 139 | + super().__init__(**kwargs) |
| 140 | + |
| 141 | + # Create a kernel client instance immediately |
| 142 | + self.kernel_client = self.client(session=self.session) |
| 143 | + |
| 144 | + async def post_start_kernel(self, **kwargs): |
| 145 | + """After kernel starts, connect the kernel client. |
| 146 | +
|
| 147 | + This method is called after the kernel has been successfully started. |
| 148 | + It loads the latest connection info (with ports set by provisioner) |
| 149 | + and connects the kernel client to the kernel. |
| 150 | +
|
| 151 | + Note: If you override this method, make sure to call super().post_start_kernel(**kwargs) |
| 152 | + to ensure the kernel client connects properly. |
| 153 | + """ |
| 154 | + await super().post_start_kernel(**kwargs) |
| 155 | + |
| 156 | + try: |
| 157 | + # Load latest connection info from kernel manager |
| 158 | + # The provisioner has now set the real ports |
| 159 | + self.kernel_client.load_connection_info(self.get_connection_info(session=True)) |
| 160 | + |
| 161 | + # Connect the kernel client |
| 162 | + success = await self.kernel_client.connect() |
| 163 | + |
| 164 | + if not success: |
| 165 | + raise RuntimeError(f"Failed to connect kernel client for kernel {self.kernel_id}") |
| 166 | + |
| 167 | + self.log.info(f"Successfully connected kernel client for kernel {self.kernel_id}") |
| 168 | + |
| 169 | + except Exception as e: |
| 170 | + self.log.error(f"Failed to connect kernel client: {e}") |
| 171 | + # Re-raise to fail the kernel start |
| 172 | + raise |
| 173 | + |
| 174 | + async def cleanup_resources(self, restart=False): |
| 175 | + """Cleanup resources, disconnecting the kernel client if not restarting. |
| 176 | +
|
| 177 | + Parameters |
| 178 | + ---------- |
| 179 | + restart : bool |
| 180 | + If True, the kernel is being restarted and we should keep the client |
| 181 | + connected but clear its state. If False, fully disconnect. |
| 182 | + """ |
| 183 | + if self.kernel_client: |
| 184 | + if restart: |
| 185 | + # On restart, clear client state but keep connection |
| 186 | + # The connection will be refreshed in post_start_kernel after restart |
| 187 | + self.log.debug(f"Clearing kernel client state for restart of kernel {self.kernel_id}") |
| 188 | + self.kernel_client.last_shell_status_time = None |
| 189 | + self.kernel_client.last_control_status_time = None |
| 190 | + # Disconnect before restart - will reconnect after |
| 191 | + await self.kernel_client.stop_listening() |
| 192 | + self.kernel_client.stop_channels() |
| 193 | + else: |
| 194 | + # On shutdown, fully disconnect the client |
| 195 | + self.log.debug(f"Disconnecting kernel client for kernel {self.kernel_id}") |
| 196 | + await self.kernel_client.stop_listening() |
| 197 | + self.kernel_client.stop_channels() |
| 198 | + |
| 199 | + await super().cleanup_resources(restart=restart) |
| 200 | + |
| 201 | + |
| 202 | +class GatewayMultiKernelManager(GatewayMappingKernelManager): |
| 203 | + """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" |
| 204 | + |
| 205 | + @default("kernel_manager_class") |
| 206 | + def _default_kernel_manager_class(self): |
| 207 | + return "jupyter_server.gateway.v3.managers.GatewayKernelManager" |
| 208 | + |
| 209 | + def start_watching_activity(self, kernel_id): |
| 210 | + pass |
| 211 | + |
| 212 | + def stop_buffering(self, kernel_id): |
| 213 | + pass |
| 214 | + |
0 commit comments