Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 114 additions & 86 deletions backend/fastrtc/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Literal,
TypedDict,
cast,
List,
Set,
)

import anyio
Expand All @@ -28,7 +30,6 @@
from .websocket import WebSocketHandler

logger = logging.getLogger(__name__)

curr_dir = Path(__file__).parent


Expand All @@ -40,62 +41,18 @@ class Body(BaseModel):


class UIArgs(TypedDict):
"""
UI customization arguments for the Gradio Blocks UI of the Stream class
"""

title: NotRequired[str]
"""Title of the demo"""
subtitle: NotRequired[str]
"""Subtitle of the demo. Text will be centered and displayed below the title."""
icon: NotRequired[str]
"""Icon to display on the button instead of the wave animation. The icon should be a path/url to a .svg/.png/.jpeg file."""
icon_button_color: NotRequired[str]
"""Color of the icon button. Default is var(--color-accent) of the demo theme."""
pulse_color: NotRequired[str]
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
icon_radius: NotRequired[int]
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
send_input_on: NotRequired[Literal["submit", "change"]]
"""When to send the input to the handler. Default is "change".
If "submit", the input will be sent when the submit event is triggered by the user.
If "change", the input will be sent whenever the user changes the input value.
"""
hide_title: NotRequired[bool]
"""If True, the title and subtitle will not be displayed."""
full_screen: NotRequired[bool]
"""If False, the component will be contained within its parent instead of full screen. Default is True."""


class Stream(WebRTCConnectionMixin):
"""
Define an audio or video stream with a built-in UI, mountable on a FastAPI app.

This class encapsulates the logic for handling real-time communication (WebRTC)
streams, including setting up peer connections, managing tracks, generating
a Gradio user interface, and integrating with FastAPI for API endpoints.
It supports different modes (send, receive, send-receive) and modalities
(audio, video, audio-video), and can optionally handle additional Gradio
input/output components alongside the stream. It also provides functionality
for telephone integration via the FastPhone method.

Attributes:
mode (Literal["send-receive", "receive", "send"]): The direction of the stream.
modality (Literal["video", "audio", "audio-video"]): The type of media stream.
rtp_params (dict[str, Any] | None): Parameters for RTP encoding.
event_handler (HandlerType): The main function to process stream data.
concurrency_limit (int): The maximum number of concurrent connections allowed.
time_limit (float | None): Time limit in seconds for the event handler execution.
allow_extra_tracks (bool): Whether to allow extra tracks beyond the specified modality.
additional_output_components (list[Component] | None): Extra Gradio output components.
additional_input_components (list[Component] | None): Extra Gradio input components.
additional_outputs_handler (Callable | None): Handler for additional outputs.
track_constraints (dict[str, Any] | None): Constraints for media tracks (e.g., resolution).
webrtc_component (WebRTC): The underlying Gradio WebRTC component instance.
rtc_configuration (dict[str, Any] | None): Configuration for the RTCPeerConnection (e.g., ICE servers).
_ui (Blocks): The Gradio Blocks UI instance.
"""

