Skip to content

Commit 8bd91ed

Browse files
authored
add pre-connected audio buffer (livekit#2171)
1 parent 15e232b commit 8bd91ed

File tree

5 files changed

+233
-10
lines changed

5 files changed

+233
-10
lines changed

examples/voice_agents/basic_agent.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ async def entrypoint(ctx: JobContext):
7171
ctx.log_context_fields = {
7272
"room": ctx.room.name,
7373
}
74-
await ctx.connect()
7574

7675
session = AgentSession(
7776
vad=ctx.proc.userdata["vad"],
@@ -98,9 +97,6 @@ async def log_usage():
9897
# shutdown callbacks are triggered when the session is over
9998
ctx.add_shutdown_callback(log_usage)
10099

101-
# wait for a participant to join the room
102-
await ctx.wait_for_participant()
103-
104100
await session.start(
105101
agent=MyAgent(),
106102
room=ctx.room,
@@ -111,6 +107,9 @@ async def log_usage():
111107
room_output_options=RoomOutputOptions(transcription_enabled=True),
112108
)
113109

110+
# join the room when agent is ready
111+
await ctx.connect()
112+
114113

115114
if __name__ == "__main__":
116115
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm))

livekit-agents/livekit/agents/voice/room_io/_input.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22

33
import asyncio
44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncIterator
5+
from collections.abc import AsyncIterator, Iterable
66
from typing import Generic, TypeVar, Union
77

88
from typing_extensions import override
99

1010
import livekit.rtc as rtc
11+
from livekit.rtc._proto.track_pb2 import AudioTrackFeature
1112

1213
from ...log import logger
1314
from ...utils import aio, log_exceptions
1415
from ..io import AudioInput, VideoInput
16+
from ._pre_connect_audio import PreConnectAudioHandler
1517

1618
T = TypeVar("T", bound=Union[rtc.AudioFrame, rtc.VideoFrame])
1719

@@ -126,14 +128,15 @@ async def _forward_task(
126128
self,
127129
old_task: asyncio.Task | None,
128130
stream: rtc.VideoStream | rtc.AudioStream,
129-
track_source: rtc.TrackSource.ValueType,
131+
publication: rtc.RemoteTrackPublication,
132+
participant: rtc.RemoteParticipant,
130133
) -> None:
131134
if old_task:
132135
await aio.cancel_and_wait(old_task)
133136

134137
extra = {
135-
"participant": self._participant_identity,
136-
"source": rtc.TrackSource.Name(track_source),
138+
"participant": participant.identity,
139+
"source": rtc.TrackSource.Name(publication.source),
137140
}
138141
logger.debug("start reading stream", extra=extra)
139142
async for event in stream:
@@ -172,7 +175,7 @@ def _on_track_available(
172175
self._stream = self._create_stream(track)
173176
self._publication = publication
174177
self._forward_atask = asyncio.create_task(
175-
self._forward_task(self._forward_atask, self._stream, publication.source)
178+
self._forward_task(self._forward_atask, self._stream, publication, participant)
176179
)
177180
return True
178181

@@ -202,13 +205,15 @@ def __init__(
202205
sample_rate: int,
203206
num_channels: int,
204207
noise_cancellation: rtc.NoiseCancellationOptions | None,
208+
pre_connect_audio_handler: PreConnectAudioHandler | None,
205209
) -> None:
206210
_ParticipantInputStream.__init__(
207211
self, room=room, track_source=rtc.TrackSource.SOURCE_MICROPHONE
208212
)
209213
self._sample_rate = sample_rate
210214
self._num_channels = num_channels
211215
self._noise_cancellation = noise_cancellation
216+
self._pre_connect_audio_handler = pre_connect_audio_handler
212217

213218
@override
214219
def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
@@ -219,6 +224,78 @@ def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
219224
noise_cancellation=self._noise_cancellation,
220225
)
221226

