Skip to content

Commit ddc1433

Browse files
authored
Fix Gemini audio (#165)
1 parent f5d37ea commit ddc1433

File tree

8 files changed

+371
-437
lines changed

8 files changed

+371
-437
lines changed

agents-core/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121

2222
requires-python = ">=3.10"
2323
dependencies = [
24-
"getstream[webrtc,telemetry]>=2.5.11",
24+
"getstream[webrtc,telemetry]>=2.5.14",
2525
"python-dotenv>=1.1.1",
2626
"pillow>=10.4.0", # Compatible with moondream SDK (<11.0.0)
2727
"numpy>=1.24.0",
@@ -91,5 +91,5 @@ include = ["vision_agents"]
9191
#]
9292
# getstream = { git = "https://github.com/GetStream/stream-py.git", branch = "audio-more" }
9393
# for local development
94-
#getstream = { git = "https://github.com/GetStream/stream-py.git", rev = "85bd8ef00859ef6ed5ef4ffe7b7f40ae12d12973" }
94+
#getstream = { path = "../../stream-py/", editable = true }
9595
# aiortc = { path = "../stream-py/", editable = true }

agents-core/vision_agents/core/agents/agents.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
LLMResponseCompletedEvent,
3232
RealtimeUserSpeechTranscriptionEvent,
3333
RealtimeAgentSpeechTranscriptionEvent,
34+
RealtimeAudioOutputEvent,
3435
)
3536
from ..llm.llm import AudioLLM, LLM, VideoLLM
3637
from ..llm.realtime import Realtime
@@ -70,7 +71,6 @@
7071
tracer: Tracer = trace.get_tracer("agents")
7172

7273

73-
7474
class Agent:
7575
"""
7676
Agent class makes it easy to build your own video AI.
@@ -227,7 +227,9 @@ def __init__(
227227

228228
async def _finish_llm_turn(self):
229229
if self._pending_turn is None or self._pending_turn.response is None:
230-
raise ValueError("Finish LLM turn should only be called after self._pending_turn is set")
230+
raise ValueError(
231+
"Finish LLM turn should only be called after self._pending_turn is set"
232+
)
231233
turn = self._pending_turn
232234
self._pending_turn = None
233235
event = turn.response
@@ -252,6 +254,7 @@ def setup_event_handling(self):
252254
self.events.subscribe(self._on_turn_event)
253255

254256
if self.stt:
257+
255258
@self.stt.events.subscribe
256259
async def on_turn_ended(event: TurnEndedEvent):
257260
logger.info("Received TurnEndedEvent %s", event)
@@ -322,10 +325,11 @@ async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
322325

323326
# if turn detection is disabled, treat the transcript event as an end of turn
324327
if not self.turn_detection_enabled:
325-
self.events.send(TurnEndedEvent(
326-
participant = event.participant,
327-
))
328-
328+
self.events.send(
329+
TurnEndedEvent(
330+
participant=event.participant,
331+
)
332+
)
329333

330334
# TODO: chat event handling needs work
331335

@@ -634,7 +638,12 @@ async def _apply(self, function_name: str, *args, **kwargs):
634638
):
635639
func = getattr(subclass, function_name)
636640
if func is not None:
637-
await func(*args, **kwargs)
641+
try:
642+
await func(*args, **kwargs)
643+
except Exception as e:
644+
self.logger.exception(
645+
f"Error calling {function_name} on {subclass.__class__.__name__}: {e}"
646+
)
638647

639648
def _end_tracing(self):
640649
if self._root_span is not None:
@@ -879,7 +888,10 @@ async def _reply_to_audio_consumer(self) -> None:
879888
pcm, participant, conversation=self.conversation
880889
)
881890

882-
if participant and getattr(participant, "user_id", None) != self.agent_user.id:
891+
if (
892+
participant
893+
and getattr(participant, "user_id", None) != self.agent_user.id
894+
):
883895
# first forward to processors
884896
# Extract audio bytes for processors using the proper PCM data structure
885897
# PCM data has: format, sample_rate, samples, pts, dts, time_base
@@ -1044,6 +1056,8 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
10441056
self.logger.info(
10451057
f"👉 Turn started - participant speaking {participant_id} : {event.confidence}"
10461058
)
1059+
if self._audio_track is not None:
1060+
await self._audio_track.flush()
10471061
else:
10481062
# Agent itself started speaking - this is normal
10491063
participant_id = (
@@ -1078,9 +1092,15 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
10781092
self._pending_user_transcripts[participant.user_id] = ""
10791093
# cancel the old task if the text changed in the meantime
10801094

1081-
if self._pending_turn is not None and self._pending_turn.input != transcript:
1082-
logger.debug("Eager turn and completed turn didn't match. Cancelling in flight response. %s vs %s ",
1083-
self._pending_turn.input, transcript)
1095+
if (
1096+
self._pending_turn is not None
1097+
and self._pending_turn.input != transcript
1098+
):
1099+
logger.debug(
1100+
"Eager turn and completed turn didn't match. Cancelling in flight response. %s vs %s ",
1101+
self._pending_turn.input,
1102+
transcript,
1103+
)
10841104
if self._pending_turn.task:
10851105
self._pending_turn.task.cancel()
10861106

@@ -1092,18 +1112,22 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
10921112
input=transcript,
10931113
participant=event.participant,
10941114
started_at=datetime.datetime.now(),
1095-
turn_finished=not event.eager_end_of_turn
1115+
turn_finished=not event.eager_end_of_turn,
10961116
)
10971117
self._pending_turn = llm_turn
1098-
task = asyncio.create_task(self.simple_response(transcript, event.participant))
1118+
task = asyncio.create_task(
1119+
self.simple_response(transcript, event.participant)
1120+
)
10991121
llm_turn.task = task
11001122
elif self._pending_turn.input == transcript:
11011123
# same text as pending turn
11021124
is_finished = not event.eager_end_of_turn
11031125
now = datetime.datetime.now()
11041126
elapsed = now - self._pending_turn.started_at
1105-
logger.debug("Marking eager turn as completed. Eager turn detection saved %.2f",
1106-
elapsed.total_seconds() * 1000)
1127+
logger.debug(
1128+
"Marking eager turn as completed. Eager turn detection saved %.2f",
1129+
elapsed.total_seconds() * 1000,
1130+
)
11071131

11081132
if is_finished:
11091133
self._pending_turn.turn_finished = True
@@ -1113,8 +1137,9 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
11131137
@property
11141138
def turn_detection_enabled(self):
11151139
# return true if either turn detection or stt provide turn detection capabilities
1116-
return self.turn_detection is not None or (self.stt is not None and self.stt.turn_detection)
1117-
1140+
return self.turn_detection is not None or (
1141+
self.stt is not None and self.stt.turn_detection
1142+
)
11181143

11191144
@property
11201145
def publish_audio(self) -> bool:
@@ -1246,30 +1271,17 @@ def _validate_configuration(self):
12461271
def _prepare_rtc(self):
12471272
# Variables are now initialized in __init__
12481273

1249-
# Set up audio track if TTS is available
12501274
if self.publish_audio:
1251-
if _is_audio_llm(self.llm):
1252-
self._audio_track = self.llm.output_audio_track
1253-
self.logger.info("🎵 Using Realtime provider output track for audio")
1254-
elif self.audio_publishers:
1255-
# Get the first audio publisher to create the track
1256-
audio_publisher = self.audio_publishers[0]
1257-
self._audio_track = audio_publisher.publish_audio_track()
1258-
self.logger.info("🎵 Audio track initialized from audio publisher")
1259-
else:
1260-
# Default to WebRTC-friendly format unless configured differently
1261-
framerate = 48000
1262-
stereo = True
1263-
self._audio_track = self.edge.create_audio_track(
1264-
framerate=framerate, stereo=stereo
1265-
)
1266-
# Inform TTS of desired output format so it can resample accordingly
1267-
if self.tts:
1268-
channels = 2 if stereo else 1
1269-
self.tts.set_output_format(
1270-
sample_rate=framerate,
1271-
channels=channels,
1272-
)
1275+
framerate = 48000
1276+
stereo = True
1277+
self._audio_track = self.edge.create_audio_track(
1278+
framerate=framerate, stereo=stereo
1279+
)
1280+
1281+
@self.events.subscribe
1282+
async def forward_audio(event: RealtimeAudioOutputEvent):
1283+
if self._audio_track is not None:
1284+
await self._audio_track.write(event.data)
12731285

12741286
# Set up video track if video publishers are available
12751287
if self.publish_video:

agents-core/vision_agents/core/edge/edge_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def create_audio_track(self) -> OutputAudioTrack:
3535
pass
3636

3737
@abc.abstractmethod
38-
def close(self):
38+
async def close(self):
3939
pass
4040

4141
@abc.abstractmethod

agents-core/vision_agents/core/edge/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ class OutputAudioTrack(Protocol):
4545
async def write(self, data: PcmData) -> None: ...
4646

4747
def stop(self) -> None: ...
48+
49+
async def flush(self) -> None: ...

agents-core/vision_agents/core/llm/llm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from vision_agents.core.agents.conversation import Conversation
2525

2626
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
27-
from getstream.video.rtc import AudioStreamTrack, PcmData
27+
from getstream.video.rtc import PcmData
2828
from vision_agents.core.processors import Processor
2929
from vision_agents.core.utils.utils import parse_instructions
3030
from vision_agents.core.events.manager import EventManager
@@ -426,13 +426,6 @@ async def simple_audio_response(
426426
participant: Optional participant information for the audio source.
427427
"""
428428

429-
@property
430-
@abc.abstractmethod
431-
def output_audio_track(self) -> AudioStreamTrack:
432-
"""
433-
An output audio track from the LLM.
434-
"""
435-
436429

437430
class VideoLLM(LLM, metaclass=abc.ABCMeta):
438431
"""

plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Optional, List, Dict, Any
55

66
import aiortc
7-
from getstream.video.rtc.audio_track import AudioStreamTrack
87
from getstream.video.rtc.track_util import PcmData
98
from google import genai
109
from google.genai.live import AsyncSession
@@ -103,19 +102,12 @@ def __init__(
103102
self.client = client
104103
self.config: LiveConnectConfigDict = self._create_config(config)
105104
self.logger = logging.getLogger(__name__)
106-
# Gemini generates at 24k. webrtc automatically translates it to 48khz
107-
self._output_audio_track = AudioStreamTrack(
108-
sample_rate=24000, channels=1, format="s16"
109-
)
105+
110106
self._video_forwarder: Optional[VideoForwarder] = None
111107
self._session_context: Optional[Any] = None
112108
self._session: Optional[AsyncSession] = None
113109
self._receive_task: Optional[asyncio.Task[Any]] = None
114110

115-
@property
116-
def output_audio_track(self) -> AudioStreamTrack:
117-
return self._output_audio_track
118-
119111
async def simple_response(
120112
self,
121113
text: str,
@@ -315,7 +307,6 @@ async def _receive_loop(self):
315307
self._emit_audio_output_event(
316308
audio_data=pcm,
317309
)
318-
await self._output_audio_track.write(pcm)
319310
elif (
320311
hasattr(typed_part, "function_call")
321312
and typed_part.function_call

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def _get_subscription_config(self):
347347
]
348348
)
349349

350-
def close(self):
350+
async def close(self):
351351
# Note: Not calling super().close() as it's an abstract method with trivial body
352352
pass
353353

0 commit comments

Comments
 (0)