Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/voice_agents/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ session = AgentSession(
### 🚀 Getting Started

- [`basic_agent.py`](./basic_agent.py) - A fundamental voice agent using LiveKit Inference with metrics collection
- [`audio_turn_detector.py`](./audio_turn_detector.py) - Using a custom audio-native turn detector with buffered current-turn PCM

### 🛠️ Tool Integration & Function Calling

Expand Down
113 changes: 113 additions & 0 deletions examples/voice_agents/audio_turn_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging

from dotenv import load_dotenv

from livekit.agents import (
Agent,
AgentServer,
AgentSession,
AudioTurnContext,
AudioTurnDetector,
JobContext,
JobProcess,
TurnHandlingOptions,
cli,
inference,
)
from livekit.agents.utils.audio import calculate_audio_duration
from livekit.plugins import silero

logger = logging.getLogger("audio-turn-detector")
logger.setLevel(logging.INFO)

load_dotenv()


class HeuristicAudioTurnDetector(AudioTurnDetector):
"""Example audio-native turn detector.

This intentionally uses a simple duration heuristic so the example stays dependency-free.
A production detector can inspect ``turn_ctx.audio`` and run any external model there.
"""

def __init__(
self,
*,
short_utterance_seconds: float = 0.8,
long_utterance_seconds: float = 1.8,
) -> None:
self._short_utterance_seconds = short_utterance_seconds
self._long_utterance_seconds = long_utterance_seconds

@property
def model(self) -> str:
return "heuristic-audio-turn-detector"

@property
def provider(self) -> str:
return "examples"

async def unlikely_threshold(self, language) -> float | None:
return 0.5

async def supports_language(self, language) -> bool:
return True

async def predict_end_of_turn_audio(
self, turn_ctx: AudioTurnContext, *, timeout: float | None = None
) -> float:
duration = calculate_audio_duration(turn_ctx.audio)
logger.info(
"audio turn detector evaluated current turn",
extra={
"duration_seconds": round(duration, 3),
"language": turn_ctx.language,
"transcript": turn_ctx.transcript,
},
)

if not turn_ctx.transcript.strip():
return 0.0
if duration <= self._short_utterance_seconds:
return 0.95
if duration >= self._long_utterance_seconds:
return 0.15
return 0.55


server = AgentServer()


def prewarm(proc: JobProcess) -> None:
proc.userdata["vad"] = silero.VAD.load()


server.setup_fnc = prewarm


@server.rtc_session()
async def entrypoint(ctx: JobContext) -> None:
session = AgentSession(
stt=inference.STT("deepgram/nova-3", language="multi"),
llm=inference.LLM("openai/gpt-4.1-mini"),
tts=inference.TTS("cartesia/sonic-3"),
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=HeuristicAudioTurnDetector(),
endpointing={
"min_delay": 0.05,
"max_delay": 0.6,
},
),
)

await session.start(
agent=Agent(
instructions=("You are a concise voice assistant. Reply briefly and naturally.")
),
room=ctx.room,
)