227+
@override
228+
async def _forward_task(
229+
self,
230+
old_task: asyncio.Task | None,
231+
stream: rtc.AudioStream,
232+
publication: rtc.RemoteTrackPublication,
233+
participant: rtc.RemoteParticipant,
234+
) -> None:
235+
if (
236+
self._pre_connect_audio_handler
237+
and publication.track
238+
and AudioTrackFeature.TF_PRECONNECT_BUFFER in publication.audio_features
239+
):
240+
try:
241+
duration = 0
242+
frames = await self._pre_connect_audio_handler.wait_for_data(publication.track.sid)
243+
for frame in self._resample_frames(frames):
244+
if self._attached:
245+
await self._data_ch.send(frame)
246+
duration += frame.duration
247+
if frames:
248+
logger.debug(
249+
"pre-connect audio buffer pushed",
250+
extra={
251+
"duration": duration,
252+
"track_id": publication.track.sid,
253+
"participant": participant.identity,
254+
},
255+
)
256+
257+
except asyncio.TimeoutError:
258+
logger.warning(
259+
"timeout waiting for pre-connect audio buffer",
260+
extra={
261+
"duration": duration,
262+
"track_id": publication.track.sid,
263+
"participant": participant.identity,
264+
},
265+
)
266+
267+
except Exception as e:
268+
logger.error(
269+
"error reading pre-connect audio buffer",
270+
extra={
271+
"error": e,
272+
"track_id": publication.track.sid,
273+
"participant": participant.identity,
274+
},
275+
)
276+
277+
await super()._forward_task(old_task, stream, publication, participant)
278+
279+
def _resample_frames(self, frames: Iterable[rtc.AudioFrame]) -> Iterable[rtc.AudioFrame]:
280+
resampler: rtc.AudioResampler | None = None
281+
for frame in frames:
282+
if (
283+
not resampler
284+
and self._sample_rate is not None
285+
and frame.sample_rate != self._sample_rate
286+
):
287+
resampler = rtc.AudioResampler(
288+
input_rate=frame.sample_rate, output_rate=self._sample_rate
289+
)
290+
291+
if resampler:
292+
yield from resampler.push(frame)
293+
else:
294+
yield frame
295+
296+
if resampler:
297+
yield from resampler.flush()
298+
222299

223300
class _ParticipantVideoInputStream(_ParticipantInputStream[rtc.VideoFrame], VideoInput):
224301
def __init__(self, room: rtc.Room) -> None:
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import asyncio
2+
import contextlib
3+
import time
4+
from dataclasses import dataclass, field
5+
6+
from livekit import rtc
7+
8+
from ..agent import logger, utils
9+
10+
PRE_CONNECT_AUDIO_BUFFER_STREAM = "lk.agent.pre-connect-audio-buffer"
11+
12+
13+
@dataclass
14+
class _PreConnectAudioBuffer:
15+
timestamp: float
16+
frames: list[rtc.AudioFrame] = field(default_factory=list)
17+
18+
19+
class PreConnectAudioHandler:
20+
def __init__(self, room: rtc.Room, *, timeout: float, max_delta_s: float = 1.0):
21+
self._room = room
22+
self._timeout = timeout
23+
self._max_delta_s = max_delta_s
24+
25+
# track id -> buffer
26+
self._buffers: dict[str, asyncio.Future[_PreConnectAudioBuffer]] = {}
27+
self._tasks: set[asyncio.Task] = set()
28+
29+
self._registered_after_connect = False
30+
31+
def register(self):
32+
def _handler(reader: rtc.ByteStreamReader, participant_id: str):
33+
task = asyncio.create_task(self._read_audio_task(reader, participant_id))
34+
self._tasks.add(task)
35+
task.add_done_callback(self._tasks.discard)
36+
37+
def _on_timeout():
38+
logger.warning(
39+
"pre-connect audio received but not completed in time",
40+
extra={"participant": participant_id},
41+
)
42+
if not task.done():
43+
task.cancel()
44+
45+
timeout_handle = asyncio.get_event_loop().call_later(self._timeout, _on_timeout)
46+
task.add_done_callback(lambda _: timeout_handle.cancel())
47+
48+
try:
49+
if self._room.isconnected():
50+
self._registered_after_connect = True
51+
self._room.register_byte_stream_handler(PRE_CONNECT_AUDIO_BUFFER_STREAM, _handler)
52+
except ValueError:
53+
logger.warning(
54+
f"pre-connect audio handler for {PRE_CONNECT_AUDIO_BUFFER_STREAM} "
55+
"already registered, ignoring"
56+
)
57+
58+
async def aclose(self):
59+
self._room.unregister_byte_stream_handler(PRE_CONNECT_AUDIO_BUFFER_STREAM)
60+
await utils.aio.cancel_and_wait(*self._tasks)
61+
62+
async def wait_for_data(self, track_id: str) -> list[rtc.AudioFrame]:
63+
# the handler is enabled by default, log a warning only if the buffer is actually used
64+
if self._registered_after_connect:
65+
logger.warning(
66+
"pre-connect audio handler registered after room connection, "
67+
"start RoomIO before ctx.connect() to ensure seamless audio buffer.",
68+
extra={"track_id": track_id},
69+
)
70+
71+
self._buffers.setdefault(track_id, asyncio.Future())
72+
fut = self._buffers[track_id]
73+
74+
try:
75+
if fut.done():
76+
buf = fut.result()
77+
if (delta := time.time() - buf.timestamp) > self._max_delta_s:
78+
logger.warning(
79+
"pre-connect audio buffer is too old",
80+
extra={"track_id": track_id, "delta_time": delta},
81+
)
82+
return []
83+
return buf.frames
84+
85+
buf = await asyncio.wait_for(fut, self._timeout)
86+
return buf.frames
87+
finally:
88+
self._buffers.pop(track_id)
89+
90+
@utils.log_exceptions(logger=logger)
91+
async def _read_audio_task(self, reader: rtc.ByteStreamReader, participant_id: str):
92+
if not (track_id := reader.info.attributes.get("trackId")):
93+
logger.warning(
94+
"pre-connect audio received but no trackId", extra={"participant": participant_id}
95+
)
96+
return
97+
98+
if (fut := self._buffers.get(track_id)) and fut.done():
99+
# reset the buffer if it's already set
100+
self._buffers.pop(track_id)
101+
self._buffers.setdefault(track_id, asyncio.Future())
102+
fut = self._buffers[track_id]
103+
104+
buf = _PreConnectAudioBuffer(timestamp=time.time())
105+
try:
106+
sample_rate = int(reader.info.attributes["sampleRate"])
107+
num_channels = int(reader.info.attributes["channels"])
108+
109+
duration = 0
110+
audio_stream = utils.audio.AudioByteStream(sample_rate, num_channels)
111+
async for chunk in reader:
112+
for frame in audio_stream.push(chunk):
113+
buf.frames.append(frame)
114+
duration += frame.duration
115+
116+
for frame in audio_stream.flush():
117+
buf.frames.append(frame)
118+
duration += frame.duration
119+
120+
logger.debug(
121+
"pre-connect audio received",
122+
extra={"duration": duration, "track_id": track_id, "participant": participant_id},
123+
)
124+
125+
with contextlib.suppress(asyncio.InvalidStateError):
126+
fut.set_result(buf)
127+
except Exception as e:
128+
with contextlib.suppress(asyncio.InvalidStateError):
129+
fut.set_exception(e)

