diff --git a/app/packages/core/src/subscription/useExecutionStoreSubscribe.ts b/app/packages/core/src/subscription/useExecutionStoreSubscribe.ts new file mode 100644 index 00000000000..8327fa93814 --- /dev/null +++ b/app/packages/core/src/subscription/useExecutionStoreSubscribe.ts @@ -0,0 +1,180 @@ +import { resolveOperatorURI } from "@fiftyone/operators/src/operators"; +import * as fos from "@fiftyone/state"; +import { getEventSource } from "@fiftyone/utilities/src/fetch"; +import { EventSourceMessage } from "@microsoft/fetch-event-source"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { useRecoilCallback } from "recoil"; +// Timeout duration (in milliseconds) after which the subscription is considered unhealthy if no 'ping' is received. +const UNHEALTHY_SUBSCRIPTION_TIMEOUT = 30 * 1000; + +const EXECUTION_STORE_SUBSCRIBE_PATH = "/operators/subscribe-execution-store"; + +export type ExecutionStoreSubscribeCallback = ( + key: string, + value: T, + metadata: Record +) => void; + +/** + * Custom hook to subscribe to an execution store using Server-Sent Events (SSE). + * + * @param operatorUri - The URI of the execution store subscription operator. + * The operator must be a SSE operator. + * @param callback - Function to call when a new message is received. + * @param datasetId - Optional ID of the dataset to subscribe to. If not provided, the subscription will be global. + * @returns An object containing: + * - isSubscriptionHealthy: indicates the current health of the subscription. + * - unsubscribe: function to terminate the subscription. + * - resetSubscription: function to restart the subscription. + */ +export const useExecutionStoreSubscribe = ({ + operatorUri, + callback, + datasetId, +}: { + operatorUri: string; + callback: ExecutionStoreSubscribeCallback; + datasetId?: string; +}) => { + // Use a ref to ensure the callback is stable over renders + const callbackRef = useRef(callback); + callbackRef.current = callback; + + const abortControllerRef = useRef(new AbortController()); + + const [isSubscriptionHealthy, setIsSubscriptionHealthy] = useState(false); + + const setSubscriptionNotHealthyTimerRef = useRef(-1); + + /** + * Handler for incoming SSE messages. + * Clears any existing timeout and sets a new one when a 'ping' is received. + * Otherwise, parses the event data and calls the provided callback. + */ + const onMessageHandler = useCallback((event: EventSourceMessage) => { + window.clearTimeout(setSubscriptionNotHealthyTimerRef.current); + + if (event.event === "ping") { + setIsSubscriptionHealthy(true); + + setSubscriptionNotHealthyTimerRef.current = window.setTimeout(() => { + setIsSubscriptionHealthy(false); + }, UNHEALTHY_SUBSCRIPTION_TIMEOUT); + + return; + } + + if (!event.data) { + console.error("No data in execution store subscribe event"); + return; + } + + try { + const { key, value, metadata } = JSON.parse(event.data); + callbackRef.current(key, value, metadata); + } catch (error) { + console.error( + "Error parsing execution store subscribe event data:", + error + ); + } + }, []); + + /** + * Handler for SSE errors. + * Logs the error and marks the subscription as unhealthy. + */ + const onErrorHandler = useCallback((error: Error) => { + console.error("SSE connection error:", error); + setIsSubscriptionHealthy(false); + }, []); + + /** + * Handler for SSE connection closure. + * Marks the subscription as unhealthy. + */ + const onCloseHandler = useCallback(() => { + setIsSubscriptionHealthy(false); + }, []); + + const setupSubscription = useRecoilCallback( + ({ snapshot }) => + () => { + const datasetName = datasetId + ? snapshot.getLoadable(fos.datasetName).getValue() + : undefined; + + const resolvedOperatorUri = resolveOperatorURI(operatorUri); + + const data = { + dataset_id: datasetId, + operator_uri: resolvedOperatorUri, + dataset_name: datasetName, + }; + + try { + getEventSource( + EXECUTION_STORE_SUBSCRIBE_PATH, + { + onmessage: onMessageHandler, + onerror: onErrorHandler, + onclose: onCloseHandler, + }, + abortControllerRef.current.signal, + data + ); + setIsSubscriptionHealthy(true); + } catch (error) { + console.error("Error subscribing to execution store:", error); + } + }, + [operatorUri, datasetId, onMessageHandler, onErrorHandler, onCloseHandler] + ); + + const unsubscribe = useCallback(() => { + abortControllerRef.current.abort(); + setIsSubscriptionHealthy(false); + }, []); + + /** + * Effect to initialize the subscription when dependencies change. + * It also cleans up the subscription when the component unmounts. + */ + useEffect(() => { + if (!abortControllerRef.current.signal.aborted) { + setupSubscription(); + } else { + abortControllerRef.current = new AbortController(); + setupSubscription(); + } + + return () => { + unsubscribe(); + }; + }, [setupSubscription, unsubscribe]); + + const resetSubscription = useCallback(() => { + abortControllerRef.current.abort(); + abortControllerRef.current = new AbortController(); + setIsSubscriptionHealthy(false); + setupSubscription(); + }, [setupSubscription]); + + return { + /** + * Indicates the current health of the subscription. + */ + isSubscriptionHealthy, + + /** + * Unsubscribes from the SSE connection by aborting the current AbortController. + * Also marks the subscription as unhealthy. + */ + unsubscribe, + + /** + * Resets the subscription by aborting the current connection, + */ + resetSubscription, + }; +}; diff --git a/fiftyone/factory/repo_factory.py b/fiftyone/factory/repo_factory.py index d13a58ea16f..6a186da32e4 100644 --- a/fiftyone/factory/repo_factory.py +++ b/fiftyone/factory/repo_factory.py @@ -20,7 +20,9 @@ ExecutionStoreRepo, MongoExecutionStoreRepo, ) - +from fiftyone.operators.store.notification_service import ( + ChangeStreamNotificationService, +) _db: Database = None @@ -57,11 +59,20 @@ def delegated_operation_repo() -> DelegatedOperationRepo: def execution_store_repo( dataset_id: Optional[ObjectId] = None, collection_name: Optional[str] = None, + notification_service: Optional[ChangeStreamNotificationService] = None, ) -> ExecutionStoreRepo: - collection = _get_db()[ - collection_name or MongoExecutionStoreRepo.COLLECTION_NAME - ] - return MongoExecutionStoreRepo( - collection=collection, - dataset_id=dataset_id, + final_collection_name = ( + collection_name + if collection_name + else MongoExecutionStoreRepo.COLLECTION_NAME ) + es_repo_key = f"{final_collection_name}-{dataset_id}" + + if es_repo_key not in RepositoryFactory.repos: + RepositoryFactory.repos[es_repo_key] = MongoExecutionStoreRepo( + collection=_get_db()[final_collection_name], + dataset_id=dataset_id, + notification_service=notification_service, + ) + + return RepositoryFactory.repos[es_repo_key] diff --git a/fiftyone/factory/repos/execution_store.py b/fiftyone/factory/repos/execution_store.py index d7ef2d7e78f..632abb9b68f 100644 --- a/fiftyone/factory/repos/execution_store.py +++ b/fiftyone/factory/repos/execution_store.py @@ -6,9 +6,10 @@ | """ +import logging from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from bson import ObjectId @@ -17,35 +18,17 @@ StoreDocument, KeyPolicy, ) +from fiftyone.operators.store.notification_service import ( + ChangeStreamNotificationService, + default_notification_service, + is_notification_service_disabled, +) -# -# TODO: update these doc strings to match fiftyone patterns! -# - - -class ExecutionStoreRepo(ABC): - """Abstract base class for execution store repositories. - - Each instance operates in a context: - - If a `dataset_id` is provided, it operates on stores associated with that dataset. - - If no `dataset_id` is provided, it operates on stores not associated with any dataset. - - To operate on all stores across all contexts, use the ``XXX_global()`` - methods that this class provides. - """ +logger = logging.getLogger(__name__) - def __init__(self, dataset_id: Optional[ObjectId] = None, is_cache=False): - """Initialize the execution store repository. - Args: - dataset_id (Optional[ObjectId]): the dataset ID to operate on - is_cache (False): whether the store is a cache store - """ - if dataset_id is not None and not isinstance(dataset_id, ObjectId): - raise ValueError( - f"dataset_id must be an ObjectId, got {type(dataset_id).__name__}" - ) - self._dataset_id = dataset_id +class AbstractExecutionStoreRepo(ABC): + """Abstract base class for execution store repositories.""" @abstractmethod def create_store( @@ -65,16 +48,6 @@ def create_store( """ pass - @abstractmethod - def clear_cache(self, store_name=None) -> None: - """Clear all keys with either a ``ttl`` or ``policy="eviction"``. - - Args: - store_name (str, optional): the name of the store to clear. If None, - all stores will be queried for deletion. - """ - pass - @abstractmethod def get_store(self, store_name: str) -> Optional[StoreDocument]: """Get a store from the store collection. @@ -97,6 +70,7 @@ def has_store(self, store_name: str) -> bool: Returns: bool: True if the store exists, False otherwise """ + pass @abstractmethod def list_stores(self) -> List[str]: @@ -138,6 +112,7 @@ def set_key( policy: str = "persist", ) -> KeyDocument: """Set a key in a store. + Args: store_name (str): The name of the store to set the key in. key (str): The key to set. @@ -147,7 +122,6 @@ def set_key( policy (str): The eviction policy for the key. One of: - ``"persist"`` (default): Key is persistent until deleted. - ``"evict"``: Key is eligible for eviction or cache clearing. - Returns: KeyDocument: The created or updated key document. """ @@ -158,13 +132,11 @@ def set_cache_key( self, store_name: str, key: str, value: Any, ttl: Optional[int] = None ) -> KeyDocument: """Set a cache key in a store. - Args: store_name (str): the name of the store to set the cache key in key (str): the cache key to set value (Any): the value to set ttl (Optional[int]): the TTL of the cache key - Returns: KeyDocument: the created or updated cache key document """ @@ -194,6 +166,7 @@ def get_key(self, store_name: str, key: str) -> Optional[KeyDocument]: Returns: Optional[KeyDocument]: the key document, or None if the key does not exist """ + pass @abstractmethod def update_ttl(self, store_name: str, key: str, ttl: int) -> bool: @@ -220,6 +193,7 @@ def delete_key(self, store_name: str, key: str) -> bool: Returns: bool: True if the key was deleted, False otherwise """ + pass @abstractmethod def list_keys(self, store_name: str) -> List[str]: @@ -254,6 +228,19 @@ def cleanup(self) -> int: """ pass + @abstractmethod + def clear_cache(self, store_name: Optional[str] = None) -> int: + """Clear all keys with either a ``ttl`` or ``policy="evict"``. + + Args: + store_name (Optional[str]): the name of the store to clear. If None, + all stores will be queried for deletion. + + Returns: + int: the number of documents deleted + """ + pass + @abstractmethod def has_store_global(self, store_name: str) -> bool: """Check if a store exists in the global store collection. @@ -297,26 +284,114 @@ def delete_store_global(self, store_name: str) -> int: pass +class ExecutionStoreRepo(AbstractExecutionStoreRepo): + """Base class for execution store repositories. + + Each instance operates in a context: + - If a `dataset_id` is provided, it operates on stores associated with that dataset. + - If no `dataset_id` is provided, it operates on stores not associated with any dataset. + + To operate on all stores across all contexts, use the ``XXX_global()`` + methods that this class provides. + """ + + def __init__( + self, + dataset_id: Optional[ObjectId] = None, + notification_service: Optional[ChangeStreamNotificationService] = None, + ): + """Initialize the execution store repository. + + Args: + dataset_id (Optional[ObjectId]): the dataset ID to operate on + notification_service: the notification service to use. If not + provided, the default notification service will be used. + """ + if dataset_id is not None and not isinstance(dataset_id, ObjectId): + raise ValueError( + f"dataset_id must be an ObjectId, got {type(dataset_id).__name__}" + ) + self._dataset_id = dataset_id + + if not is_notification_service_disabled(): + if notification_service is None: + self._notification_service = default_notification_service + else: + self._notification_service = notification_service + else: + logger.warning("Execution store notification service is disabled") + + def subscribe( + self, + store_name: str, + callback: Callable[[str], None], + ) -> str: + """Subscribe to changes in a store. + + Args: + store_name (str): the name of the store to subscribe to + callback (Callable[[str], None]): the callback to call when a change occurs + + Returns: + str: the subscription ID + + Raises: + ValueError: if no notification service is available + """ + if not self._notification_service: + raise ValueError( + "Cannot subscribe when execution store notification service is disabled" + ) + + return self._notification_service.subscribe( + store_name, callback, str(self._dataset_id) + ) + + def unsubscribe(self, subscription_id: str) -> bool: + """Unsubscribe from changes in a store. + + Args: + subscription_id (str): the subscription ID to unsubscribe + + Returns: + bool: True if the subscription was removed, False otherwise + + Raises: + ValueError: if no notification service is available + """ + if not self._notification_service: + raise ValueError( + "Cannot unsubscribe without execution store notification service" + ) + + return self._notification_service.unsubscribe(subscription_id) + + class MongoExecutionStoreRepo(ExecutionStoreRepo): """MongoDB implementation of the execution store repository.""" COLLECTION_NAME = "execution_store" def __init__( - self, collection, dataset_id: Optional[ObjectId] = None, is_cache=False + self, + collection, + dataset_id: Optional[ObjectId] = None, + notification_service: Optional[ChangeStreamNotificationService] = None, ): if dataset_id is not None and not isinstance(dataset_id, ObjectId): raise ValueError( f"dataset_id must be an ObjectId, got {type(dataset_id).__name__}" ) - super().__init__(dataset_id, is_cache) self._collection = collection + super().__init__(dataset_id, notification_service) + self._create_indexes() def _create_indexes(self): indices = [idx["name"] for idx in self._collection.list_indexes()] expires_at_name = "expires_at" store_name_name = "store_name" + updated_at_name = "updated_at" key_name = "key" full_key_name = "unique_store_index" dataset_id_name = "dataset_id" @@ -332,10 +407,12 @@ def _create_indexes(self): name=full_key_name, unique=True, ) + for name in [ store_name_name, key_name, dataset_id_name, + updated_at_name, policy_name, ]: if name not in indices: @@ -411,6 +488,9 @@ def count_stores(self) -> int: return result[0]["total_stores"] if result else 0 def delete_store(self, store_name: str) -> int: + if self._notification_service: + self._notification_service.unsubscribe_all(store_name) + result = self._collection.delete_many( { "store_name": store_name, @@ -446,6 +526,7 @@ def set_key( policy=policy, ) ) + on_insert_fields = { "store_name": store_name, "key": key, @@ -603,8 +684,12 @@ def delete_store_global(self, store_name: str) -> int: class InMemoryExecutionStoreRepo(ExecutionStoreRepo): """In-memory implementation of execution store repository.""" - def __init__(self, dataset_id: Optional[ObjectId] = None): - super().__init__(dataset_id) + def __init__( + self, + dataset_id: Optional[ObjectId] = None, + notification_service: Optional[ChangeStreamNotificationService] = None, + ): + super().__init__(dataset_id, notification_service) self._docs = {} def _doc_key(self, store_name: str, key: str) -> tuple: @@ -697,9 +782,11 @@ def set_key( updated_at=now, expires_at=expiration, dataset_id=self._dataset_id, - policy="evict" - if policy == "evict" or ttl is not None - else "persist", + policy=( + "evict" + if policy == "evict" or ttl is not None + else "persist" + ), ) ) composite_key = self._doc_key(store_name, key) diff --git a/fiftyone/operators/__init__.py b/fiftyone/operators/__init__.py index 1a517f5e52e..3e63aff5422 100644 --- a/fiftyone/operators/__init__.py +++ b/fiftyone/operators/__init__.py @@ -24,6 +24,7 @@ from .store import ExecutionStore from .categories import Categories from .cache import execution_cache +from .sse import SseOperator, SseOperatorConfig # This enables Sphinx refs to directly use paths imported here __all__ = [k for k, v in globals().items() if not k.startswith("_")] diff --git a/fiftyone/operators/executor.py b/fiftyone/operators/executor.py index bdd3c544168..5b4bf0fa730 100644 --- a/fiftyone/operators/executor.py +++ b/fiftyone/operators/executor.py @@ -364,6 +364,11 @@ async def execute_or_delegate_operator( error_message=str(error), ) + if hasattr(operator, "IS_SSE_OPERATOR"): + return ExecutionResult( + result=result, executor=executor, is_sse=True + ) + return ExecutionResult(result=result, executor=executor) @@ -1025,6 +1030,8 @@ class ExecutionResult(object): delegated (False): whether execution was delegated outputs_schema (None): a JSON dict representing the output schema of the operator + is_sse (False): whether execution was from an operator handling + server-sent events (SSE) """ def __init__( @@ -1036,6 +1043,7 @@ def __init__( validation_ctx=None, delegated=False, outputs_schema=None, + is_sse=False, ): self.result = result self.executor = executor @@ -1044,6 +1052,7 @@ def __init__( self.validation_ctx = validation_ctx self.delegated = delegated self.outputs_schema = outputs_schema + self.is_sse = is_sse @property def is_generator(self): diff --git a/fiftyone/operators/message.py b/fiftyone/operators/message.py index 1fc3f274275..f39ef51ba1f 100644 --- a/fiftyone/operators/message.py +++ b/fiftyone/operators/message.py @@ -5,8 +5,11 @@ | `voxel51.com `_ | """ + +import dataclasses from enum import Enum import json +from typing import Optional, Any, Dict class MessageType(Enum): @@ -33,3 +36,65 @@ def to_json(self): def to_json_line(self): return json.dumps(self.to_json()) + "\n" + + +@dataclasses.dataclass +class MessageMetadata: + """Metadata for a store notification message.""" + + operation_type: Optional[str] = None + dataset_id: Optional[str] = None + timestamp: Optional[str] = None + + +@dataclasses.dataclass +class MessageData: + """Data structure for messages sent by the notification service.""" + + key: str + value: Any + metadata: MessageMetadata + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MessageData": + """Create a MessageData instance from a dictionary. + + Args: + data: Dictionary containing message data + + Returns: + MessageData instance + """ + metadata = MessageMetadata(**data.get("metadata", {})) + return cls( + key=data.get("key", ""), value=data.get("value"), metadata=metadata + ) + + @classmethod + def from_json(cls, json_str: str) -> "MessageData": + """Create a MessageData instance from a JSON string. + + Args: + json_str: JSON string containing message data + + Returns: + MessageData instance + """ + data = json.loads(json_str) + return cls.from_dict(data) + + def to_dict(self) -> Dict[str, Any]: + """Convert the MessageData instance to a dictionary. + + Returns: + Dictionary representation + """ + return dataclasses.asdict(self) + + def to_json(self) -> str: + """Convert the MessageData instance to a JSON string. + + Returns: + JSON string representation + """ + return json.dumps(self.to_dict()) diff --git a/fiftyone/operators/remote_notifier.py b/fiftyone/operators/remote_notifier.py new file mode 100644 index 00000000000..8cc7ee4464c --- /dev/null +++ b/fiftyone/operators/remote_notifier.py @@ -0,0 +1,224 @@ +""" +FiftyOne operator server SSE notifier for execution store events. + +This module provides an SSE notifier that listens for notification requests +targeting a specific execution store. When a broadcast is sent to a store, +all connected SSE clients subscribed to that store will receive the message. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import asyncio +import json +import logging +import time +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Dict, Optional, Set, Tuple + +from sse_starlette.sse import EventSourceResponse + +logger = logging.getLogger(__name__) + + +class RemoteNotifier(ABC): + @abstractmethod + async def broadcast_to_store(self, store_name: str, message: str) -> None: + """ + Broadcast a message to all remote subscribers of the given store. + + Args: + store_name: The name of the store to which the message should be broadcast. + message: The message payload to send to the subscribers. + """ + pass + + +class SseNotifier(RemoteNotifier): + """ + Handles the logic for broadcasting messages and managing client subscriptions + for Server-Sent Events (SSE) notifications. + """ + + def __init__(self) -> None: + # Maps store names to a set of tuples (queue, dataset_id) + self.store_queues: Dict[ + str, Set[Tuple[asyncio.Queue, Optional[str]]] + ] = {} + + async def broadcast_to_store(self, store_name: str, message: str) -> None: + """ + Broadcast a message to all connected SSE clients subscribed to the specified store. + Handles disconnected clients gracefully without raising exceptions. + + Args: + store_name: The name of the store to broadcast to. + message: The message to broadcast. + """ + if store_name in self.store_queues: + # Try to extract dataset_id from message for filtering + dataset_id = None + try: + msg_data = json.loads(message) + dataset_id = msg_data.get("metadata", {}).get("dataset_id") + except Exception: + # If we can't parse the message, continue without dataset filtering + pass + + logger.debug( + "Broadcasting message to store '%s'%s: %s", + store_name, + f" for dataset {dataset_id}" if dataset_id else "", + message, + ) + + # Create a copy of the queues to avoid modification during iteration + queue_items = list(self.store_queues[store_name]) + queues_to_remove = set() + + for queue, client_dataset_id in queue_items: + # Filter by dataset_id if both are specified + if ( + client_dataset_id is not None + and dataset_id is not None + and dataset_id != client_dataset_id + ): + continue + + try: + # Use put_nowait to avoid blocking on full queues + # This prevents one slow client from blocking others + queue.put_nowait(message) + except asyncio.QueueFull: + logger.debug( + f"Queue full for client in store '{store_name}', dropping message" + ) + except Exception as e: + # If we encounter an error with this queue, mark it for removal + logger.debug( + f"Error sending to client in store '{store_name}': {e}" + ) + queues_to_remove.add((queue, client_dataset_id)) + + # Clean up any problematic queues + for queue_item in queues_to_remove: + self._unregister_queue( + store_name, queue_item[0], queue_item[1] + ) + else: + logger.debug( + "No subscribers found for store '%s'. Message not sent.", + store_name, + ) + + async def get_event_source_response( + self, store_name: str, dataset_id: Optional[str] = None + ) -> EventSourceResponse: + """ + Creates an EventSourceResponse for a client subscribing to a specific store. + It registers a new queue for the client and produces an async generator to stream events. + + Args: + store_name: The name of the store to subscribe to. + dataset_id: Optional dataset ID to filter events by. + + Returns: + An EventSourceResponse for streaming events to the client. + """ + queue: asyncio.Queue = asyncio.Queue() + + if store_name not in self.store_queues: + self.store_queues[store_name] = set() + self.store_queues[store_name].add((queue, dataset_id)) + + logger.debug( + "New SSE connection for store: %s%s", + store_name, + f" and dataset: {dataset_id}" if dataset_id else "", + ) + logger.debug( + "Total SSE connections for store %s: %s", + store_name, + len(self.store_queues.get(store_name, set())), + ) + + await self.sync_current_state_for_client(queue, store_name, dataset_id) + + async def event_generator() -> AsyncGenerator[str, None]: + try: + while True: + message = await queue.get() + yield message + queue.task_done() + except asyncio.CancelledError: + logger.debug( + "SSE client disconnected from store: %s", store_name + ) + finally: + self._unregister_queue(store_name, queue, dataset_id) + logger.debug( + "Total SSE connections for store %s: %s", + store_name, + len(self.store_queues.get(store_name, set())), + ) + + return EventSourceResponse(event_generator()) + + async def sync_current_state_for_client( + self, + queue: asyncio.Queue, + store_name: str, + dataset_id: Optional[str] = None, + ) -> None: + """ + Broadcast the current state of the store to all connected clients. + """ + # note: unfortunate dependency on the notification service + from fiftyone.operators.store.notification_service import ( + default_notification_service, + ) + + # wait until the notification service is started, with a timeout of 10 seconds + start_time = time.time() + while not default_notification_service.is_running: + if time.time() - start_time > 10: + raise TimeoutError( + "Notification service failed to start within 10 seconds" + ) + await asyncio.sleep(0.5) + + asyncio.run_coroutine_threadsafe( + default_notification_service._broadcast_current_state_for_store( + store_name, + dataset_id, + lambda msg: queue.put_nowait(msg.to_json()), + ), + default_notification_service.dedicated_event_loop, + ) + + def _unregister_queue( + self, + store_name: str, + queue: asyncio.Queue, + dataset_id: Optional[str] = None, + ) -> None: + """ + Remove the client's queue from the store. Clean up if no queues remain. + + Args: + store_name: The name of the store to unregister from. + queue: The queue to unregister. + dataset_id: Optional dataset ID associated with the queue. + """ + if store_name in self.store_queues: + self.store_queues[store_name].discard((queue, dataset_id)) + if not self.store_queues[store_name]: + del self.store_queues[store_name] + logger.debug( + "No more subscribers for store: %s. Cleaned up.", + store_name, + ) + + +default_sse_notifier = SseNotifier() diff --git a/fiftyone/operators/server.py b/fiftyone/operators/server.py index 1334cd5ce28..ffc7f04e988 100644 --- a/fiftyone/operators/server.py +++ b/fiftyone/operators/server.py @@ -7,6 +7,7 @@ """ import types +from sse_starlette.sse import EventSourceResponse from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException from starlette.requests import Request @@ -229,6 +230,38 @@ async def post(self, request: Request, data: dict) -> dict: return result.to_dict() if result else {} +class SubscribeToExecutionStoreAsOperator(HTTPEndpoint): + @route + async def post(self, request: Request, data: dict) -> EventSourceResponse: + dataset_name = data.get("dataset_name", None) + dataset_ids = [dataset_name] + operator_uri = data.get("operator_uri", None) + if operator_uri is None: + raise ValueError("Operator URI must be provided") + + registry = await PermissionedOperatorRegistry.from_exec_request( + request, dataset_ids=dataset_ids + ) + if registry.can_execute(operator_uri) is False: + return create_permission_error(operator_uri) + + if registry.operator_exists(operator_uri) is False: + error_detail = { + "message": "Operator '%s' does not exist" % operator_uri, + "loading_errors": registry.list_errors(), + } + raise HTTPException(status_code=404, detail=error_detail) + + execution_result = await execute_or_delegate_operator( + operator_uri, data + ) + + if execution_result.is_sse: + return execution_result.result + else: + return execution_result.to_json() + + OperatorRoutes = [ ("/operators", ListOperators), ("/operators/execute", ExecuteOperator), @@ -236,4 +269,8 @@ async def post(self, request: Request, data: dict) -> dict: ("/operators/resolve-type", ResolveType), ("/operators/resolve-placements", ResolvePlacements), ("/operators/resolve-execution-options", ResolveExecutionOptions), + ( + "/operators/subscribe-execution-store", + SubscribeToExecutionStoreAsOperator, + ), ] diff --git a/fiftyone/operators/sse.py b/fiftyone/operators/sse.py new file mode 100644 index 00000000000..53c4611c782 --- /dev/null +++ b/fiftyone/operators/sse.py @@ -0,0 +1,62 @@ +""" +FiftyOne SSE operators. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import logging + +import fiftyone.operators as foo +from fiftyone.operators.executor import ExecutionContext +from fiftyone.operators.remote_notifier import default_sse_notifier + +logger = logging.getLogger(__name__) + + +class SseOperatorConfig(foo.OperatorConfig): + def __init__( + self, + name, + label, + store_name, + description=None, + icon=None, + light_icon=None, + dark_icon=None, + ): + super().__init__( + name, + label=label, + description=description, + icon=icon, + light_icon=light_icon, + dark_icon=dark_icon, + ) + self.store_name = store_name + + +class SseOperator(foo.Operator): + IS_SSE_OPERATOR = True + + @property + def subscription_config(self): + raise NotImplementedError( + "Subscriptions must define subscription_config" + ) + + @property + def config(self): + return self.subscription_config + + async def execute(self, ctx: ExecutionContext): + if not self.subscription_config: + raise ValueError("subscription_config must be defined") + + dataset_id = ctx.request_params.get("dataset_id", None) + + return await default_sse_notifier.get_event_source_response( + self.subscription_config.store_name, + dataset_id, + ) diff --git a/fiftyone/operators/store/notification_service.py b/fiftyone/operators/store/notification_service.py new file mode 100644 index 00000000000..13b8323520b --- /dev/null +++ b/fiftyone/operators/store/notification_service.py @@ -0,0 +1,651 @@ +""" +Notification service for ExecutionStore using MongoDB Change Streams. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import asyncio +import logging +import os +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from threading import Thread +from typing import Callable, Dict, List, Optional, Set + +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase +from pymongo.errors import OperationFailure + +import fiftyone.core.odm as foo +from fiftyone.operators.message import MessageData, MessageMetadata +from fiftyone.operators.remote_notifier import ( + RemoteNotifier, + default_sse_notifier, +) +from fiftyone.operators.store.subscription_registry import ( + LocalSubscriptionRegistry, + default_subscription_registry, +) + +logger = logging.getLogger(__name__) + +POLL_INTERVAL_SECONDS = int( + os.getenv("FIFTYONE_EXECUTION_STORE_POLL_INTERVAL_SECONDS", 5) +) + + +class ChangeStreamNotificationService(ABC): + """Abstract base class for change stream notification services.""" + + @abstractmethod + def subscribe( + self, + store_name: str, + callback: Callable[[str], None], + dataset_id: Optional[str] = None, + ) -> str: + """Register a local subscriber for a specific store. + + Args: + store_name: The name of the store to subscribe to. + callback: The callback to call when a change occurs. + dataset_id: Optional dataset ID to filter changes by. + + Returns: + The subscription id. + """ + pass + + @abstractmethod + def unsubscribe(self, subscription_id: str): + """Unsubscribe local subscribers from a specific store. + + Args: + subscription_id: The subscription id to unsubscribe from. + """ + pass + + @abstractmethod + def unsubscribe_all(self, store_name: str): + """Unsubscribe from all changes in a store. + + Args: + store_name (str): the name of the store to unsubscribe from + """ + pass + + @abstractmethod + def notify(self, store_name: str, message_data: MessageData) -> None: + """Notify local subscribers and remote listeners of a change. + + Args: + store_name: The name of the store that changed. + message: The message to notify subscribers with. + """ + pass + + @abstractmethod + async def start(self) -> None: + """Start watching for database changes.""" + pass + + @abstractmethod + async def stop(self) -> None: + """Stop watching for database changes.""" + pass + + +class MongoChangeStreamNotificationService(ChangeStreamNotificationService): + def __init__( + self, + collection_name: str, + remote_notifier: RemoteNotifier = None, + registry: LocalSubscriptionRegistry = None, + ): + self._subscription_registry = ( + default_subscription_registry if registry is None else registry + ) + self._remote_notifier = remote_notifier + + self._collection_name = collection_name + + # we init this in start(), which runs in a dedicated event loop + # init-ing it in ctor or in global might cause wrong binding + self.dedicated_event_loop: Optional[asyncio.AbstractEventLoop] = None + self._async_db: AsyncIOMotorDatabase = None + self._collection_async: AsyncIOMotorCollection = None + self._last_poll_time: datetime = None + self.is_running: bool = False + + # Track keys per store for polling + self._last_keys: Dict[str, set] = {} + + # Reference to running task + self._task: Optional[asyncio.Task] = None + + # Event to signal the task to stop, + # will be initialized in the start method + self._stop_event = None + + self._background_tasks: Set[asyncio.Task] = set() + + async def _get_current_stores(self) -> List[str]: + return await self._collection_async.distinct("store_name") + + def subscribe( + self, + store_name: str, + callback: Callable[[str], None], + dataset_id: Optional[str] = None, + ) -> str: + """Register a local subscriber for a specific store. + + Args: + store_name: The name of the store to subscribe to. + callback: The callback to call when a change occurs. + dataset_id: Optional dataset ID to filter changes by. + + Returns: + The subscription id. + """ + log_message = f"Subscribing to store {store_name}" + if dataset_id: + log_message += f" for dataset {dataset_id}" + logger.debug(log_message) + + subscription_id = self._subscription_registry.subscribe( + store_name, callback, dataset_id + ) + + # we need to broadcast the current state as soon as the subscriber + # is registered, so that the subscriber has the latest state + # of the store. otherwise, the subscriber will only receive + # next state changes after the first change occurs + + if self.dedicated_event_loop is None: + logger.warning( + "Event loop for notification service is not set, " + "cannot broadcast current state" + ) + else: + # note: subscribe is usually called from the main thread, + # so we need to use asyncio.run_coroutine_threadsafe + asyncio.run_coroutine_threadsafe( + self._broadcast_current_state_for_store( + store_name, dataset_id, callback + ), + self.dedicated_event_loop, + ) + + return subscription_id + + def unsubscribe(self, subscription_id: str): + """Unsubscribe from a specific store. + + Args: + subscription_id: The subscription id to unsubscribe from. + """ + self._subscription_registry.unsubscribe(subscription_id) + + def unsubscribe_all(self, store_name: str): + """Unsubscribe from all changes in a store. + + Args: + store_name (str): the name of the store to unsubscribe from + """ + self._subscription_registry.unsubscribe_all(store_name) + + async def start( + self, dedicated_event_loop: asyncio.AbstractEventLoop + ) -> None: + """Start watching the collection for changes using change streams or polling.""" + self.dedicated_event_loop = dedicated_event_loop + + self._stop_event = asyncio.Event() + + self._async_db = foo.get_async_db_conn() + self._collection_async = self._async_db[self._collection_name] + + self._task = asyncio.create_task(self._run()) + + try: + await self._task + except asyncio.CancelledError: + logger.debug("Change stream/polling task cancelled.") + + async def notify(self, store_name: str, message_data: MessageData) -> None: + """ + Notify local subscribers and remote listeners of a change. + Handles exceptions gracefully to prevent failures when clients disconnect. + + Args: + store_name: The name of the store that changed + message_data: The message data to notify subscribers with + """ + try: + # Get dataset_id for filtering + message_dataset_id = message_data.metadata.dataset_id + except Exception as e: + logger.warning(f"Error accessing dataset_id for filtering: {e}") + message_dataset_id = None + + # Notify local subscribers + # snapshot to avoid concurrent mutations + subscribers = list( + self._subscription_registry.get_subscribers( + store_name=store_name + ).items() + ) + + for subscription_id, (callback, subscriber_dataset_id) in subscribers: + # Filter by dataset_id if specified in the subscription + if ( + subscriber_dataset_id is not None + and message_dataset_id is not None + and subscriber_dataset_id != message_dataset_id + ): + continue + + try: + callback(message_data) + except Exception as e: + logger.warning( + f"Error notifying local subscriber {subscription_id}: {e}" + ) + # Consider removing problematic subscribers + try: + self._subscription_registry.unsubscribe(subscription_id) + except Exception: + # If unsubscribe fails, just continue + pass + + # Notify remote listeners + if self._remote_notifier: + task = asyncio.create_task( + self._remote_notifier.broadcast_to_store( + store_name, message_data.to_json() + ) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + async def _run(self) -> None: + """Run the change stream/polling task.""" + logger.debug("Starting change stream/polling task") + self.is_running = True + + try: + # First attempt to use change streams + await self._run_change_stream() + except OperationFailure: + logger.warning( + f"Mongo change stream is not available. Falling back to polling." + ) + # Try polling as a fallback for any error + await self._start_polling() + + async def _run_change_stream(self) -> None: + """Run the change stream watcher.""" + # Watch all changes in the collection - filtering happens at subscriber level + pipeline = [] + + # full_document="updateLookup" is required to get the full document in the change stream + # https://motor.readthedocs.io/en/stable/api-asyncio/asyncio_motor_change_stream.html + async with self._collection_async.watch( + pipeline, full_document="updateLookup" + ) as stream: + if self._stop_event and self._stop_event.is_set(): + return + + try: + while stream.alive and ( + not self._stop_event or not self._stop_event.is_set() + ): + try: + change = await stream.next() + + if change: + await self._handle_change(change) + + except StopAsyncIteration: + break + except asyncio.CancelledError: + break + except Exception as e: + logger.exception( + f"Error processing change stream: {e}" + ) + await asyncio.sleep(1) + finally: + await stream.close() + + async def _broadcast_current_state_for_store( + self, + store_name: str, + dataset_id: Optional[str] = None, + callback: Optional[Callable] = None, + ) -> None: + """Broadcast the current state for a specific store to a single subscriber. + + Args: + store_name: The name of the store to broadcast state for + dataset_id: Optional dataset ID to filter by + callback: The callback function to send messages to + """ + if not callback: + return + + logger.debug( + f"broadcasting current state for store {store_name} with dataset_id {dataset_id}" + ) + + query = { + "store_name": store_name, + "key": {"$ne": "__store__"}, + } + + if dataset_id is not None: + query["dataset_id"] = ( + ObjectId(dataset_id) + if isinstance(dataset_id, str) + else dataset_id + ) + + docs = await self._collection_async.find(query).to_list() + for doc in docs: + message_data = MessageData( + key=doc["key"], + value=doc["value"], + metadata=MessageMetadata( + operation_type="initial", + dataset_id=( + str(doc.get("dataset_id")) + if doc.get("dataset_id") is not None + else None + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + ) + try: + callback(message_data) + except Exception as e: + logger.warning( + f"Error sending initial state to subscriber: {e}" + ) + break + + async def stop(self) -> None: + """Signal stop watching the collection for changes. + Assume this is called from thread safe context. + """ + if self._stop_event: + self._stop_event.set() + + # Cancel all in-flight background tasks + for task in self._background_tasks: + task.cancel() + await asyncio.gather(*self._background_tasks, return_exceptions=True) + + # Cancel the the main watcher/polling task + if self._task is not None: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + self._subscription_registry.empty_subscribers() + self.is_running = False + + async def _handle_change(self, change: dict) -> None: + """ + Process a change document from MongoDB change stream. + + Args: + change: The change document from MongoDB. + """ + try: + # from mongodb docs: + # https://www.mongodb.com/docs/manual/changeStreams/ + # and pymongo docs: + # https://pymongo.readthedocs.io/en/stable/api/pymongo/change_stream.html + # + # operationType can be: insert/update/replace/delete + # fullDocument is the most current majority-committed version of the document + # documentKey contains a document _id (not the same as a full document) + # updateDescription contains a description of the fields that were updated/removed (optional) + + operation_type = change["operationType"] + + # Get the store name from the document + if "fullDocument" in change and change["fullDocument"]: + store_name = change["fullDocument"].get("store_name") + dataset_id = change["fullDocument"].get("dataset_id") + key = change["fullDocument"].get("key") + value = change["fullDocument"].get("value") + else: + # For delete operations, we need to use the document key + # important todo: right now, we won't detect deletes + # because we need to do a minor refactor to embed store_name + # and dataset_id and key in the documentKey._id. + # Right now, we'll know document is deleted but not + # useful metadata about it + doc_id = change["documentKey"].get("_id", {}) + if isinstance(doc_id, dict): + store_name = doc_id.get("store_name") + dataset_id = doc_id.get("dataset_id") + key = doc_id.get("key") + value = None + else: + # If we can't get the store name, skip this change + return + + if not store_name: + # If we can't get the store name, skip this change + return + + time_of_change = change["wallTime"].isoformat() + + message_data = MessageData( + key=key, + value=value, + metadata=MessageMetadata( + operation_type=operation_type, + dataset_id=( + str(dataset_id) if dataset_id is not None else None + ), + timestamp=time_of_change, + ), + ) + + # Directly notify subscribers with message_data + await self.notify(store_name, message_data) + except Exception as e: + logger.exception(f"Error handling change: {e}") + + async def _start_polling(self) -> None: + """Fallback to polling the collection periodically + if change streams are not available.""" + logger.debug("Starting polling-based notification service") + + while not self._stop_event or not self._stop_event.is_set(): + try: + await self._poll() + except Exception as e: + logger.exception(f"Error during polling: {e}") + # If there's an error during polling, wait a bit before retrying + # to avoid rapid retries in case of persistent errors + if self._stop_event and self._stop_event.is_set(): + break + + # Wait for the next polling interval + try: + await asyncio.sleep(POLL_INTERVAL_SECONDS) + except asyncio.CancelledError: + logger.debug("Polling task cancelled during sleep") + break + + async def _poll(self) -> None: + """Check for changes since the last poll.""" + now = datetime.now(timezone.utc) + store_names = await self._get_current_stores() + + if self._last_poll_time is None: + self._last_poll_time = now + + for store_name in store_names: + self._last_keys[store_name] = set( + await self._collection_async.distinct( + "key", {"store_name": store_name} + ) + ) + return + + for store_name in store_names: + current_keys = set( + await self._collection_async.distinct( + "key", {"store_name": store_name} + ) + ) + previous_keys = self._last_keys.get(store_name, set()) + + # Detect deleted keys + deleted_keys = previous_keys - current_keys + for key in deleted_keys: + if key == "__store__": + continue + message_data = MessageData( + key=key, + value=None, + metadata=MessageMetadata( + operation_type="delete", + dataset_id=None, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + ) + await self.notify(store_name, message_data) + + # Detect inserts and updates + query = { + "store_name": store_name, + "updated_at": {"$gt": self._last_poll_time}, + } + + docs = await self._collection_async.find(query).to_list() + for doc in docs: + key = doc["key"] + if key == "__store__": + continue + value = doc["value"] + dataset_id = doc.get("dataset_id") + event = "insert" if key not in previous_keys else "update" + + message_data = MessageData( + key=key, + value=value, + metadata=MessageMetadata( + operation_type=event, + dataset_id=( + str(dataset_id) if dataset_id is not None else None + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + ) + await self.notify(store_name, message_data) + + self._last_keys[store_name] = current_keys + self._last_poll_time = now + + +class MongoChangeStreamNotificationServiceLifecycleManager: + def __init__( + self, notification_service: MongoChangeStreamNotificationService + ): + self._notification_service = notification_service + self._notification_service_loop: Optional[ + asyncio.AbstractEventLoop + ] = None + self._notification_thread: Optional[Thread] = None + + def start_in_dedicated_thread(self) -> None: + """Create a dedicated event loop in a new thread + and start the notification service.""" + if self._notification_thread and self._notification_thread.is_alive(): + logger.info("Notification service daemon already running") + return + + logger.info("Starting execution store notification service daemon...") + + def run_service_in_thread(): + self._notification_service_loop = asyncio.new_event_loop() + + asyncio.set_event_loop(self._notification_service_loop) + + try: + self._notification_service_loop.run_until_complete( + self._notification_service.start( + self._notification_service_loop + ) + ) + except Exception: + logger.exception("Notification service failed to start") + + self._notification_thread = Thread( + target=run_service_in_thread, daemon=True + ) + self._notification_thread.start() + + async def stop(self) -> None: + if ( + self._notification_service + and not self._notification_service_loop.is_closed() + ): + try: + logger.info("Stopping notification service gracefully") + + fut = asyncio.run_coroutine_threadsafe( + self._notification_service.stop(), + self._notification_service_loop, + ) + fut.result(timeout=5) + except Exception: + logger.exception( + "Failed to stop notification service gracefully" + ) + finally: + self._notification_thread.join(timeout=5) + + if self._notification_thread.is_alive(): + logger.warning( + "Notification thread did not stop; forcing exit" + ) + + try: + self._notification_service_loop.close() + except Exception as e: + logger.warning( + f"Failed to close notification service loop: {e}" + ) + + self._notification_service_loop = None + self._notification_thread = None + + logger.info("Notification service stopped!") + + +def is_notification_service_disabled() -> bool: + """Check if the notification service is disabled.""" + return ( + os.getenv( + "FIFTYONE_EXECUTION_STORE_NOTIFICATION_SERVICE_DISABLED", "false" + ).lower() + == "true" + ) + + +default_notification_service = MongoChangeStreamNotificationService( + collection_name="execution_store", + remote_notifier=default_sse_notifier, +) diff --git a/fiftyone/operators/store/service.py b/fiftyone/operators/store/service.py index 9cc0650f1ef..9e11ddf6222 100644 --- a/fiftyone/operators/store/service.py +++ b/fiftyone/operators/store/service.py @@ -7,9 +7,16 @@ """ from bson import ObjectId -from typing import Any, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING from fiftyone.operators.store.models import StoreDocument, KeyDocument +from fiftyone.operators.store.notification_service import ( + ChangeStreamNotificationService, +) + +if TYPE_CHECKING: + # so that we can import type without circular imports + from fiftyone.factory.repos.execution_store import ExecutionStoreRepo class ExecutionStoreService(object): @@ -35,6 +42,7 @@ class ExecutionStoreService(object): dataset_id (None): a dataset ID (ObjectId) to scope operations to collection_name (None): a collection name to use for the execution store. If `repo` is provided, this argument is ignored + notification_service (None): an optional notification service for the repository """ def __init__( @@ -42,6 +50,7 @@ def __init__( repo: Optional["ExecutionStoreRepo"] = None, dataset_id: Optional[ObjectId] = None, collection_name: str = None, + notification_service: Optional[ChangeStreamNotificationService] = None, ): from fiftyone.factory.repo_factory import ( @@ -53,6 +62,7 @@ def __init__( repo = RepositoryFactory.execution_store_repo( dataset_id=dataset_id, collection_name=collection_name, + notification_service=notification_service, ) self._dataset_id = dataset_id self._repo: ExecutionStoreRepo = repo @@ -254,7 +264,7 @@ def cleanup(self) -> None: """Deletes all stores associated with the current context.""" self._repo.cleanup() - def has_store_global(self, store_name) -> bool: + def has_store_global(self, store_name: str) -> bool: """Determines whether a store with the given name exists across all datasets and the global context. @@ -282,7 +292,7 @@ def count_stores_global(self) -> int: """ return self._repo.count_stores_global() - def delete_store_global(self, store_name) -> int: + def delete_store_global(self, store_name: str) -> int: """Deletes the specified store across all datasets and the global context. @@ -293,3 +303,28 @@ def delete_store_global(self, store_name) -> int: the number of stores deleted """ return self._repo.delete_store_global(store_name) + + def subscribe( + self, store_name: str, callback: Callable[[str], None] + ) -> str: + """Subscribe to changes in a store. + + Args: + store_name (str): the name of the store to subscribe to + callback (Callable[[str], None]): the callback to call when a change occurs + + Returns: + str: the subscription ID + """ + return self._repo.subscribe(store_name, callback) + + def unsubscribe(self, subscription_id: str) -> bool: + """Unsubscribe from changes in a store. + + Args: + subscription_id (str): the subscription ID to unsubscribe + + Returns: + bool: True if the subscription was removed, False otherwise + """ + return self._repo.unsubscribe(subscription_id) diff --git a/fiftyone/operators/store/store.py b/fiftyone/operators/store/store.py index 862cc818022..6e12b2f7117 100644 --- a/fiftyone/operators/store/store.py +++ b/fiftyone/operators/store/store.py @@ -6,11 +6,14 @@ | """ -from datetime import datetime -from typing import Any, Optional +from typing import Any, Callable, Optional from bson import ObjectId +from fiftyone.operators.message import MessageData +from fiftyone.operators.store.notification_service import ( + default_notification_service, +) from fiftyone.operators.store.service import ExecutionStoreService @@ -39,13 +42,30 @@ def create( store_name: str, dataset_id: Optional[ObjectId] = None, default_policy: str = "persist", - collection_name: Optional[str] = None, ) -> "ExecutionStore": + """Creates a new execution store. + + Args: + store_name: the name of the store + dataset_id: an optional dataset ID to scope the store to + + Returns: + an ExecutionStore instance + """ + from fiftyone.factory.repos.execution_store import ( + MongoExecutionStoreRepo, + ) + + # Create store service with notification service + store_service = ExecutionStoreService( + dataset_id=dataset_id, + collection_name=MongoExecutionStoreRepo.COLLECTION_NAME, + notification_service=default_notification_service, + ) + return ExecutionStore( store_name, - ExecutionStoreService( - dataset_id=dataset_id, collection_name=collection_name - ), + store_service, default_policy, ) @@ -184,3 +204,25 @@ def list_keys(self) -> list[str]: a list of keys in the store """ return self._store_service.list_keys(self.store_name) + + def subscribe(self, callback: Callable[[MessageData], None]) -> str: + """Subscribes to changes in the store. + + Args: + callback: a function that will be called when a change occurs in the store + + Returns: + a subscription ID + """ + return self._store_service.subscribe(self.store_name, callback) + + def unsubscribe(self, subscription_id: str) -> bool: + """Unsubscribes from changes in the store. + + Args: + subscription_id: the subscription ID to unsubscribe from + + Returns: + True if the subscription was removed, False otherwise + """ + return self._store_service.unsubscribe(subscription_id) diff --git a/fiftyone/operators/store/subscription_registry.py b/fiftyone/operators/store/subscription_registry.py new file mode 100644 index 00000000000..c68e5bf08a3 --- /dev/null +++ b/fiftyone/operators/store/subscription_registry.py @@ -0,0 +1,124 @@ +""" +Subscription registry class. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import logging +import threading +import uuid +from abc import ABC, abstractmethod +from typing import Callable, Dict, Optional, Tuple + +logger = logging.getLogger(__name__) + + +class LocalSubscriptionRegistry(ABC): + """Abstract base class for subscription registry.""" + + @abstractmethod + def subscribe( + self, + store_name: str, + callback: Callable[[str], None], + dataset_id: Optional[str] = None, + ) -> str: + """ + Registers a subscription for a given store. + Returns a unique subscription id. + + Args: + store_name: The name of the store to subscribe to. + callback: The callback to call when a change occurs. + dataset_id: Optional dataset ID to filter changes by. + """ + pass + + @abstractmethod + def unsubscribe(self, subscription_id: str) -> bool: + """ + Unsubscribes a subscription by its id. + Returns True if a subscription was removed. + """ + pass + + @abstractmethod + def unsubscribe_all(self, store_name: str): + """ + Unsubscribes all subscriptions for a given store. + """ + pass + + @abstractmethod + def empty_subscribers(self) -> None: + """ + Empties all subscribers. + """ + pass + + @abstractmethod + def get_subscribers( + self, store_name: str + ) -> Dict[str, Tuple[Callable[[str], None], Optional[str]]]: + """ + Retrieves all subscriptions for a given store. + Returns a dictionary mapping subscription id to a tuple of (callback, dataset_id). + """ + pass + + +class InLocalMemorySubscriptionRegistry(LocalSubscriptionRegistry): + def __init__(self): + # Maps store_name -> {subscription_id -> (callback, dataset_id)} + self._registry: Dict[ + str, Dict[str, Tuple[Callable[[str], None], Optional[str]]] + ] = {} + + # registry might either be accessed / modified by main + # thread (store.subscribe()) + # or from notification service daemon thread to get + # list of subscribers for a store + self._lock = threading.Lock() + + def subscribe( + self, + store_name: str, + callback: Callable[[str], None], + dataset_id: Optional[str] = None, + ) -> str: + sub_id = str(uuid.uuid4()) + + with self._lock: + if store_name not in self._registry: + self._registry[store_name] = {} + self._registry[store_name][sub_id] = (callback, dataset_id) + + return sub_id + + def unsubscribe(self, subscription_id: str) -> bool: + with self._lock: + for store, subs in self._registry.items(): + if subscription_id in subs: + del subs[subscription_id] + return True + return False + + def unsubscribe_all(self, store_name: str): + with self._lock: + if store_name in self._registry: + del self._registry[store_name] + + def get_subscribers( + self, store_name: str + ) -> Dict[str, Tuple[Callable[[str], None], Optional[str]]]: + with self._lock: + return self._registry.get(store_name, {}).copy() + + def empty_subscribers(self) -> None: + with self._lock: + self._registry = {} + + +default_subscription_registry = InLocalMemorySubscriptionRegistry() diff --git a/fiftyone/server/app.py b/fiftyone/server/app.py index e9a72191ef6..3e169c6732e 100644 --- a/fiftyone/server/app.py +++ b/fiftyone/server/app.py @@ -5,11 +5,15 @@ | `voxel51.com `_ | """ + +import asyncio +import logging import os import pathlib import stat import eta.core.utils as etau +import strawberry as gql from starlette.applications import Starlette from starlette.datastructures import Headers from starlette.middleware import Middleware @@ -23,10 +27,14 @@ from starlette.routing import Mount, Route from starlette.staticfiles import NotModifiedResponse, PathLike, StaticFiles from starlette.types import Scope -import strawberry as gql import fiftyone as fo import fiftyone.constants as foc +from fiftyone.operators.store.notification_service import ( + MongoChangeStreamNotificationServiceLifecycleManager, + default_notification_service, + is_notification_service_disabled, +) from fiftyone.server.constants import SCALAR_OVERRIDES from fiftyone.server.context import GraphQL from fiftyone.server.extensions import EndSession @@ -34,6 +42,8 @@ from fiftyone.server.query import Query from fiftyone.server.routes import routes +logger = logging.getLogger(__name__) + etau.ensure_dir(os.path.join(os.path.dirname(__file__), "static")) @@ -143,3 +153,36 @@ async def dispatch( ), ], ) + + +@app.on_event("startup") +async def startup_event(): + if is_notification_service_disabled(): + logger.info("Execution Store notification service is disabled") + return + + app.state.lifecycle_manager = ( + MongoChangeStreamNotificationServiceLifecycleManager( + default_notification_service + ) + ) + app.state.lifecycle_manager.start_in_dedicated_thread() + + +@app.on_event("shutdown") +async def shutdown_event(): + if hasattr(app.state, "lifecycle_manager") and app.state.lifecycle_manager: + logger.info("Shutting down notification service...") + try: + await asyncio.wait_for( + app.state.lifecycle_manager.stop(), timeout=5 + ) + logger.info("Notification service shutdown complete") + except asyncio.TimeoutError: + logger.warning( + "Notification service shutdown timed out after 5 seconds" + ) + except Exception as e: + logger.exception( + f"Error during notification service shutdown: {e}" + ) diff --git a/tests/unittests/operators/notification_service_tests.py b/tests/unittests/operators/notification_service_tests.py new file mode 100644 index 00000000000..13611ce88c8 --- /dev/null +++ b/tests/unittests/operators/notification_service_tests.py @@ -0,0 +1,353 @@ +""" +Unit tests for the notification service. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import asyncio +import datetime +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pymongo.errors import OperationFailure + +from fiftyone.operators.message import MessageData, MessageMetadata +from fiftyone.operators.store.notification_service import ( + MongoChangeStreamNotificationService, +) + + +class TestMongoChangeStreamNotificationService(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.remote_notifier = MagicMock() + self.remote_notifier.broadcast_to_store = AsyncMock() + + self.collection = MagicMock() + self.db = MagicMock() + self.db.__getitem__.return_value = self.collection + + with patch("fiftyone.core.odm.get_async_db_conn") as mock_get_db: + mock_get_db.return_value = self.db + self.notification_service = MongoChangeStreamNotificationService( + collection_name="test_collection", + remote_notifier=self.remote_notifier, + ) + self.notification_service._collection_async = self.collection + + def tearDown(self): + self.loop.close() + + def test_subscribe(self): + callback = MagicMock() + subscription_id = self.notification_service.subscribe( + "test_store", callback, dataset_id="test_dataset" + ) + + self.assertIsNotNone(subscription_id) + subscribers = ( + self.notification_service._subscription_registry.get_subscribers( + "test_store" + ) + ) + self.assertIn(subscription_id, subscribers) + self.assertEqual(subscribers[subscription_id][0], callback) + self.assertEqual(subscribers[subscription_id][1], "test_dataset") + + def test_unsubscribe(self): + callback = MagicMock() + subscription_id = self.notification_service.subscribe( + "test_store", callback + ) + + self.notification_service.unsubscribe(subscription_id) + + subscribers = ( + self.notification_service._subscription_registry.get_subscribers( + "test_store" + ) + ) + self.assertNotIn(subscription_id, subscribers) + + def test_unsubscribe_all(self): + callback1 = MagicMock() + callback2 = MagicMock() + + self.notification_service.subscribe("test_store", callback1) + self.notification_service.subscribe("test_store", callback2) + + self.notification_service.unsubscribe_all("test_store") + + subscribers = ( + self.notification_service._subscription_registry.get_subscribers( + "test_store" + ) + ) + self.assertEqual(len(subscribers), 0) + + def test_notify_with_dataset_id_filtering(self): + # Subscribe with dataset_id filter + callback1 = MagicMock() + callback2 = MagicMock() + self.notification_service.subscribe( + "test_store", callback1, dataset_id="dataset1" + ) + self.notification_service.subscribe( + "test_store", callback2, dataset_id="dataset2" + ) + + # Notification for dataset1 + metadata = MessageMetadata(dataset_id="dataset1") + message_data = MessageData( + key="test_key", value="test_value", metadata=metadata + ) + + with patch("asyncio.create_task"): + asyncio.run( + self.notification_service.notify("test_store", message_data) + ) + + callback1.assert_called_once_with(message_data) + callback2.assert_not_called() + self.remote_notifier.broadcast_to_store.assert_called_once_with( + "test_store", message_data.to_json() + ) + + def test_notify_without_dataset_id(self): + # Subscribe without dataset_id filter + callback = MagicMock() + self.notification_service.subscribe("test_store", callback) + + # Notification without dataset_id + metadata = MessageMetadata() + message_data = MessageData( + key="test_key", value="test_value", metadata=metadata + ) + + with patch("asyncio.create_task"): + asyncio.run( + self.notification_service.notify("test_store", message_data) + ) + + callback.assert_called_once_with(message_data) + self.remote_notifier.broadcast_to_store.assert_called_once_with( + "test_store", message_data.to_json() + ) + + +@pytest.fixture +def event_loop(): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +class TestMongoChangeStreamNotificationServiceAsync: + @pytest.fixture + def notification_service(self, event_loop): + remote_notifier = MagicMock() + remote_notifier.broadcast_to_store = AsyncMock() + + collection = MagicMock() + db = MagicMock() + db.__getitem__.return_value = collection + + with patch("fiftyone.core.odm.get_async_db_conn") as mock_get_db: + mock_get_db.return_value = db + notification_service = MongoChangeStreamNotificationService( + collection_name="test_collection", + remote_notifier=remote_notifier, + ) + # Initialize collection manually for testing + notification_service._collection_async = collection + yield notification_service, collection, remote_notifier + + @pytest.mark.asyncio + async def test_handle_change_insert(self, notification_service): + service, collection, remote_notifier = notification_service + callback = MagicMock() + service.subscribe("test_store", callback) + + change = { + "operationType": "insert", + "fullDocument": { + "store_name": "test_store", + "dataset_id": "test_dataset", + "key": "test_key", + "value": "test_value", + }, + "wallTime": datetime.datetime.now(), + } + + with patch("asyncio.create_task"): + await service._handle_change(change) + + callback.assert_called_once() + called_message = callback.call_args[0][0] + + assert called_message.key == "test_key" + assert called_message.value == "test_value" + assert called_message.metadata.dataset_id == "test_dataset" + assert called_message.metadata.operation_type == "insert" + + @pytest.mark.asyncio + async def test_handle_change_delete(self, notification_service): + service, collection, remote_notifier = notification_service + callback = MagicMock() + service.subscribe("test_store", callback) + + # Mock a delete operation where fullDocument is not available + change = { + "operationType": "delete", + "documentKey": { + "_id": { + "store_name": "test_store", + "dataset_id": "test_dataset", + "key": "test_key", + } + }, + "wallTime": datetime.datetime.now(), + } + + with patch("asyncio.create_task"): + await service._handle_change(change) + + callback.assert_called_once() + called_message = callback.call_args[0][0] + + assert called_message.key == "test_key" + assert called_message.value is None + assert called_message.metadata.dataset_id == "test_dataset" + assert called_message.metadata.operation_type == "delete" + + @pytest.mark.asyncio + async def test_run_with_change_stream(self, notification_service): + service, collection, remote_notifier = notification_service + + with patch.object( + service, "_run_change_stream" + ) as mock_run_change_stream: + mock_run_change_stream.return_value = None + + task = asyncio.create_task(service._run()) + + await asyncio.sleep(0.1) + + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + mock_run_change_stream.assert_called_once() + + @pytest.mark.asyncio + async def test_run_with_fallback_to_polling(self, notification_service): + service, collection, remote_notifier = notification_service + + # Patch the _run method and check its calls + with patch.object( + service, "_run_change_stream" + ) as mock_run_change_stream, patch.object( + service, "_start_polling" + ) as mock_start_polling: + # Make _run_change_stream raise an exception to trigger fallback + mock_run_change_stream.side_effect = OperationFailure( + "Change stream not available" + ) + + # Make _start_polling just return a future that resolves immediately + mock_start_polling.return_value = asyncio.Future() + mock_start_polling.return_value.set_result(None) + + task = asyncio.create_task(service._run()) + + await asyncio.sleep(0.1) + + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + # Verify that fallback to polling occurred + mock_run_change_stream.assert_called_once() + mock_start_polling.assert_called_once() + + @pytest.mark.asyncio + async def test_poll(self, notification_service): + service, collection, remote_notifier = notification_service + + with patch.object( + MongoChangeStreamNotificationService, "_get_current_stores" + ) as mock_get_current_stores: + # Setup + mock_get_current_stores.return_value = ["test_store"] + + # Mock the collection's distinct method + current_keys = {"key1", "key2", "key3"} + collection.distinct = AsyncMock(return_value=current_keys) + + # Set up previous state + service._last_poll_time = datetime.datetime.now( + datetime.timezone.utc + ) + service._last_keys = {"test_store": {"key1", "key2", "key4"}} + + # Mock find to return updated documents + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock( + return_value=[ + { + "key": "key1", + "value": "value1", + "dataset_id": "dataset1", + }, + { + "key": "key2", + "value": "value2", + "dataset_id": "dataset2", + }, + { + "key": "key3", + "value": "value3", + "dataset_id": "dataset3", + }, + ] + ) + collection.find.return_value = mock_cursor + + # Save the poll time to ensure it's the same one used in the query + last_poll_time = service._last_poll_time + + # Add a small delay to ensure time passes between poll calls + await asyncio.sleep(0.001) + + # Call the poll method + await service._poll() + + # Verify distinct was called + collection.distinct.assert_called_with( + "key", {"store_name": "test_store"} + ) + + # Don't directly compare the query since timestamps may have microsecond differences + assert collection.find.call_count == 1 + call_args = collection.find.call_args[0][0] + assert call_args["store_name"] == "test_store" + assert "$gt" in call_args["updated_at"] + + # Verify _last_keys was updated + assert service._last_keys["test_store"] == current_keys + + # Verify _last_poll_time was updated + assert service._last_poll_time is not None + assert service._last_poll_time != last_poll_time diff --git a/tests/unittests/operators/subscription_registry_tests.py b/tests/unittests/operators/subscription_registry_tests.py new file mode 100644 index 00000000000..6b7c763c225 --- /dev/null +++ b/tests/unittests/operators/subscription_registry_tests.py @@ -0,0 +1,114 @@ +""" +Unit tests for subscription registry. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import threading +import unittest +from unittest.mock import Mock + +from fiftyone.operators.store.subscription_registry import ( + InLocalMemorySubscriptionRegistry, +) + + +class TestInLocalMemorySubscriptionRegistry(unittest.TestCase): + def setUp(self): + """Set up a fresh registry instance before each test.""" + self.registry = InLocalMemorySubscriptionRegistry() + self.callback1 = Mock() + self.callback2 = Mock() + + def test_subscribe_new_store(self): + """Test subscribing to a new store returns a valid subscription ID.""" + sub_id = self.registry.subscribe("store1", self.callback1) + self.assertIsInstance(sub_id, str) + self.assertTrue(len(sub_id) > 0) + + subscribers = self.registry.get_subscribers("store1") + self.assertEqual(len(subscribers), 1) + self.assertIn(sub_id, subscribers) + self.assertEqual(subscribers[sub_id], (self.callback1, None)) + + def test_subscribe_with_dataset_id(self): + """Test subscribing with a dataset ID.""" + sub_id = self.registry.subscribe( + "store1", self.callback1, "dataset123" + ) + subscribers = self.registry.get_subscribers("store1") + self.assertEqual(subscribers[sub_id], (self.callback1, "dataset123")) + + def test_multiple_subscriptions_same_store(self): + """Test multiple subscriptions to the same store.""" + sub_id1 = self.registry.subscribe("store1", self.callback1) + sub_id2 = self.registry.subscribe("store1", self.callback2) + + self.assertNotEqual(sub_id1, sub_id2) + subscribers = self.registry.get_subscribers("store1") + self.assertEqual(len(subscribers), 2) + self.assertIn(sub_id1, subscribers) + self.assertIn(sub_id2, subscribers) + + def test_unsubscribe_existing(self): + """Test unsubscribing an existing subscription.""" + sub_id = self.registry.subscribe("store1", self.callback1) + result = self.registry.unsubscribe(sub_id) + + self.assertTrue(result) + subscribers = self.registry.get_subscribers("store1") + self.assertNotIn(sub_id, subscribers) + + def test_unsubscribe_nonexistent(self): + """Test unsubscribing a non-existent subscription.""" + result = self.registry.unsubscribe("nonexistent-id") + self.assertFalse(result) + + def test_unsubscribe_all(self): + """Test unsubscribing all subscriptions for a store.""" + self.registry.subscribe("store1", self.callback1) + self.registry.subscribe("store1", self.callback2) + self.registry.subscribe("store2", self.callback1) + + self.registry.unsubscribe_all("store1") + store1_subs = self.registry.get_subscribers("store1") + store2_subs = self.registry.get_subscribers("store2") + + self.assertEqual(len(store1_subs), 0) + self.assertEqual(len(store2_subs), 1) + + def test_empty_subscribers(self): + """Test emptying all subscribers.""" + self.registry.subscribe("store1", self.callback1) + self.registry.subscribe("store2", self.callback2) + + self.registry.empty_subscribers() + self.assertEqual(len(self.registry.get_subscribers("store1")), 0) + self.assertEqual(len(self.registry.get_subscribers("store2")), 0) + + def test_get_subscribers_nonexistent_store(self): + """Test getting subscribers for a non-existent store.""" + subscribers = self.registry.get_subscribers("nonexistent") + self.assertEqual(subscribers, {}) + + def test_thread_safety(self): + """Test thread safety of subscription operations.""" + + def subscribe_thread(): + for _ in range(50): + self.registry.subscribe("store1", self.callback1) + + threads = [threading.Thread(target=subscribe_thread) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + subscribers = self.registry.get_subscribers("store1") + self.assertEqual(len(subscribers), 250) + + +if __name__ == "__main__": + unittest.main()