if __name__ == "__main__":
cli.run_app(server)
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
mock_tools,
)
from .voice.turn import (
AudioTurnContext,
AudioTurnDetector,
EndpointingOptions,
InterruptionOptions,
PreemptiveGenerationOptions,
Expand Down Expand Up @@ -234,6 +236,8 @@ def __getattr__(name: str) -> typing.Any:
"AMDCategory",
"AMDResult",
"TurnHandlingOptions",
"AudioTurnContext",
"AudioTurnDetector",
"EndpointingOptions",
"InterruptionOptions",
"PreemptiveGenerationOptions",
Expand Down
53 changes: 48 additions & 5 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import deque
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol, TypeGuard, cast

from opentelemetry import trace
from opentelemetry.sdk.trace import ReadableSpan
Expand All @@ -30,7 +30,12 @@
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
from .turn import TurnDetectionMode as TurnDetectionMode
from .turn import (
AudioTurnContext,
AudioTurnDetector,
TurnDetectionMode as TurnDetectionMode,
_TurnDetector,
)

if TYPE_CHECKING:
from .agent_session import AgentSession
Expand All @@ -40,6 +45,19 @@
_EOU_MAX_HISTORY_TURNS = 6


def _copy_audio_frame(frame: rtc.AudioFrame) -> rtc.AudioFrame:
return rtc.AudioFrame(
data=bytes(frame.data),
sample_rate=frame.sample_rate,
num_channels=frame.num_channels,
samples_per_channel=frame.samples_per_channel,
)


def _is_audio_turn_detector(detector: object) -> TypeGuard[AudioTurnDetector]:
return hasattr(detector, "predict_end_of_turn_audio")


@dataclass
class _EndOfTurnInfo:
skip_reply: bool
Expand Down Expand Up @@ -145,6 +163,7 @@ def __init__(
self._end_of_turn_task: asyncio.Task[None] | None = None
self._endpointing: BaseEndpointing = endpointing
self._turn_detector = turn_detection if not isinstance(turn_detection, str) else None
self._audio_turn_detector_enabled = _is_audio_turn_detector(self._turn_detector)
self._stt = stt
self._vad = vad
self._stt_model = stt_model
Expand All @@ -167,6 +186,7 @@ def __init__(
self._audio_interim_transcript = ""
# used for STTs that support preflight mode, so it could start preemptive generation earlier
self._audio_preflight_transcript = ""
self._turn_audio_frames: list[rtc.AudioFrame] = []
self._last_language: LanguageCode | None = None

self._stt_pipeline: _STTPipeline | None = None
Expand Down Expand Up @@ -203,6 +223,7 @@ def update_options(

if is_given(turn_detection):
self._turn_detector = turn_detection if not isinstance(turn_detection, str) else None
self._audio_turn_detector_enabled = _is_audio_turn_detector(self._turn_detector)

mode = turn_detection if isinstance(turn_detection, str) else None
if self._turn_detection_mode != mode:
Expand Down Expand Up @@ -434,6 +455,8 @@ def push_audio(self, frame: rtc.AudioFrame, *, skip_stt: bool = False) -> None:
self._input_started_at = time.time() - frame.duration

self._sample_rate = frame.sample_rate
if self._audio_turn_detector_enabled:
self._turn_audio_frames.append(_copy_audio_frame(frame))
if not skip_stt and self._stt_pipeline is not None:
self._stt_pipeline.audio_ch.send_nowait(frame)

Expand Down Expand Up @@ -566,6 +589,7 @@ def clear_user_turn(self) -> None:
self._audio_interim_transcript = ""
self._audio_preflight_transcript = ""
self._final_transcript_confidence = []
self._turn_audio_frames = []
self._user_turn_committed = False

# reset stt to clear the buffer from previous user turn
Expand Down Expand Up @@ -945,9 +969,27 @@ async def _bounce_eou_task(
end_of_turn_probability = 0.0
unlikely_threshold: float | None = None
try:
end_of_turn_probability = await turn_detector.predict_end_of_turn(
chat_ctx
)
if _is_audio_turn_detector(turn_detector):
if not self._turn_audio_frames:
logger.debug(
"audio turn detector skipped because no turn audio was buffered"
)
else:
end_of_turn_probability = (
await turn_detector.predict_end_of_turn_audio(
AudioTurnContext(
audio=list(self._turn_audio_frames),
transcript=self._audio_transcript,
chat_ctx=chat_ctx.copy(),
language=self._last_language,
)
)
)
else:
text_turn_detector = cast(_TurnDetector, turn_detector)
end_of_turn_probability = (
await text_turn_detector.predict_end_of_turn(chat_ctx)
)
unlikely_threshold = await turn_detector.unlikely_threshold(
self._last_language
)
Expand Down Expand Up @@ -1043,6 +1085,7 @@ async def _bounce_eou_task(
# clear the transcript if the user turn was committed
self._audio_transcript = ""
self._final_transcript_confidence = []
self._turn_audio_frames = []
self._last_final_transcript_time = None
# concurrent user speech might have changed it
# only reset if there is no new speech
Expand Down
46 changes: 44 additions & 2 deletions livekit-agents/livekit/agents/voice/turn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, Protocol

from typing_extensions import TypedDict

from livekit import rtc

from ..language import LanguageCode
from ..llm import ChatContext
from ..types import NOT_GIVEN, NotGivenOr
Expand All @@ -28,15 +31,54 @@ async def predict_end_of_turn(
) -> float: ...


TurnDetectionMode = Literal["stt", "vad", "realtime_llm", "manual"] | _TurnDetector
@dataclass(slots=True)
class AudioTurnContext:
"""Current-turn audio and conversation context for audio-native detectors.

Attributes:
audio: Buffered audio frames for the current user turn.
transcript: The current turn transcript accumulated from STT.
chat_ctx: Conversation context with the current user transcript appended.
language: The most recent detected language, if available from STT.
"""

audio: list[rtc.AudioFrame]
transcript: str
chat_ctx: ChatContext
language: LanguageCode | None = None


class AudioTurnDetector(Protocol):
"""Protocol for custom audio-native turn detectors."""

@property
def model(self) -> str:
return "unknown"

@property
def provider(self) -> str:
return "unknown"

async def unlikely_threshold(self, language: LanguageCode | None) -> float | None: ...
async def supports_language(self, language: LanguageCode | None) -> bool: ...

async def predict_end_of_turn_audio(
self, turn_ctx: AudioTurnContext, *, timeout: float | None = None
) -> float: ...


TurnDetectionMode = (
Literal["stt", "vad", "realtime_llm", "manual"] | _TurnDetector | AudioTurnDetector
)
"""
The mode of turn detection to use.

- "stt": use speech-to-text result to detect the end of the user's turn
- "vad": use VAD to detect the start and end of the user's turn
- "realtime_llm": use server-side turn detection provided by the realtime LLM
- "manual": manually manage the turn detection
- _TurnDetector: use the default mode with the provided turn detector
- _TurnDetector: use a text/chat-context turn detector
- AudioTurnDetector: use an audio-native turn detector

(default) If not provided, automatically choose the best mode based on
available models (realtime_llm -> vad -> stt -> manual)
Expand Down
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ unit-tests:
tests/test_aio_itertools.py \
tests/test_room.py \
tests/test_room_io.py \
tests/test_audio_turn_detection.py \
tests/test_audio_recognition_handoff.py \
tests/test_utils/test_audio_array_buffer.py \
tests/test_utils/test_bounded_dict.py \
Expand Down Expand Up @@ -355,4 +356,4 @@ doctor: ## Check development environment health
else \
echo "$(BOLD)$(RED)⚠️ Found $$ISSUES issue(s). Please fix the errors above.$(RESET)"; \
exit 1; \
fi
fi
Loading