-
Notifications
You must be signed in to change notification settings - Fork 405
Fix: unify WebRTC connection tracking for API + UI connections #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
DeveloperViraj
wants to merge
3
commits into
gradio-app:main
Choose a base branch
from
DeveloperViraj:fix/stream-connections-unified
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+9
−0
Open
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
a1fe28d
fix(stream): unify WebRTC connection tracking for API + UI connections
DeveloperViraj 4d94d52
refactor(stream): add minimal get_all_connections() helper as per fee…
DeveloperViraj 0de40c2
fix(stream): move get_all_connections outside __init__, restore UI va…
DeveloperViraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ | |
Literal, | ||
TypedDict, | ||
cast, | ||
List, | ||
Set, | ||
) | ||
|
||
import anyio | ||
|
@@ -28,7 +30,6 @@ | |
from .websocket import WebSocketHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
curr_dir = Path(__file__).parent | ||
|
||
|
||
|
@@ -40,62 +41,18 @@ class Body(BaseModel): | |
|
||
|
||
class UIArgs(TypedDict): | ||
""" | ||
UI customization arguments for the Gradio Blocks UI of the Stream class | ||
""" | ||
|
||
title: NotRequired[str] | ||
"""Title of the demo""" | ||
subtitle: NotRequired[str] | ||
"""Subtitle of the demo. Text will be centered and displayed below the title.""" | ||
icon: NotRequired[str] | ||
"""Icon to display on the button instead of the wave animation. The icon should be a path/url to a .svg/.png/.jpeg file.""" | ||
icon_button_color: NotRequired[str] | ||
"""Color of the icon button. Default is var(--color-accent) of the demo theme.""" | ||
pulse_color: NotRequired[str] | ||
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme.""" | ||
icon_radius: NotRequired[int] | ||
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%.""" | ||
send_input_on: NotRequired[Literal["submit", "change"]] | ||
"""When to send the input to the handler. Default is "change". | ||
If "submit", the input will be sent when the submit event is triggered by the user. | ||
If "change", the input will be sent whenever the user changes the input value. | ||
""" | ||
hide_title: NotRequired[bool] | ||
"""If True, the title and subtitle will not be displayed.""" | ||
full_screen: NotRequired[bool] | ||
"""If False, the component will be contained within its parent instead of full screen. Default is True.""" | ||
|
||
|
||
class Stream(WebRTCConnectionMixin): | ||
""" | ||
Define an audio or video stream with a built-in UI, mountable on a FastAPI app. | ||
|
||
This class encapsulates the logic for handling real-time communication (WebRTC) | ||
streams, including setting up peer connections, managing tracks, generating | ||
a Gradio user interface, and integrating with FastAPI for API endpoints. | ||
It supports different modes (send, receive, send-receive) and modalities | ||
(audio, video, audio-video), and can optionally handle additional Gradio | ||
input/output components alongside the stream. It also provides functionality | ||
for telephone integration via the FastPhone method. | ||
|
||
Attributes: | ||
mode (Literal["send-receive", "receive", "send"]): The direction of the stream. | ||
modality (Literal["video", "audio", "audio-video"]): The type of media stream. | ||
rtp_params (dict[str, Any] | None): Parameters for RTP encoding. | ||
event_handler (HandlerType): The main function to process stream data. | ||
concurrency_limit (int): The maximum number of concurrent connections allowed. | ||
time_limit (float | None): Time limit in seconds for the event handler execution. | ||
allow_extra_tracks (bool): Whether to allow extra tracks beyond the specified modality. | ||
additional_output_components (list[Component] | None): Extra Gradio output components. | ||
additional_input_components (list[Component] | None): Extra Gradio input components. | ||
additional_outputs_handler (Callable | None): Handler for additional outputs. | ||
track_constraints (dict[str, Any] | None): Constraints for media tracks (e.g., resolution). | ||
webrtc_component (WebRTC): The underlying Gradio WebRTC component instance. | ||
rtc_configuration (dict[str, Any] | None): Configuration for the RTCPeerConnection (e.g., ICE servers). | ||
_ui (Blocks): The Gradio Blocks UI instance. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
handler: HandlerType, | ||
|
@@ -115,31 +72,6 @@ def __init__( | |
ui_args: UIArgs | None = None, | ||
verbose: bool = True, | ||
): | ||
""" | ||
Initialize the Stream instance. | ||
|
||
Args: | ||
handler: The function to handle incoming stream data and return output data. | ||
additional_outputs_handler: An optional function to handle updates to additional output components. | ||
mode: The direction of the stream ('send', 'receive', or 'send-receive'). | ||
modality: The type of media ('video', 'audio', or 'audio-video'). | ||
concurrency_limit: Maximum number of concurrent connections. 'default' maps to 1. | ||
time_limit: Maximum execution time for the handler function in seconds. | ||
allow_extra_tracks: If True, allows connections with tracks not matching the modality. | ||
rtp_params: Optional dictionary of RTP encoding parameters. | ||
rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers). | ||
Required when deploying on Colab or Spaces. | ||
server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note | ||
that setting iceServers to be an empty list will mean no ICE servers will be used in the server. | ||
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate). | ||
additional_inputs: Optional list of extra Gradio input components. | ||
additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`. | ||
ui_args: Optional dictionary to customize the default UI appearance (title, subtitle, icon, etc.). | ||
verbose: Whether to print verbose logging on startup. | ||
|
||
Raises: | ||
ValueError: If `additional_outputs` are provided without `additional_outputs_handler`. | ||
""" | ||
WebRTCConnectionMixin.__init__(self) | ||
self.mode = mode | ||
self.modality = modality | ||
|
@@ -153,7 +85,6 @@ def __init__( | |
self.event_handler.needs_args = True # type: ignore | ||
else: | ||
self.event_handler.needs_args = False # type: ignore | ||
|
||
self.concurrency_limit = cast( | ||
(int), | ||
1 if concurrency_limit in ["default", None] else concurrency_limit, | ||
|
@@ -167,7 +98,7 @@ def __init__( | |
self.additional_input_components = additional_inputs | ||
self.additional_outputs_handler = additional_outputs_handler | ||
self.track_constraints = track_constraints | ||
self.webrtc_component: WebRTC | ||
self.webrtc_component: WebRTC | None = None | ||
self.rtc_configuration = rtc_configuration | ||
self.server_rtc_configuration = self.convert_to_aiortc_format( | ||
server_rtc_configuration | ||
|
@@ -176,29 +107,126 @@ def __init__( | |
self._ui = self._generate_default_ui(ui_args) | ||
self._ui.launch = self._wrap_gradio_launch(self._ui.launch) | ||
|
||
def has_webrtc_component(self) -> bool: | ||
return getattr(self, "webrtc_component", None) is not None | ||
|
||
def get_all_connections(self) -> List[str]: | ||
all_ids: Set[str] = set() | ||
try: | ||
conns = getattr(self, "connections", None) | ||
if conns is not None: | ||
if hasattr(conns, "keys"): | ||
for k in conns.keys(): | ||
all_ids.add(str(k)) | ||
else: | ||
for k in conns: | ||
all_ids.add(str(k)) | ||
except Exception: | ||
pass | ||
try: | ||
if self.has_webrtc_component(): | ||
wc = self.webrtc_component | ||
if hasattr(wc, "connections") and wc.connections is not None: | ||
try: | ||
for k in wc.connections.keys(): | ||
all_ids.add(str(k)) | ||
except Exception: | ||
for k in wc.connections: | ||
all_ids.add(str(k)) | ||
for alt_name in ("_connections", "_conn_map", "active_connections"): | ||
|
||
try: | ||
alt = getattr(wc, alt_name, None) | ||
if alt: | ||
if hasattr(alt, "keys"): | ||
for k in alt.keys(): | ||
all_ids.add(str(k)) | ||
else: | ||
for k in alt: | ||
all_ids.add(str(k)) | ||
except Exception: | ||
continue | ||
except Exception: | ||
pass | ||
return list(all_ids) | ||
|
||
async def offer(self, body: Body, request: Request) -> dict[str, Any]: | ||
sdp = body.sdp | ||
candidate = body.candidate | ||
webrtc_id = body.webrtc_id | ||
if body.type == "offer" and sdp is not None: | ||
return await self.handle_offer(sdp, webrtc_id) | ||
elif body.type == "candidate" and candidate is not None: | ||
await self.add_ice_candidate(candidate, webrtc_id) | ||
return {"status": "ok"} | ||
return {"error": "Invalid request"} | ||
|
||
async def telephone_handler(self, websocket: WebSocket) -> None: | ||
handler = WebSocketHandler(self) | ||
await handler.handle(websocket) | ||
|
||
async def handle_incoming_call(self, request: Request) -> dict[str, Any]: | ||
data = await request.json() | ||
return await self.handle_call(data) | ||
|
||
async def websocket_offer(self, websocket: WebSocket) -> None: | ||
await websocket.accept() | ||
async for message in websocket.iter_text(): | ||
await self.handle_ws_message(message, websocket) | ||
|
||
def convert_to_aiortc_format( | ||
self, rtc_config: dict[str, Any] | None | ||
) -> dict[str, Any] | None: | ||
if rtc_config is None: | ||
return None | ||
if callable(rtc_config): | ||
rtc_config = rtc_config() | ||
return rtc_config | ||
|
||
def _generate_default_ui(self, ui_args: UIArgs | None = None) -> Blocks: | ||
with gr.Blocks() as demo: | ||
self.webrtc_component = WebRTC( | ||
mode=self.mode, | ||
modality=self.modality, | ||
rtc_configuration=self.rtc_configuration, | ||
server_rtc_configuration=self.server_rtc_configuration, | ||
track_constraints=self.track_constraints, | ||
label="", | ||
show_share_button=False, | ||
**(ui_args or {}), | ||
) | ||
self.webrtc_component.stream(self.event_handler) | ||
return demo | ||
|
||
def _wrap_gradio_launch(self, launch_fn: Callable) -> Callable: | ||
def wrapped_launch(*args: Any, **kwargs: Any) -> Any: | ||
self.start_background_tasks() | ||
return launch_fn(*args, **kwargs) | ||
return wrapped_launch | ||
|
||
def _inject_startup_message( | ||
self, lifespan_context: Callable[..., AbstractAsyncContextManager] | ||
) -> Callable[..., AbstractAsyncContextManager]: | ||
async def new_context(*args: Any, **kwargs: Any): | ||
async with lifespan_context(*args, **kwargs): | ||
if self.verbose: | ||
logger.info("Stream mounted and ready.") | ||
yield | ||
return new_context | ||
|
||
def mount( | ||
self, app: FastAPI, path: str = "", tags: list[str | Enum] | None = None | ||
) -> None: | ||
""" | ||
Mount the stream's API endpoints onto a FastAPI application. | ||
|
||
This method adds the necessary routes (`/webrtc/offer`, `/telephone/handler`, | ||
`/telephone/incoming`, `/websocket/offer`) to the provided FastAPI app, | ||
prefixed with the optional `path`. It also injects a startup message | ||
into the app's lifespan. | ||
|
||
Args: | ||
app: The FastAPI application instance. | ||
path: An optional URL prefix for the mounted routes. | ||
tags: Optional tags to FastAPI endpoints. | ||
""" | ||
from fastapi import APIRouter | ||
|
||
router = APIRouter(prefix=path) | ||
router.post("/webrtc/offer", tags=tags)(self.offer) | ||
router.websocket("/telephone/handler")(self.telephone_handler) | ||
router.post("/telephone/incoming", tags=tags)(self.handle_incoming_call) | ||
router.websocket("/websocket/offer")(self.websocket_offer) | ||
|
||
@router.get("/connections", tags=tags) | ||
async def get_connections(): | ||
return self.get_all_connections() | ||
|
||
lifespan = self._inject_startup_message(app.router.lifespan_context) | ||
app.router.lifespan_context = lifespan | ||
app.include_router(router) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this ever happen?