livekit-agents/livekit/agents/voice/room_io/room_io.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..events import AgentStateChangedEvent, UserInputTranscribedEvent
2020
from ..io import AudioInput, AudioOutput, TextOutput, VideoInput
2121
from ..transcription import TranscriptSynchronizer
22+
from ._pre_connect_audio import PreConnectAudioHandler
2223

2324
if TYPE_CHECKING:
2425
from ..agent_session import AgentSession
@@ -70,6 +71,10 @@ class RoomInputOptions:
7071
participant_identity: NotGivenOr[str] = NOT_GIVEN
7172
"""The participant to link to. If not provided, link to the first participant.
7273
Can be overridden by the `participant` argument of RoomIO constructor or `set_participant`."""
74+
pre_connect_audio: bool = True
75+
"""Pre-connect audio enabled or not."""
76+
pre_connect_audio_timeout: float = 3.0
77+
"""The pre-connect audio will be ignored if it doesn't arrive within this time."""
7378

7479

7580
@dataclass
@@ -125,8 +130,17 @@ def __init__(
125130
self._tasks: set[asyncio.Task] = set()
126131
self._update_state_task: asyncio.Task | None = None
127132

133+
self._pre_connect_audio_handler: PreConnectAudioHandler | None = None
134+
128135
async def start(self) -> None:
129136
# -- create inputs --
137+
if self._input_options.pre_connect_audio:
138+
self._pre_connect_audio_handler = PreConnectAudioHandler(
139+
room=self._room,
140+
timeout=self._input_options.pre_connect_audio_timeout,
141+
)
142+
self._pre_connect_audio_handler.register()
143+
130144
if self._input_options.text_enabled:
131145
try:
132146
self._room.register_text_stream_handler(TOPIC_CHAT, self._on_user_text_input)
@@ -144,6 +158,7 @@ async def start(self) -> None:
144158
sample_rate=self._input_options.audio_sample_rate,
145159
num_channels=self._input_options.audio_num_channels,
146160
noise_cancellation=self._input_options.noise_cancellation,
161+
pre_connect_audio_handler=self._pre_connect_audio_handler,
147162
)
148163

149164
# -- create outputs --
@@ -209,6 +224,9 @@ async def aclose(self) -> None:
209224
if self._init_atask:
210225
await utils.aio.cancel_and_wait(self._init_atask)
211226

227+
if self._pre_connect_audio_handler:
228+
await self._pre_connect_audio_handler.aclose()
229+
212230
if self._audio_input:
213231
await self._audio_input.aclose()
214232
if self._video_input:

livekit-agents/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
]
2525
dependencies = [
2626
"click~=8.1",
27-
"livekit>=1.0.6,<2",
27+
"livekit>=1.0.7,<2",
2828
"livekit-api>=1.0.2,<2",
2929
"livekit-protocol~=1.0",
3030
"protobuf>=3",

0 commit comments

Comments
 (0)