def __init__(
self,
handler: HandlerType,
Expand All @@ -115,31 +72,6 @@ def __init__(
ui_args: UIArgs | None = None,
verbose: bool = True,
):
"""
Initialize the Stream instance.

Args:
handler: The function to handle incoming stream data and return output data.
additional_outputs_handler: An optional function to handle updates to additional output components.
mode: The direction of the stream ('send', 'receive', or 'send-receive').
modality: The type of media ('video', 'audio', or 'audio-video').
concurrency_limit: Maximum number of concurrent connections. 'default' maps to 1.
time_limit: Maximum execution time for the handler function in seconds.
allow_extra_tracks: If True, allows connections with tracks not matching the modality.
rtp_params: Optional dictionary of RTP encoding parameters.
rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers).
Required when deploying on Colab or Spaces.
server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note
that setting iceServers to be an empty list will mean no ICE servers will be used in the server.
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate).
additional_inputs: Optional list of extra Gradio input components.
additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`.
ui_args: Optional dictionary to customize the default UI appearance (title, subtitle, icon, etc.).
verbose: Whether to print verbose logging on startup.

Raises:
ValueError: If `additional_outputs` are provided without `additional_outputs_handler`.
"""
WebRTCConnectionMixin.__init__(self)
self.mode = mode
self.modality = modality
Expand All @@ -153,7 +85,6 @@ def __init__(
self.event_handler.needs_args = True # type: ignore
else:
self.event_handler.needs_args = False # type: ignore

self.concurrency_limit = cast(
(int),
1 if concurrency_limit in ["default", None] else concurrency_limit,
Expand All @@ -167,7 +98,7 @@ def __init__(
self.additional_input_components = additional_inputs
self.additional_outputs_handler = additional_outputs_handler
self.track_constraints = track_constraints
self.webrtc_component: WebRTC
self.webrtc_component: WebRTC | None = None
self.rtc_configuration = rtc_configuration
self.server_rtc_configuration = self.convert_to_aiortc_format(
server_rtc_configuration
Expand All @@ -176,29 +107,126 @@ def __init__(
self._ui = self._generate_default_ui(ui_args)
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)

def has_webrtc_component(self) -> bool:
return getattr(self, "webrtc_component", None) is not None

def get_all_connections(self) -> List[str]:
all_ids: Set[str] = set()
try:
conns = getattr(self, "connections", None)
if conns is not None:
if hasattr(conns, "keys"):
for k in conns.keys():
all_ids.add(str(k))
else:
for k in conns:
all_ids.add(str(k))
except Exception:
pass
try:
if self.has_webrtc_component():
wc = self.webrtc_component
if hasattr(wc, "connections") and wc.connections is not None:
try:
for k in wc.connections.keys():
all_ids.add(str(k))
except Exception:
for k in wc.connections:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this ever happen?

all_ids.add(str(k))
for alt_name in ("_connections", "_conn_map", "active_connections"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are these naming conventions used? I cannot find them in the current code base.

try:
alt = getattr(wc, alt_name, None)
if alt:
if hasattr(alt, "keys"):
for k in alt.keys():
all_ids.add(str(k))
else:
for k in alt:
all_ids.add(str(k))
except Exception:
continue
except Exception:
pass
return list(all_ids)

async def offer(self, body: Body, request: Request) -> dict[str, Any]:
sdp = body.sdp
candidate = body.candidate
webrtc_id = body.webrtc_id
if body.type == "offer" and sdp is not None:
return await self.handle_offer(sdp, webrtc_id)
elif body.type == "candidate" and candidate is not None:
await self.add_ice_candidate(candidate, webrtc_id)
return {"status": "ok"}
return {"error": "Invalid request"}

async def telephone_handler(self, websocket: WebSocket) -> None:
handler = WebSocketHandler(self)
await handler.handle(websocket)

async def handle_incoming_call(self, request: Request) -> dict[str, Any]:
data = await request.json()
return await self.handle_call(data)

async def websocket_offer(self, websocket: WebSocket) -> None:
await websocket.accept()
async for message in websocket.iter_text():
await self.handle_ws_message(message, websocket)

def convert_to_aiortc_format(
self, rtc_config: dict[str, Any] | None
) -> dict[str, Any] | None:
if rtc_config is None:
return None
if callable(rtc_config):
rtc_config = rtc_config()
return rtc_config

def _generate_default_ui(self, ui_args: UIArgs | None = None) -> Blocks:
with gr.Blocks() as demo:
self.webrtc_component = WebRTC(
mode=self.mode,
modality=self.modality,
rtc_configuration=self.rtc_configuration,
server_rtc_configuration=self.server_rtc_configuration,
track_constraints=self.track_constraints,
label="",
show_share_button=False,
**(ui_args or {}),
)
self.webrtc_component.stream(self.event_handler)
return demo

def _wrap_gradio_launch(self, launch_fn: Callable) -> Callable:
def wrapped_launch(*args: Any, **kwargs: Any) -> Any:
self.start_background_tasks()
return launch_fn(*args, **kwargs)
return wrapped_launch

def _inject_startup_message(
self, lifespan_context: Callable[..., AbstractAsyncContextManager]
) -> Callable[..., AbstractAsyncContextManager]:
async def new_context(*args: Any, **kwargs: Any):
async with lifespan_context(*args, **kwargs):
if self.verbose:
logger.info("Stream mounted and ready.")
yield
return new_context

def mount(
self, app: FastAPI, path: str = "", tags: list[str | Enum] | None = None
) -> None:
"""
Mount the stream's API endpoints onto a FastAPI application.

This method adds the necessary routes (`/webrtc/offer`, `/telephone/handler`,
`/telephone/incoming`, `/websocket/offer`) to the provided FastAPI app,
prefixed with the optional `path`. It also injects a startup message
into the app's lifespan.

Args:
app: The FastAPI application instance.
path: An optional URL prefix for the mounted routes.
tags: Optional tags to FastAPI endpoints.
"""
from fastapi import APIRouter

router = APIRouter(prefix=path)
router.post("/webrtc/offer", tags=tags)(self.offer)
router.websocket("/telephone/handler")(self.telephone_handler)
router.post("/telephone/incoming", tags=tags)(self.handle_incoming_call)
router.websocket("/websocket/offer")(self.websocket_offer)

@router.get("/connections", tags=tags)
async def get_connections():
return self.get_all_connections()

lifespan = self._inject_startup_message(app.router.lifespan_context)
app.router.lifespan_context = lifespan
app.include_router(router)
Expand Down