diff --git a/jupyter_server_documents/app.py b/jupyter_server_documents/app.py index 6b2b631..1d53aa1 100644 --- a/jupyter_server_documents/app.py +++ b/jupyter_server_documents/app.py @@ -101,5 +101,5 @@ def _link_jupyter_server_extension(self, server_app): async def stop_extension(self): self.log.info("Stopping `jupyter_server_documents` server extension.") if self.yroom_manager: - await self.yroom_manager.stop() + self.yroom_manager.stop() self.log.info("`jupyter_server_documents` server extension is shut down. Goodbye!") diff --git a/jupyter_server_documents/rooms/yroom.py b/jupyter_server_documents/rooms/yroom.py index 4a07062..f80fbf8 100644 --- a/jupyter_server_documents/rooms/yroom.py +++ b/jupyter_server_documents/rooms/yroom.py @@ -5,6 +5,7 @@ from ..websockets import YjsClientGroup import pycrdt +import uuid from pycrdt import YMessageType, YSyncMessageType as YSyncMessageSubtype from jupyter_server_documents.ydocs import ydocs as jupyter_ydoc_classes from jupyter_ydoc.ybasedoc import YBaseDoc @@ -37,10 +38,10 @@ class YRoom: The `YRoomFileAPI` instance for this room. This is set to `None` only if `self.room_id == "JupyterLab:globalAwareness"`. - The file API provides `load_ydoc_content()` for loading the YDoc content - from the `ContentsManager`, accepts & handles save requests via - `file_api.schedule_save()`, and automatically watches the file for - out-of-band changes. + The file API provides `load_content_into()` for loading the content + from the `ContentsManager` into the JupyterYDoc. It accepts & handles save + requests via `file_api.schedule_save()`, and automatically watches the file + for out-of-band changes. """ events_api: YRoomEventsAPI | None @@ -53,6 +54,14 @@ class YRoom: _jupyter_ydoc: YBaseDoc | None """JupyterYDoc""" + _jupyter_ydoc_observers: dict[str, callable[[str, Any], Any]] + """ + Dictionary of JupyterYDoc observers added by consumers of this room. + + Added to via `observe_jupyter_ydoc()`. Removed from via + `unobserve_jupyter_ydoc()`. + """ + _ydoc: pycrdt.Doc """Ydoc""" _awareness: pycrdt.Awareness @@ -76,9 +85,15 @@ class YRoom: _ydoc_subscription: pycrdt.Subscription """Subscription to YDoc changes.""" - _on_stop: callable[[], Any] | None + _stopped: bool """ - Callback to run after stopping, provided in the constructor. + Whether the YRoom is stopped. Set to `True` when `stop()` is called and set + to `False` when `restart()` is called. + """ + + _updated: bool + """ + See `self.updated` for more info. """ _fileid_manager: BaseFileIdManager @@ -93,8 +108,7 @@ def __init__( loop: asyncio.AbstractEventLoop, fileid_manager: BaseFileIdManager, contents_manager: AsyncContentsManager | ContentsManager, - on_stop: callable[[], Any] | None = None, - event_logger: EventLogger + event_logger: EventLogger, ): # Bind instance attributes self.room_id = room_id @@ -102,40 +116,41 @@ def __init__( self._loop = loop self._fileid_manager = fileid_manager self._contents_manager = contents_manager - self._on_stop = on_stop + self._jupyter_ydoc_observers = {} + self._stopped = False + self._updated = False - # Initialize YjsClientGroup, YDoc, YAwareness, JupyterYDoc + # Initialize YjsClientGroup, YDoc, and Awareness self._client_group = YjsClientGroup(room_id=room_id, log=self.log, loop=self._loop) - self._ydoc = pycrdt.Doc() - self._awareness = pycrdt.Awareness(ydoc=self._ydoc) + self._ydoc = self._init_ydoc() + self._awareness = self._init_awareness(ydoc=self._ydoc) # If this room is providing global awareness, set unused optional # attributes to `None`. if self.room_id == "JupyterLab:globalAwareness": - self.file_api = None self._jupyter_ydoc = None + self.file_api = None self.events_api = None else: - # Otherwise, initialize optional attributes for document rooms + # Otherwise, initialize optional attributes for document rooms. # Initialize JupyterYDoc - self._jupyter_ydoc = self._init_jupyter_ydoc() + self._jupyter_ydoc = self._init_jupyter_ydoc( + ydoc=self._ydoc, + awareness=self._awareness + ) # Initialize YRoomFileAPI, start loading content self.file_api = YRoomFileAPI( room_id=self.room_id, - jupyter_ydoc=self._jupyter_ydoc, log=self.log, loop=self._loop, fileid_manager=self._fileid_manager, contents_manager=self._contents_manager, - on_outofband_change=self.reload_ydoc, + on_outofband_change=self.handle_outofband_change, on_outofband_move=self.handle_outofband_move, on_inband_deletion=self.handle_inband_deletion ) - self.file_api.load_ydoc_content() - - # Attach Jupyter YDoc observer to automatically save on change - self._jupyter_ydoc.observe(self._on_jupyter_ydoc_update) + self.file_api.load_content_into(self._jupyter_ydoc) # Initialize YRoomEventsAPI self.events_api = YRoomEventsAPI( @@ -145,15 +160,6 @@ def __init__( log=self.log, ) - # Start observers on `self.ydoc` and `self.awareness` to ensure new - # updates are always broadcast to all clients. - self._awareness_subscription = self._awareness.observe( - self._on_awareness_update - ) - self._ydoc_subscription = self._ydoc.observe( - self._on_ydoc_update - ) - # Initialize message queue and start background task that routes new # messages in the message queue to the appropriate handler method. self._message_queue = asyncio.Queue() @@ -170,23 +176,48 @@ def __init__( # Emit 'load' event once content is loaded assert self.file_api async def emit_load_event(): - await self.file_api.ydoc_content_loaded.wait() + await self.file_api.until_content_loaded self.events_api.emit_room_event("load") self._loop.create_task(emit_load_event()) - def _init_jupyter_ydoc(self) -> YBaseDoc: + def _init_ydoc(self) -> pycrdt.Doc: + """ + Initializes a YDoc, automatically binding its `_on_ydoc_update()` + observer to `self._ydoc_subscription`. The observer can be removed via + `ydoc.unobserve(self._ydoc_subscription)`. """ - Initializes a Jupyter YDoc (instance of `pycrdt.YBaseDoc`). This - should not be called in global awareness rooms, and requires - `self._ydoc` and `self._awareness` to be set prior. + self._ydoc = pycrdt.Doc() + self._ydoc_subscription = self._ydoc.observe( + self._on_ydoc_update + ) + return self._ydoc + + + def _init_awareness(self, ydoc: pycrdt.Doc) -> pycrdt.Awareness: + """ + Initializes an Awareness instance, automatically binding its + `_on_awareness_update()` observer to `self._awareness_subscription`. + The observer can be removed via + `awareness.unobserve(self._awareness_subscription)`. + """ + self._awareness = pycrdt.Awareness(ydoc=ydoc) + self._awareness_subscription = self._awareness.observe( + self._on_awareness_update + ) + return self._awareness + - Raises `AssertionError` if the room ID is "JupyterLab:globalAwareness" - or if either `self._ydoc` or `self._awareness` are not set. + def _init_jupyter_ydoc(self, ydoc: pycrdt.Doc, awareness: pycrdt.Awareness) -> YBaseDoc: + """ + Initializes a Jupyter YDoc (instance of `pycrdt.YBaseDoc`), + automatically attaching its `_on_jupyter_ydoc_update()` observer. + The observer can be removed via `jupyter_ydoc.unobserve()`. + + Raises `AssertionError` if the room ID is "JupyterLab:globalAwareness", + as a JupyterYDoc is not needed for global awareness rooms. """ assert self.room_id != "JupyterLab:globalAwareness" - assert self._ydoc - assert self._awareness # Get Jupyter YDoc class, defaulting to `YFile` if the file type is # unrecognized @@ -197,8 +228,9 @@ def _init_jupyter_ydoc(self) -> YBaseDoc: ) # Initialize Jupyter YDoc and return it - jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness) - return jupyter_ydoc + self._jupyter_ydoc = JupyterYDocClass(ydoc=ydoc, awareness=awareness) + self._jupyter_ydoc.observe(self._on_jupyter_ydoc_update) + return self._jupyter_ydoc @property @@ -211,32 +243,32 @@ def clients(self) -> YjsClientGroup: return self._client_group - async def get_jupyter_ydoc(self): + async def get_jupyter_ydoc(self) -> YBaseDoc: """ Returns a reference to the room's JupyterYDoc (`jupyter_ydoc.ybasedoc.YBaseDoc`) after waiting for its content to be loaded from the ContentsManager. """ - if self.file_api: - await self.file_api.ydoc_content_loaded.wait() if self.room_id == "JupyterLab:globalAwareness": message = "There is no Jupyter ydoc for global awareness scenario" self.log.error(message) raise Exception(message) + if self.file_api: + await self.file_api.until_content_loaded return self._jupyter_ydoc - async def get_ydoc(self): + async def get_ydoc(self) -> pycrdt.Doc: """ Returns a reference to the room's YDoc (`pycrdt.Doc`) after waiting for its content to be loaded from the ContentsManager. """ if self.file_api: - await self.file_api.ydoc_content_loaded.wait() + await self.file_api.until_content_loaded return self._ydoc - def get_awareness(self): + def get_awareness(self) -> pycrdt.Awareness: """ Returns a reference to the room's awareness (`pycrdt.Awareness`). """ @@ -264,7 +296,7 @@ async def _process_message_queue(self) -> None: # Wait for content to be loaded before processing any messages in the # message queue if self.file_api: - await self.file_api.ydoc_content_loaded.wait() + await self.file_api.until_content_loaded # Begin processing messages from the message queue while True: @@ -275,60 +307,73 @@ async def _process_message_queue(self) -> None: if queue_item is None: break - # Otherwise, process the new message + # Otherwise, process & handle the new message client_id, message = queue_item - - # Determine message type & subtype from header - message_type = message[0] - sync_message_subtype = "*" - # message subtypes only exist on sync messages, hence this condition - if message_type == YMessageType.SYNC and len(message) >= 2: - sync_message_subtype = message[1] - - # Determine if message is invalid - # NOTE: In Python 3.12+, we can drop list(...) call - # according to https://docs.python.org/3/library/enum.html#enum.EnumType.__contains__ - invalid_message_type = message_type not in list(YMessageType) - invalid_sync_message_type = message_type == YMessageType.SYNC and sync_message_subtype not in list(YSyncMessageSubtype) - invalid_message = invalid_message_type or invalid_sync_message_type - - # Handle invalid messages by logging a warning and ignoring - if invalid_message: - self.log.warning( - "Ignoring an unrecognized message with header " - f"'{message_type},{sync_message_subtype}' from client " - f"'{client_id}'. Messages must have one of the following " - "headers: '0,0' (SyncStep1), '0,1' (SyncStep2), " - "'0,2' (SyncUpdate), or '1,*' (AwarenessUpdate)." - ) - # Handle Awareness messages - elif message_type == YMessageType.AWARENESS: - self.log.debug(f"Received AwarenessUpdate from '{client_id}'.") - self.handle_awareness_update(client_id, message) - self.log.debug(f"Handled AwarenessUpdate from '{client_id}'.") - # Handle Sync messages - elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP1: - self.log.info(f"Received SS1 from '{client_id}'.") - self.handle_sync_step1(client_id, message) - self.log.info(f"Handled SS1 from '{client_id}'.") - elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP2: - self.log.info(f"Received SS2 from '{client_id}'.") - self.handle_sync_step2(client_id, message) - self.log.info(f"Handled SS2 from '{client_id}'.") - elif sync_message_subtype == YSyncMessageSubtype.SYNC_UPDATE: - self.log.info(f"Received SyncUpdate from '{client_id}'.") - self.handle_sync_update(client_id, message) - self.log.info(f"Handled SyncUpdate from '{client_id}'.") + self.handle_message(client_id, message) # Finally, inform the asyncio Queue that the task was complete # This is required for `self._message_queue.join()` to unblock once # queue is empty in `self.stop()`. self._message_queue.task_done() - self.log.info( + self.log.debug( "Stopped `self._process_message_queue()` background task " f"for YRoom '{self.room_id}'." ) + + def handle_message(self, client_id: str, message: bytes) -> None: + """ + Handles all messages from every client received in the message queue by + calling the appropriate handler based on the message type. This method + routes the message to one of the following methods: + + - `handle_sync_step1()` + - `handle_sync_step2()` + - `handle_sync_update()` + - `handle_awareness_update()` + """ + + # Determine message type & subtype from header + message_type = message[0] + sync_message_subtype = "*" + # message subtypes only exist on sync messages, hence this condition + if message_type == YMessageType.SYNC and len(message) >= 2: + sync_message_subtype = message[1] + + # Determine if message is invalid + # NOTE: In Python 3.12+, we can drop list(...) call + # according to https://docs.python.org/3/library/enum.html#enum.EnumType.__contains__ + invalid_message_type = message_type not in list(YMessageType) + invalid_sync_message_type = message_type == YMessageType.SYNC and sync_message_subtype not in list(YSyncMessageSubtype) + invalid_message = invalid_message_type or invalid_sync_message_type + + # Handle invalid messages by logging a warning and ignoring + if invalid_message: + self.log.warning( + "Ignoring an unrecognized message with header " + f"'{message_type},{sync_message_subtype}' from client " + f"'{client_id}'. Messages must have one of the following " + "headers: '0,0' (SyncStep1), '0,1' (SyncStep2), " + "'0,2' (SyncUpdate), or '1,*' (AwarenessUpdate)." + ) + # Handle Awareness messages + elif message_type == YMessageType.AWARENESS: + self.log.debug(f"Received AwarenessUpdate from '{client_id}'.") + self.handle_awareness_update(client_id, message) + self.log.debug(f"Handled AwarenessUpdate from '{client_id}'.") + # Handle Sync messages + elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP1: + self.log.info(f"Received SS1 from '{client_id}'.") + self.handle_sync_step1(client_id, message) + self.log.info(f"Handled SS1 from '{client_id}'.") + elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP2: + self.log.info(f"Received SS2 from '{client_id}'.") + self.handle_sync_step2(client_id, message) + self.log.info(f"Handled SS2 from '{client_id}'.") + elif sync_message_subtype == YSyncMessageSubtype.SYNC_UPDATE: + self.log.info(f"Received SyncUpdate from '{client_id}'.") + self.handle_sync_update(client_id, message) + self.log.info(f"Handled SyncUpdate from '{client_id}'.") def handle_sync_step1(self, client_id: str, message: bytes) -> None: @@ -447,6 +492,37 @@ def _on_ydoc_update(self, event: TransactionEvent) -> None: self._broadcast_message(message, message_type="SyncUpdate") + def observe_jupyter_ydoc(self, observer: callable[[str, Any], Any]) -> str: + """ + Adds an observer callback to the JupyterYDoc that fires on change. + The callback should accept 2 arguments: + + 1. `updated_key: str`: the key of the shared type that was updated, e.g. + "cells", "state", or "metadata". + + 2. `event: Any`: The `pycrdt` event corresponding to the shared type. + For example, if "state" refers to a `pycrdt.Map`, `event` will take the + type `pycrdt.MapEvent`. + + Consumers should use this method instead of calling `observe()` directly + on the `jupyter_ydoc.YBaseDoc` instance, because JupyterYDocs generally + only allow for a single observer. + + Returns an `observer_id: str` that can be passed to + `unobserve_jupyter_ydoc()` to remove the observer. + """ + observer_id = uuid.uuid4() + self._jupyter_ydoc_observers[observer_id] = observer + + + def unobserve_jupyter_ydoc(self, observer_id: str): + """ + Removes an observer from the JupyterYDoc previously added by + `observe_jupyter_ydoc()`, given the returned `observer_id`. + """ + self._jupyter_ydoc_observers.pop(observer_id, None) + + def _on_jupyter_ydoc_update(self, updated_key: str, event: Any) -> None: """ This method is an observer on `self._jupyter_ydoc` which saves the file @@ -472,8 +548,7 @@ def _on_jupyter_ydoc_update(self, updated_key: str, event: Any) -> None: # Do nothing if the content is still loading. Clients cannot make # updates until the content is loaded, so this safely prevents an extra # save upon loading/reloading the YDoc. - content_loading = not self.file_api.ydoc_content_loaded.is_set() - if content_loading: + if not self.file_api.content_loaded: return # Do nothing if the event updates the 'state' dictionary with no effect @@ -483,10 +558,16 @@ def _on_jupyter_ydoc_update(self, updated_key: str, event: Any) -> None: map_event = cast(pycrdt.MapEvent, event) if should_ignore_state_update(map_event): return + + # Otherwise, a change was made. + # Call all observers added by consumers first. + for observer in self._jupyter_ydoc_observers.values(): + observer(updated_key, event) - # Otherwise, save the file + # Then set `updated=True` and save the file. + self._updated = True self.file_api.schedule_save() - + def handle_awareness_update(self, client_id: str, message: bytes) -> None: # Apply the AwarenessUpdate message @@ -579,157 +660,166 @@ def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) - def reload_ydoc(self) -> None: """ - Reloads the YDoc from the `ContentsManager`. This method: - - - Is called in response to out-of-band changes. - - - Disconnects all clients with close code 4000. - - - Empties the message queue, as the updates can no longer be applied. - - - Resets `self._ydoc`, `self._awareness`, and `self._jupyter_ydoc`. - - - Resets `self.file_api` to reload the YDoc from the `ContentsManager`. - - This method is deliberately synchronous so it cannot interrupted by - another coroutine. + Alias for `self.restart(close_code=4000, immediately=True)`. + + TODO: Use a designated close code to distinguish YDoc reloads from + out-of-band changes. """ - # Do nothing if this is a global awareness room, since the YDoc is never - # used anyways. - if self.room_id == "JupyterLab:globalAwareness": - return - - # Stop the existing `YRoomFileAPI` immediately - assert self.file_api - self.file_api.stop() - - # Disconnect all clients with close code 4000. - # This is a special code defined by our extension, informing each client - # to purge their existing YDoc & re-connect. - self.clients.close_all(4000) - - # Empty message queue - while not self._message_queue.empty(): - self._message_queue.get_nowait() - self._message_queue.task_done() + self.restart(close_code=4000, immediately=True) - # Remove existing observers - self._ydoc.unobserve(self._ydoc_subscription) - self._awareness.unobserve(self._awareness_subscription) - self._jupyter_ydoc.unobserve() - - # Reset YDoc, YAwareness, JupyterYDoc to empty states - self._ydoc = pycrdt.Doc() - self._awareness = pycrdt.Awareness(ydoc=self._ydoc) - self._jupyter_ydoc = self._init_jupyter_ydoc() - - # Reset `YRoomFileAPI` & reload the document - self.file_api = YRoomFileAPI( - room_id=self.room_id, - jupyter_ydoc=self._jupyter_ydoc, - log=self.log, - loop=self._loop, - fileid_manager=self._fileid_manager, - contents_manager=self._contents_manager, - on_outofband_change=self.reload_ydoc, - on_outofband_move=self.handle_outofband_move, - on_inband_deletion=self.handle_inband_deletion - ) - self.file_api.load_ydoc_content() - - # Add observers to new YDoc, YAwareness, and JupyterYDoc instances - self._awareness_subscription = self._awareness.observe( - self._on_awareness_update - ) - self._ydoc_subscription = self._ydoc.observe( - self._on_ydoc_update - ) - self._jupyter_ydoc.observe(self._on_jupyter_ydoc_update) + + def handle_outofband_change(self) -> None: + """ + Handles an out-of-band change by restarting the YRoom immediately, + closing all Websockets with close code 4000. - # Emit 'overwrite' event as the YDoc content has been overwritten - if self.events_api: - self.events_api.emit_room_event("overwrite") + See `restart()` for more info. + """ + self.restart(close_code=4000, immediately=True) + - def handle_outofband_move(self) -> None: """ - Handles an out-of-band move/deletion by stopping the YRoom immediately - with close code 4001. + Handles an out-of-band move/deletion by stopping the YRoom immediately, + closing all Websockets with close code 4001. + + See `stop()` for more info. """ - self.stop_immediately(close_code=4001) + self.stop(close_code=4001, immediately=True) def handle_inband_deletion(self) -> None: """ - Handles an in-band file deletion by stopping the YRoom immediately with - close code 4002. + Handles an in-band file deletion by stopping the YRoom immediately, + closing all Websockets with close code 4002. + + See `stop()` for more info. """ - self.stop_immediately(close_code=4002) + self.stop(close_code=4002, immediately=True) - def stop_immediately(self, close_code: int) -> None: + def stop(self, close_code: int = 1001, immediately: bool = False) -> None: """ - Stops the YRoom immediately, closing all Websockets with the given - `close_code`. This is similar to `self.stop()` with some key - differences: + Stops the YRoom. This method: + + - Disconnects all clients with the given `close_code`, + defaulting to `1001` (server shutting down) if not given. - - This does not apply any pending YDoc updates from other clients. - - This does not save the file before exiting. + - Removes all observers and stops the `_process_message_queue()` + background task. - This should be reserved for scenarios where it does not make sense to - apply pending updates or save the file, e.g. when the file has been - deleted from disk. + - If `immediately=False` (default), this method will finish applying all + pending updates in the message queue and save the YDoc before returning. + Otherwise, if `immediately=True`, this method will drop all pending + updates and not save the YDoc before returning. + + - Clears the YDoc, Awareness, and JupyterYDoc, freeing their memory to + the server. This deletes the YDoc history. """ - # Disconnect all clients with given `close_code` + self.log.info(f"Stopping YRoom '{self.room_id}'.") + + # Disconnect all clients with the given close code self.clients.stop(close_code=close_code) # Remove all observers self._ydoc.unobserve(self._ydoc_subscription) self._awareness.unobserve(self._awareness_subscription) - - # Purge the message queue immediately, dropping all queued messages + if self._jupyter_ydoc: + self._jupyter_ydoc.unobserve() + + # Empty the message queue based on `immediately` argument while not self._message_queue.empty(): - self._message_queue.get_nowait() - self._message_queue.task_done() + if immediately: + self._message_queue.get_nowait() + self._message_queue.task_done() + else: + client_id, message = self._message.queue.get_nowait() + self.handle_message(client_id, message) - # Enqueue `None` to stop the `_process_message_queue()` background task + # Stop the `_process_message_queue` task by enqueueing `None` self._message_queue.put_nowait(None) + + # Return early if the room is not a document room, as no more action is + # needed. + if not self.file_api or not self._jupyter_ydoc: + return - # Stop FileAPI immediately (without saving) - if self.file_api: - self.file_api.stop() + # Otherwise, stop the file API. + self.file_api.stop() - # Finally, run the provided callback (if any) and return - if self._on_stop: - self._on_stop() + # Clear the YDoc, saving the previous content unless `immediately=True` + if not immediately: + prev_jupyter_ydoc = self._jupyter_ydoc + self._loop.create_task( + self.file_api.save(prev_jupyter_ydoc) + ) + self._reset_ydoc() + self._stopped = True + + def _reset_ydoc(self) -> None: + """ + Deletes and re-initializes the YDoc, awareness, and JupyterYDoc. This + frees the memory occupied by their histories. + """ + self._ydoc = self._init_ydoc() + self._awareness = self._init_awareness(ydoc=self._ydoc) + self._jupyter_ydoc = self._init_jupyter_ydoc( + ydoc=self._ydoc, + awareness=self._awareness + ) + + @property + def stopped(self) -> bool: + """ + Returns whether the room is stopped. + """ + return self._stopped + - async def stop(self) -> None: + @property + def updated(self) -> bool: + """ + Returns whether the room has been updated since the last restart, or + since initialization if the room was not restarted. + + This initializes to `False` and is set to `True` whenever a meaningful + update that needs to be saved occurs. This is reset to `False` when + `restart()` is called. + """ + return self._updated + + + def restart(self, close_code: int = 1001, immediately: bool = False) -> None: """ - Stops the YRoom gracefully by disconnecting all clients with close code - 1001, applying all pending updates, and saving the YDoc before exiting. + Restarts the YRoom. This method re-initializes & reloads the YDoc, + Awareness, and the JupyterYDoc. After this method is called, this + instance behaves as if it were just initialized. + + If the YRoom was not stopped beforehand, then `self.stop(close_code, + immediately)` with the given arguments. Otherwise, `close_code` and + `immediately` are ignored. """ - # First, disconnect all clients by stopping the client group. - self.clients.stop() + # Stop if not stopped already + if not self._stopped: + self.stop(close_code=close_code, immediately=immediately) - # Remove all observers, as updates no longer need to be broadcast - self._ydoc.unobserve(self._ydoc_subscription) - self._awareness.unobserve(self._awareness_subscription) - if self._jupyter_ydoc: - self._jupyter_ydoc.unobserve() + # Reset internal state + self._stopped = False + self._updated = False - # Finish processing all messages, then enqueue `None` to stop the - # `_process_message_queue()` background task. - await self._message_queue.join() - self._message_queue.put_nowait(None) + # Restart client group + self.clients.restart() - # Stop FileAPI, saving the content before doing so - if self.file_api: - await self.file_api.stop_then_save() + # Restart `YRoomFileAPI` & reload the document + self.file_api.restart() + self.file_api.load_content_into(self._jupyter_ydoc) + + # Restart `_process_message_queue()` task + self._loop.create_task(self._process_message_queue()) - # Finally, run the provided callback (if any) and return - if self._on_stop: - self._on_stop() + self.log.info(f"Restarted YRoom '{self.room_id}'.") + def should_ignore_state_update(event: pycrdt.MapEvent) -> bool: """ diff --git a/jupyter_server_documents/rooms/yroom_file_api.py b/jupyter_server_documents/rooms/yroom_file_api.py index c257191..d93d584 100644 --- a/jupyter_server_documents/rooms/yroom_file_api.py +++ b/jupyter_server_documents/rooms/yroom_file_api.py @@ -1,9 +1,3 @@ -""" -WIP. - -This file just contains interfaces to be filled out later. -""" - from __future__ import annotations from typing import TYPE_CHECKING import asyncio @@ -14,21 +8,23 @@ from tornado.web import HTTPError if TYPE_CHECKING: - from typing import Any, Callable, Literal + from typing import Any, Callable, Coroutine, Literal from jupyter_server_fileid.manager import BaseFileIdManager from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager class YRoomFileAPI: """ Provides an API to 1 file from Jupyter Server's ContentsManager for a YRoom, - given the the room's JupyterYDoc and ID in the constructor. + given the the room ID in the constructor. + + - To load the content, consumers call `load_content_into()` with a + JupyterYDoc. This also starts the `_watch_file()` loop. - To load the content, consumers should call `file_api.load_ydoc_content()`, - then `await file_api.ydoc_content_loaded` before performing any operations - on the YDoc. + - Consumers should `await file_api.until_content_loaded` before performing + any operations on the YDoc. - To save a JupyterYDoc to the file, call - `file_api.schedule_save(jupyter_ydoc)`. + - To save a JupyterYDoc to the file, call + `file_api.schedule_save(jupyter_ydoc)` after calling `load_content_into()`. """ # See `filemanager.py` in `jupyter_server` for references on supported file @@ -38,14 +34,24 @@ class YRoomFileAPI: file_type: Literal["file", "notebook"] file_id: str log: logging.Logger - jupyter_ydoc: YBaseDoc _fileid_manager: BaseFileIdManager + """ + Stores a reference to the Jupyter Server's File ID Manager. + """ + _contents_manager: AsyncContentsManager | ContentsManager + """ + Stores a reference to the Jupyter Server's ContentsManager. + + NOTE: any calls made on this attribute should acquire & release the + `_content_lock`. See `_content_lock` for more info. + """ + _loop: asyncio.AbstractEventLoop _save_scheduled: bool - _ydoc_content_loading: bool - _ydoc_content_loaded: asyncio.Event + _content_loading: bool + _content_load_event: asyncio.Event _last_modified: datetime | None """ @@ -74,13 +80,29 @@ class YRoomFileAPI: The callback to run when an in-band move file deletion is detected. """ - _save_loop_task: asyncio.Task + _watch_file_task: asyncio.Task | None + """ + The task running the `_watch_file()` loop that processes scheduled saves and + watches for in-band & out-of-band changes. + """ + + _stopped: bool + """ + Whether the FileAPI has been stopped, i.e. when the `_watch_file()` task is + not running. + """ + + _content_lock: asyncio.Lock + """ + An `asyncio.Lock` that ensures `ContentsManager` calls reading/writing for a + single file do not overlap. This prevents file corruption scenarios like + dual-writes or dirty-reads. + """ def __init__( self, *, room_id: str, - jupyter_ydoc: YBaseDoc, log: logging.Logger, fileid_manager: BaseFileIdManager, contents_manager: AsyncContentsManager | ContentsManager, @@ -92,7 +114,6 @@ def __init__( # Bind instance attributes self.room_id = room_id self.file_format, self.file_type, self.file_id = room_id.split(":") - self.jupyter_ydoc = jupyter_ydoc self.log = log self._loop = loop self._fileid_manager = fileid_manager @@ -103,13 +124,12 @@ def __init__( self._save_scheduled = False self._last_path = None self._last_modified = None + self._stopped = False - # Initialize loading & loaded states - self._ydoc_content_loading = False - self._ydoc_content_loaded = asyncio.Event() - - # Start processing scheduled saves in a loop running concurrently - self._save_loop_task = self._loop.create_task(self._watch_file()) + # Initialize content-related primitives + self._content_loading = False + self._content_load_event = asyncio.Event() + self._content_lock = asyncio.Lock() def get_path(self) -> str | None: @@ -124,42 +144,42 @@ def get_path(self) -> str | None: @property - def ydoc_content_loaded(self) -> asyncio.Event: + def content_loaded(self) -> bool: """ - Returns an `asyncio.Event` that is set when the YDoc content is loaded. + Immediately returns whether the YDoc content is loaded. - To suspend a coroutine until the content is loaded: + To have a coroutine wait until the content is loaded, call `await + file_api.until_content_loaded` instead. + """ + return self._content_load_event.is_set() - ``` - await file_api.ydoc_content_loaded.wait() - ``` - To synchronously (i.e. immediately) check if the content is loaded: - - ``` - file_api.ydoc_content_loaded.is_set() - ``` + @property + def until_content_loaded(self) -> Coroutine[Any, Any, Literal[True]]: + """ + Returns an awaitable that resolves when the content is loaded. """ + return self._content_load_event.wait() - return self._ydoc_content_loaded - - def load_ydoc_content(self) -> None: + def load_content_into(self, jupyter_ydoc: YBaseDoc) -> None: """ - Loads the file from disk asynchronously into `self.jupyter_ydoc`. + Loads the file content into the given JupyterYDoc. Consumers should `await file_api.ydoc_content_loaded` before performing any operations on the YDoc. + + This method starts the `_watch_file()` task after the content is loaded. """ # If already loaded/loading, return immediately. # Otherwise, set loading to `True` and start the loading task. - if self._ydoc_content_loaded.is_set() or self._ydoc_content_loading: + if self._content_load_event.is_set() or self._content_loading: return - self._ydoc_content_loading = True - self._loop.create_task(self._load_ydoc_content()) + self._content_loading = True + self._loop.create_task(self._load_content(jupyter_ydoc)) - async def _load_ydoc_content(self) -> None: + async def _load_content(self, jupyter_ydoc: YBaseDoc) -> None: # Get the path specified on the file ID path = self.get_path() if not path: @@ -168,26 +188,32 @@ async def _load_ydoc_content(self) -> None: # Load the content of the file from the path self.log.info(f"Loading content for room ID '{self.room_id}', found at path: '{path}'.") - file_data = await ensure_async(self._contents_manager.get( - path, - type=self.file_type, - format=self.file_format - )) + async with self._content_lock: + file_data = await ensure_async(self._contents_manager.get( + path, + type=self.file_type, + format=self.file_format + )) # Set JupyterYDoc content and set `dirty = False` to hide the "unsaved # changes" icon in the UI - self.jupyter_ydoc.source = file_data['content'] - self.jupyter_ydoc.dirty = False + jupyter_ydoc.source = file_data['content'] + jupyter_ydoc.dirty = False # Set `_last_modified` timestamp self._last_modified = file_data['last_modified'] - # Finally, set loaded event to inform consumers that the YDoc is ready + # Set loaded event to inform consumers that the YDoc is ready # Also set loading to `False` for consistency and log success - self._ydoc_content_loaded.set() - self._ydoc_content_loading = False + self._content_load_event.set() + self._content_loading = False self.log.info(f"Loaded content for room ID '{self.room_id}'.") + # Start _watch_file() task + self._watch_file_task = self._loop.create_task( + self._watch_file(jupyter_ydoc) + ) + def schedule_save(self) -> None: """ @@ -197,18 +223,16 @@ def schedule_save(self) -> None: """ self._save_scheduled = True - async def _watch_file(self) -> None: + async def _watch_file(self, jupyter_ydoc: YBaseDoc) -> None: """ - Defines a background task that continuously saves the YDoc every 500ms, - checking for out-of-band changes before doing so. + Defines a background task that processes scheduled saves to the YDoc + every 500ms, checking for in-band & out-of-band changes before doing so. - Note that consumers must call `self.schedule_save()` for the next tick + This task is started by a call to `load_ydoc_content()`. + Consumers must call `self.schedule_save()` for the next tick of this task to save. """ - # Wait for content to be loaded before processing scheduled saves - await self._ydoc_content_loaded.wait() - while True: try: await asyncio.sleep(0.5) @@ -218,7 +242,7 @@ async def _watch_file(self) -> None: # cancelled halfway and corrupting the file. We need to # store a reference to the shielded task to prevent it from # being garbage collected (see `asyncio.shield()` docs). - save_task = self._save_jupyter_ydoc() + save_task = self.save(jupyter_ydoc) await asyncio.shield(save_task) except asyncio.CancelledError: break @@ -231,7 +255,7 @@ async def _watch_file(self) -> None: # occurs repeatedly. await asyncio.sleep(5) - self.log.info( + self.log.debug( "Stopped `self._watch_file()` background task " f"for YRoom '{self.room_id}'." ) @@ -285,9 +309,10 @@ async def _check_file(self): # If this raises `HTTPError(404)`, that indicates the file was # moved/deleted out-of-band. try: - file_data = await ensure_async(self._contents_manager.get( - path=path, format=file_format, type=file_type, content=False - )) + async with self._content_lock: + file_data = await ensure_async(self._contents_manager.get( + path=path, format=file_format, type=file_type, content=False + )) except HTTPError as e: # If not 404, re-raise the exception as it is unknown if (e.status_code != 404): @@ -317,18 +342,20 @@ async def _check_file(self): self._on_outofband_change() - async def _save_jupyter_ydoc(self): + async def save(self, jupyter_ydoc: YBaseDoc): """ - Saves the JupyterYDoc to disk immediately. + Saves the given JupyterYDoc to disk. This method works even if the + FileAPI is stopped. - This is a private method. Consumers should call - `file_api.schedule_save()` to save the YDoc on the next tick of - the `self._watch_file()` background task. + This method should only be called by consumers if the YDoc needs to be + saved while the FileAPI is stopped, e.g. when the parent room is + stopping. In all other cases, consumers should call `schedule_save()` + instead. """ try: # Build arguments to `CM.save()` path = self.get_path() - content = self.jupyter_ydoc.source + content = jupyter_ydoc.source file_format = self.file_format file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file" @@ -338,14 +365,15 @@ async def _save_jupyter_ydoc(self): self._save_scheduled = False # Save the YDoc via the ContentsManager - file_data = await ensure_async(self._contents_manager.save( - { - "format": file_format, - "type": file_type, - "content": content, - }, - path - )) + async with self._content_lock: + file_data = await ensure_async(self._contents_manager.save( + { + "format": file_format, + "type": file_type, + "content": content, + }, + path + )) # Set most recent `last_modified` timestamp if file_data['last_modified']: @@ -354,7 +382,7 @@ async def _save_jupyter_ydoc(self): # Set `dirty` to `False` to hide the "unsaved changes" icon in the # JupyterLab tab for this YDoc in the frontend. - self.jupyter_ydoc.dirty = False + jupyter_ydoc.dirty = False except Exception as e: self.log.error("An exception occurred when saving JupyterYDoc.") self.log.exception(e) @@ -365,19 +393,40 @@ def stop(self) -> None: Gracefully stops the `YRoomFileAPI`. This immediately halts the background task saving the YDoc to the `ContentsManager`. - To save the YDoc after stopping, call `await file_api.stop_then_save()` - instead. + To save the YDoc after stopping, call `await + file_api.save_immediately()` after calling this method. """ - self._save_loop_task.cancel() + if self._watch_file_task: + self._watch_file_task.cancel() + self._stopped = True + @property + def stopped(self) -> bool: + """ + Returns whether the FileAPI has been stopped via the `stop()` method. + """ + return self._stopped - async def stop_then_save(self) -> None: + def restart(self) -> None: """ - Gracefully stops the YRoomFileAPI by calling `self.stop()`, then saves - the content of `self.jupyter_ydoc` before exiting. + Restarts the instance by stopping if the room is not stopped, then + clearing its internal state. + + Consumers should call `load_content_into()` again after this method to + restart the `_watch_file()` task. """ - self.stop() - await self._save_jupyter_ydoc() + # Stop if not stopped already + if not self.stopped: + self.stop() + + # Reset instance attributes + self._stopped = False + self._content_load_event = asyncio.Event() + self._content_loading = False + self._save_scheduled = False + self._last_modified = None + self._last_path = None + self.log.info(f"Restarted FileAPI for room '{self.room_id}'.") # see https://github.com/jupyterlab/jupyter-collaboration/blob/main/projects/jupyter-server-ydoc/jupyter_server_ydoc/loaders.py#L146-L149 diff --git a/jupyter_server_documents/rooms/yroom_manager.py b/jupyter_server_documents/rooms/yroom_manager.py index e6d7576..d0e4926 100644 --- a/jupyter_server_documents/rooms/yroom_manager.py +++ b/jupyter_server_documents/rooms/yroom_manager.py @@ -11,7 +11,36 @@ from jupyter_events import EventLogger class YRoomManager(): + """ + A singleton that manages all `YRoom` instances in the server extension. + + This manager automatically restarts updated `YRoom` instances if they have + had no connected clients or active kernel for >10 seconds. This deletes the + YDoc history to free its memory to the server. + """ + _rooms_by_id: dict[str, YRoom] + """ + Dictionary of active `YRoom` instances, keyed by room ID. Rooms are never + deleted from this dictionary. + + TODO: Delete a room if its file was deleted in/out-of-band or moved + out-of-band. See #116. + """ + + _inactive_rooms: set[str] + """ + Set of room IDs that were marked inactive on the last iteration of + `_watch_rooms()`. If a room is inactive and its ID is present in this set, + then the room should be restarted as it has been inactive for >10 seconds. + """ + + _get_fileid_manager: callable[[], BaseFileIdManager] + contents_manager: AsyncContentsManager | ContentsManager + event_logger: EventLogger + loop: asyncio.AbstractEventLoop + log: logging.Logger + _watch_rooms_task: asyncio.Task | None def __init__( self, @@ -31,7 +60,15 @@ def __init__( # Initialize dictionary of YRooms, keyed by room ID self._rooms_by_id = {} - + + # Initialize set of inactive rooms tracked by `self._watch_rooms()` + self._inactive_rooms = set() + + # Start `self._watch_rooms()` background task to automatically stop + # empty rooms + # TODO: Do not enable this until #120 is addressed. + # self._watch_rooms_task = self.loop.create_task(self._watch_rooms()) + @property def fileid_manager(self) -> BaseFileIdManager: @@ -40,13 +77,17 @@ def fileid_manager(self) -> BaseFileIdManager: def get_room(self, room_id: str) -> YRoom | None: """ - Retrieves a YRoom given a room ID. If the YRoom does not exist, this - method will initialize a new YRoom. + Returns the `YRoom` instance for a given room ID. If the instance does + not exist, this method will initialize one and return it. Otherwise, + this method returns the instance from its cache. """ + # First, ensure the room is not considered inactive. + self._inactive_rooms.discard(room_id) - # If room exists, then return it immediately - if room_id in self._rooms_by_id: - return self._rooms_by_id[room_id] + # If room exists, return the room + yroom = self._rooms_by_id.get(room_id, None) + if yroom: + return yroom # Otherwise, create a new room try: @@ -57,7 +98,6 @@ def get_room(self, room_id: str) -> YRoom | None: loop=self.loop, fileid_manager=self.fileid_manager, contents_manager=self.contents_manager, - on_stop=lambda: self._handle_yroom_stop(room_id), event_logger=self.event_logger, ) self._rooms_by_id[room_id] = yroom @@ -70,19 +110,20 @@ def get_room(self, room_id: str) -> YRoom | None: return None - def _handle_yroom_stop(self, room_id: str) -> None: + def has_room(self, room_id: str) -> bool: """ - Callback that is run when the YRoom is stopped. This ensures the room is - removed from the dictionary for garbage collection, even if the room was - stopped directly without `YRoomManager.delete_room()`. + Returns whether a `YRoom` instance with a matching `room_id` already + exists. """ - self._rooms_by_id.pop(room_id, None) + return room_id in self._rooms_by_id - - async def delete_room(self, room_id: str) -> None: + + def delete_room(self, room_id: str) -> None: """ - Gracefully deletes a YRoom given a room ID. This stops the YRoom first, - which finishes applying all updates & saves the content automatically. + Gracefully deletes a YRoom given a room ID. This stops the YRoom, + closing all Websockets with close code 1001 (server shutting down), + applying remaining updates, and saving the final content of the YDoc in + a background task. Returns `True` if the room was deleted successfully. Returns `False` if an exception was raised. @@ -93,41 +134,116 @@ async def delete_room(self, room_id: str) -> None: self.log.info(f"Stopping YRoom '{room_id}'.") try: - await yroom.stop() - self.log.info(f"Stopped YRoom '{room_id}'.") + yroom.stop() return True except Exception as e: - self.log.error(f"Exception raised when stopping YRoom '{room_id}:") - self.log.exception(e) + self.log.exception( + f"Exception raised when stopping YRoom '{room_id}: " + ) return False + async def _watch_rooms(self) -> None: + """ + Background task that checks all `YRoom` instances every 10 seconds, + restarting any updated rooms that have been inactive for >10 seconds. + This frees the memory occupied by the room's YDoc history, discarding it + in the process. + + - For rooms providing notebooks: This task restarts the room if it has + been updated, has no connected clients, and its kernel execution status + is either 'idle' or 'dead'. - async def stop(self) -> None: + - For all other rooms: This task restarts the room if it has been + updated and has no connected clients. + """ + while True: + # Check every 10 seconds + await asyncio.sleep(10) + + # Get all room IDs, except for the global awareness room + room_ids = set(self._rooms_by_id.keys()) + room_ids.discard("JupyterLab:globalAwareness") + + # Check all rooms and restart it if inactive for >10 seconds. + for room_id in room_ids: + self._check_room(room_id) + + + def _check_room(self, room_id: str) -> None: + """ + Checks a room for inactivity. + + - Rooms that have not been updated are not restarted, as there is no + YDoc history to free. + + - If a room is inactive and not in `_inactive_rooms`, this method adds + the room to `_inactive_rooms`. + + - If a room is inactive and is listed in `_inactive_rooms`, this method + restarts the room, as it has been inactive for 2 consecutive iterations + of `_watch_rooms()`. + """ + # Do nothing if the room has any connected clients. + room = self._rooms_by_id[room_id] + if room.clients.count != 0: + self._inactive_rooms.discard(room_id) + return + + # Do nothing if the room contains a notebook with kernel execution state + # neither 'idle' nor 'dead'. + # In this case, the notebook kernel may still be running code cells, so + # the room should not be closed. + awareness = room.get_awareness().get_local_state() or {} + execution_state = awareness.get("kernel", {}).get("execution_state", None) + if execution_state not in { "idle", "dead", None }: + self._inactive_rooms.discard(room_id) + return + + # Do nothing if the room has not been updated. This prevents empty rooms + # from being restarted every 10 seconds. + if not room.updated: + self._inactive_rooms.discard(room_id) + return + + # The room is updated (with history) & inactive if this line is reached. + # Restart the room if was marked as inactive in the last iteration, + # otherwise mark it as inactive. + if room_id in self._inactive_rooms: + self.log.info( + f"Room '{room_id}' has been inactive for >10 seconds. " + "Restarting the room to free memory occupied by its history." + ) + room.restart() + self._inactive_rooms.discard(room_id) + else: + self._inactive_rooms.add(room_id) + + + def stop(self) -> None: """ Gracefully deletes each `YRoom`. See `delete_room()` for more info. """ + # First, stop all background tasks + if self._watch_rooms_task: + self._watch_rooms_task.cancel() + + # Get all room IDs. If there are none, return early. room_ids = list(self._rooms_by_id.keys()) room_count = len(room_ids) - if room_count == 0: return + # Otherwise, delete all rooms. self.log.info( f"Stopping `YRoomManager` and deleting all {room_count} YRooms." ) - - # Delete rooms in parallel. - # Note that we do not use `asyncio.TaskGroup` here because that cancels - # all other tasks when any task raises an exception. - deletion_tasks = [] + failures = 0 for room_id in room_ids: - dt = asyncio.create_task(self.delete_room(room_id)) - deletion_tasks.append(dt) - - # Use returned values to log success/failure of room deletion - results: list[bool] = await asyncio.gather(*deletion_tasks) - failures = results.count(False) + result = self.delete_room(room_id) + if not result: + failures += 1 + # Log the aggregate status before returning. if failures: self.log.error( "An exception occurred when stopping `YRoomManager`. " diff --git a/jupyter_server_documents/session_manager.py b/jupyter_server_documents/session_manager.py index e8e080d..473a0cd 100644 --- a/jupyter_server_documents/session_manager.py +++ b/jupyter_server_documents/session_manager.py @@ -30,20 +30,47 @@ def yroom_manager(self) -> YRoomManager: """The Jupyter Server's YRoom Manager.""" return self.serverapp.web_app.settings["yroom_manager"] - def get_kernel_client(self, kernel_id) -> DocumentAwareKernelClient: + _room_ids: dict[str, str] + """ + Dictionary of room IDs, keyed by session ID. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._room_ids = {} + + def get_kernel_client(self, kernel_id: str) -> DocumentAwareKernelClient: """Get the kernel client for a running kernel.""" kernel_manager = self.kernel_manager.get_kernel(kernel_id) kernel_client = kernel_manager.main_client return kernel_client - # The `type` argument here is for future proofing this API. - # Today, only notebooks have ydocs with kernels connected. - # In the future, we will include consoles here. - def get_yroom(self, path, type) -> YRoom: - """Get the yroom for a given path.""" + def get_yroom(self, session_id: str) -> YRoom: + """ + Get the `YRoom` for a session given its ID. The session must have + been created first via a call to `create_session()`. + """ + room_id = self._room_ids.get(session_id, None) + yroom = self.yroom_manager.get_room(room_id) if room_id else None + if not yroom: + raise LookupError(f"No room found for session '{session_id}'.") + return yroom + + + def _init_session_yroom(self, session_id: str, path: str) -> YRoom: + """ + Returns a `YRoom` for a session identified by the given `session_id` and + `path`. This should be called only in `create_session()`. + + This method stores the new room ID & session ID in `self._room_ids`. The + `YRoom` for a session can be retrieved via `self.get_yroom()` after this + method is called. + """ file_id = self.file_id_manager.index(path) room_id = f"json:notebook:{file_id}" yroom = self.yroom_manager.get_room(room_id) + self._room_ids[session_id] = room_id + return yroom async def create_session( @@ -57,51 +84,104 @@ async def create_session( """ After creating a session, connects the yroom to the kernel client. """ - output = await super().create_session( + session_model = await super().create_session( path, name, type, kernel_name, kernel_id ) + session_id = session_model["id"] if kernel_id is None: - kernel_id = output["kernel"]["id"] - - # Connect this session's yroom to the kernel. - if type == "notebook": - # If name or path is None, we cannot map to a yroom, - # so just move on. - if name is None or path is None: - self.log.debug("`name` or `path` was not given, so a yroom was not set up for this session.") - return output - # When JupyterLab creates a session, it uses a fake path - # which is the relative path + UUID, i.e. the notebook - # name is incorrect temporarily. It later makes multiple - # updates to the session to correct the path. - # - # Here, we create the true path to store in the fileID service - # by dropping the UUID and appending the file name. - real_path = os.path.join(os.path.split(path)[0], name) - yroom = self.get_yroom(real_path, type) - # TODO: we likely have a race condition here... need to - # think about it more. Currently, the kernel client gets - # created after the kernel starts fully. We need the - # kernel client instantiated _before_ trying to connect - # the yroom. - kernel_client = self.get_kernel_client(kernel_id) - await kernel_client.add_yroom(yroom) - self.log.info(f"Connected yroom {yroom.room_id} to kernel {kernel_id}. yroom: {yroom}") - else: - self.log.debug(f"Document type {type} is not supported by YRoom.") - return output + kernel_id = session_model["kernel"]["id"] + + # If the type is not 'notebook', return the session model immediately + if type != "notebook": + self.log.warning( + f"Document type '{type}' is not recognized by YDocSessionManager." + ) + return session_model + + # If name or path is None, we cannot map to a yroom, + # so just move on. + if name is None or path is None: + self.log.warning(f"`name` or `path` was not given for new session at '{path}'.") + return session_model + + # Otherwise, get a `YRoom` and add it to this session's kernel client. + + # When JupyterLab creates a session, it uses a fake path + # which is the relative path + UUID, i.e. the notebook + # name is incorrect temporarily. It later makes multiple + # updates to the session to correct the path. + # + # Here, we create the true path to store in the fileID service + # by dropping the UUID and appending the file name. + real_path = os.path.join(os.path.split(path)[0], name) + + # Get YRoom for this session and store its ID in `self._room_ids` + yroom = self._init_session_yroom(session_id, real_path) + + # Add YRoom to this session's kernel client + # TODO: we likely have a race condition here... need to + # think about it more. Currently, the kernel client gets + # created after the kernel starts fully. We need the + # kernel client instantiated _before_ trying to connect + # the yroom. + kernel_client = self.get_kernel_client(kernel_id) + await kernel_client.add_yroom(yroom) + self.log.info(f"Connected yroom {yroom.room_id} to kernel {kernel_id}. yroom: {yroom}") + return session_model + + + async def update_session(self, session_id: str, **update) -> None: + """ + Updates the session identified by `session_id` using the keyword + arguments passed to this method. Each keyword argument should correspond + to a column in the sessions table. + + This class calls the `update_session()` parent method, then updates the + kernel client if `update` contains `kernel_id`. + """ + # Apply update and return early if `kernel_id` was not updated + if "kernel_id" not in update: + return await super().update_session(session_id, **update) + + # Otherwise, first remove the YRoom from the old kernel client and add + # the YRoom to the new kernel client, before applying the update. + old_session_info = (await self.get_session(session_id=session_id) or {}) + old_kernel_id = old_session_info.get("kernel_id", None) + new_kernel_id = update.get("kernel_id", None) + self.log.info( + f"Updating session '{session_id}' from kernel '{old_kernel_id}' " + f"to kernel '{new_kernel_id}'." + ) + yroom = self.get_yroom(session_id) + if old_kernel_id: + old_kernel_client = self.get_kernel_client(old_kernel_id) + await old_kernel_client.remove_yroom(yroom=yroom) + if new_kernel_id: + new_kernel_client = self.get_kernel_client(new_kernel_id) + await new_kernel_client.add_yroom(yroom=yroom) + + # Apply update and return + return await super().update_session(session_id, **update) + async def delete_session(self, session_id): """ Deletes the session and disconnects the yroom from the kernel client. """ session = await self.get_session(session_id=session_id) - kernel_id, path, type = session["kernel"]["id"], session["path"], session["type"] - yroom = self.get_yroom(path, type) + kernel_id = session["kernel"]["id"] + + # Remove YRoom from session's kernel client + yroom = self.get_yroom(session_id) kernel_client = self.get_kernel_client(kernel_id) await kernel_client.remove_yroom(yroom) + + # Remove room ID stored for the session + self._room_ids.pop(session_id, None) + + # Delete the session via the parent method await super().delete_session(session_id) \ No newline at end of file diff --git a/jupyter_server_documents/tests/test_yroom_file_api.py b/jupyter_server_documents/tests/test_yroom_file_api.py index d2e29dd..545ad63 100644 --- a/jupyter_server_documents/tests/test_yroom_file_api.py +++ b/jupyter_server_documents/tests/test_yroom_file_api.py @@ -73,8 +73,8 @@ async def plaintext_file_api(mock_plaintext_file: str, jp_contents_manager: Asyn async def test_load_plaintext_file(plaintext_file_api: Awaitable[YRoomFileAPI], mock_plaintext_file: str): file_api = await plaintext_file_api jupyter_ydoc = file_api.jupyter_ydoc - file_api.load_ydoc_content() - await file_api.ydoc_content_loaded.wait() + file_api.load_content_into(jupyter_ydoc) + await file_api.until_content_loaded # Assert that `get_jupyter_ydoc()` returns a `jupyter_ydoc.YUnicode` object # for plaintext files diff --git a/jupyter_server_documents/websockets/clients.py b/jupyter_server_documents/websockets/clients.py index c0f662a..0a18144 100644 --- a/jupyter_server_documents/websockets/clients.py +++ b/jupyter_server_documents/websockets/clients.py @@ -157,9 +157,10 @@ def get_all(self, synced_only: bool = True) -> list[YjsClient]: return all_clients - def is_empty(self) -> bool: - """Returns whether the client group is empty.""" - return len(self.synced) == 0 and len(self.desynced) == 0 + @property + def count(self) -> int: + """Returns the number of clients synced / syncing to the room.""" + return len(self.synced) + len(self.desynced) async def _clean_desynced(self) -> None: while True: @@ -197,10 +198,11 @@ def close_all(self, close_code: int): for client in clients: client.websocket.close(code=close_code) - def stop(self, close_code: int = 1001): + def stop(self, close_code: int = 1001) -> None: """ Closes all Websocket connections with the given close code, removes all - clients from this group, and ignores any future calls to `add()`. + clients from this group. Future calls to `add()` are ignored until the + client group is restarted via `restart()`. If a close code is not specified, it defaults to 1001 (indicates server shutting down). @@ -210,4 +212,25 @@ def stop(self, close_code: int = 1001): # Set `_stopped` to `True` to ignore future calls to `add()` self._stopped = True + + @property + def stopped(self) -> bool: + """ + Returns whether the client group is stopped. + """ + + return self._stopped + + def restart(self, close_code: int = 1001) -> None: + """ + Restarts the client group by setting `stopped` to `False`. Future calls + to `add()` will *not* be ignored after this method is called. + + If the client group is not stopped, `self.stop(close_code)` will be + called with the given argument. Otherwise, `close_code` will be ignored. + """ + if not self.stopped: + self.stop(close_code=close_code) + + self._stopped = False \ No newline at end of file