Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
SpeechCreatedEvent,
UserInputTranscribedEvent,
UserStateChangedEvent,
UserTurnExceededEvent,
avatar,
io,
room_io,
Expand Down Expand Up @@ -116,6 +117,7 @@
InterruptionOptions,
PreemptiveGenerationOptions,
TurnHandlingOptions,
UserTurnLimitOptions,
)
from .worker import (
AgentServer,
Expand Down Expand Up @@ -237,6 +239,8 @@ def __getattr__(name: str) -> typing.Any:
"EndpointingOptions",
"InterruptionOptions",
"PreemptiveGenerationOptions",
"UserTurnLimitOptions",
"UserTurnExceededEvent",
]

# Cleanup docs of unexported modules
Expand Down
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/voice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SpeechCreatedEvent,
UserInputTranscribedEvent,
UserStateChangedEvent,
UserTurnExceededEvent,
)
from .room_io import (
_ParticipantAudioOutput,
Expand Down Expand Up @@ -47,6 +48,7 @@
"AgentStateChangedEvent",
"FunctionToolsExecutedEvent",
"AgentFalseInterruptionEvent",
"UserTurnExceededEvent",
"TranscriptSynchronizer",
"io",
"room_io",
Expand Down
21 changes: 21 additions & 0 deletions livekit-agents/livekit/agents/voice/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..log import logger
from ..types import NOT_GIVEN, FlushSentinel, NotGivenOr
from ..utils import is_given, misc
from .events import UserTurnExceededEvent
from .speech_handle import SpeechHandle
from .turn import TurnHandlingOptions, _migrate_turn_handling

Expand Down Expand Up @@ -254,6 +255,26 @@ async def on_user_turn_completed(
"""
pass

async def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None:
"""Called when the user turn has exceeded the configured limit.

The user has been speaking for too long without the agent successfully
responding. By default, generates a reply using the current turn's
transcript (previous turns are already in the chat context).

Override to customize (e.g., use session.say() with a canned message,
or skip the interruption entirely).
"""
await self.session.generate_reply(
user_input=ev.transcript,
instructions=(
"The user has been speaking too long without giving a chance to reply. "
"Politely cut in with a short reply or notice. Keep it short since the user cannot interrupt it."
),
allow_interruptions=False,
tool_choice="none",
)

def stt_node(
self, audio: AsyncIterable[rtc.AudioFrame], model_settings: ModelSettings
) -> (
Expand Down
108 changes: 86 additions & 22 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@
from .endpointing import create_endpointing
from .events import (
AgentFalseInterruptionEvent,
AgentStateChangedEvent,
ErrorEvent,
FunctionToolsExecutedEvent,
MetricsCollectedEvent,
SessionUsageUpdatedEvent,
SpeechCreatedEvent,
UserInputTranscribedEvent,
UserTurnExceededEvent,
)
from .generation import (
ToolExecutionOutput,
Expand Down Expand Up @@ -169,6 +171,9 @@ def __init__(self, agent: Agent, sess: AgentSession) -> None:
self._drain_blocked_tasks: list[asyncio.Task[Any]] = []
self._mcp_tools: list[mcp.MCPToolset] = []

self._user_turn_exceeded_atask: asyncio.Task[None] | None = None
self._user_turn_exceeded_locked: bool = False

self._on_enter_task: asyncio.Task | None = None
self._on_exit_task: asyncio.Task | None = None

Expand Down Expand Up @@ -1383,31 +1388,34 @@ async def _scheduling_task(self) -> None:
if self._scheduling_paused and len(to_wait) == 0:
break

async def _wait_for_inactive(self) -> None:
async def _wait_for_inactive(
self, *, wait_for_agent: bool = True, wait_for_user: bool = True
) -> None:
agent_active = True
user_active = True
while agent_active or user_active:
if self._current_speech is None and not self._speech_q:
agent_active = False
else:
agent_active = True
if (speech := self._current_speech) and speech._generations:
await speech._wait_for_generation()
await asyncio.sleep(0)

if self._user_silence_event.is_set():
user_active = False
else:
user_active = True
await self._user_silence_event.wait()
while (wait_for_agent and agent_active) or (wait_for_user and user_active):
if wait_for_agent:
if (
self._audio_recognition
and (eou_task := self._audio_recognition._end_of_turn_task)
and not eou_task.done()
):
await eou_task

if (
self._audio_recognition
and (eou_task := self._audio_recognition._end_of_turn_task)
and not eou_task.done()
):
user_active = True
await eou_task
if self._current_speech is None and not self._speech_q:
agent_active = False
else:
agent_active = True
if (speech := self._current_speech) and speech._generations:
await speech._wait_for_generation()
await asyncio.sleep(0)

if wait_for_user:
if self._user_silence_event.is_set():
user_active = False
else:
user_active = True
await self._user_silence_event.wait()

# -- Realtime Session events --

Expand Down Expand Up @@ -2079,6 +2087,62 @@ async def _user_turn_completed_task(
)
self._session.emit("metrics_collected", MetricsCollectedEvent(metrics=eou_metrics))

def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None:
if self._user_turn_exceeded_locked:
return # user callback is executing, drop

# cancel previous wait phase (if still waiting for EOU result)
if self._user_turn_exceeded_atask is not None:
self._user_turn_exceeded_atask.cancel()

self._user_turn_exceeded_atask = self._create_speech_task(
self._user_turn_exceeded_task(ev),
name="AgentActivity._user_turn_exceeded_task",
)

@utils.log_exceptions(logger=logger)
async def _user_turn_exceeded_task(self, ev: UserTurnExceededEvent) -> None:
agent_speaking_fut = asyncio.Future[None]()

def _on_agent_state_changed(state_ev: AgentStateChangedEvent) -> None:
if state_ev.new_state == "speaking" and not agent_speaking_fut.done():
agent_speaking_fut.set_result(None)

if self._session.agent_state == "speaking":
agent_speaking_fut.set_result(None)
else:
self._session.on("agent_state_changed", _on_agent_state_changed)

# wait for the EOU-triggered agent response (cancellable by the new user turn exceeded event)
wait_inactive = asyncio.ensure_future(
self._wait_for_inactive(wait_for_agent=True, wait_for_user=False)
)
try:
done, _ = await asyncio.wait(
(agent_speaking_fut, wait_inactive), return_when=asyncio.FIRST_COMPLETED
)
if agent_speaking_fut in done:
# agent started speaking, skip the user turn exceeded event
return
finally:
self._session.off("agent_state_changed", _on_agent_state_changed)
if not wait_inactive.done():
wait_inactive.cancel()

# custom callback, locked - don't cancel user's callback
logger.debug(
"user turn limit exceeded",
extra={"num_words": ev.accumulated_word_count, "duration": ev.duration},
)
self._user_turn_exceeded_locked = True
try:
await self._agent.on_user_turn_exceeded(ev)
except Exception:
logger.exception("error in on_user_turn_exceeded callback")
finally:
self._user_turn_exceeded_locked = False
self._user_turn_exceeded_atask = None

# AudioRecognition is calling this method to retrieve the chat context before running the TurnDetector model # noqa: E501
def retrieve_chat_ctx(self) -> llm.ChatContext:
return self._agent.chat_ctx
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
_resolve_endpointing,
_resolve_interruption,
_resolve_preemptive_generation,
_resolve_user_turn_limit,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -353,6 +354,7 @@ def __init__(
endpointing = _resolve_endpointing(turn_handling.get("endpointing"))
interruption = _resolve_interruption(turn_handling.get("interruption"))
preemptive_gen = _resolve_preemptive_generation(turn_handling.get("preemptive_generation"))
user_turn_limit = _resolve_user_turn_limit(turn_handling.get("user_turn_limit"))
raw_turn_detection = turn_handling.get("turn_detection", None)

# This is the "global" chat_context, it holds the entire conversation history
Expand All @@ -363,6 +365,7 @@ def __init__(
interruption=interruption,
turn_detection=raw_turn_detection,
preemptive_generation=preemptive_gen,
user_turn_limit=user_turn_limit,
),
max_tool_steps=max_tool_steps,
user_away_timeout=user_away_timeout,
Expand Down
54 changes: 53 additions & 1 deletion livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from livekit import rtc

from .. import inference, llm, stt, utils, vad
from .. import inference, llm, stt, tokenize, utils, vad
from ..inference.interruption import (
_AgentSpeechEndedSentinel,
_AgentSpeechStartedSentinel,
Expand All @@ -30,6 +30,7 @@
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
from .events import UserTurnExceededEvent
from .turn import TurnDetectionMode as TurnDetectionMode

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,6 +62,13 @@ class _PreemptiveGenerationInfo:
started_speaking_at: float | None


@dataclass
class _UserTurnTracker:
words: int = 0
transcript: str = ""
started_at: float | None = None


class RecognitionHooks(Protocol):
def on_interruption(self, ev: inference.OverlappingSpeechEvent) -> None: ...
def on_start_of_speech(self, ev: vad.VADEvent | None) -> None: ...
Expand All @@ -70,6 +78,7 @@ def on_interim_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None) -
def on_final_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None = None) -> None: ...
def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool: ...
def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None: ...
def on_user_turn_exceeded(self, ev: UserTurnExceededEvent) -> None: ...
def retrieve_chat_ctx(self) -> llm.ChatContext: ...


Expand Down Expand Up @@ -189,6 +198,10 @@ def __init__(

self._vad_speech_started: bool = False

# user turn limit tracking — accumulates across turns until agent speaks
self._turn_tracker = _UserTurnTracker()
self._word_tokenizer = tokenize.basic.WordTokenizer()

def update_options(
self,
*,
Expand Down Expand Up @@ -239,6 +252,9 @@ def on_start_of_agent_speech(self, started_at: float) -> None:
self._agent_speaking = True
self._endpointing.on_start_of_agent_speech(started_at=started_at)

# reset user turn tracker when agent starts speaking
self._turn_tracker = _UserTurnTracker()

if self.adaptive_interruption_active:
self._interruption_ch.send_nowait(_AgentSpeechStartedSentinel()) # type: ignore[union-attr]

Expand Down Expand Up @@ -755,6 +771,9 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
# and using that timestamp for _last_speaking_time
self._last_speaking_time = time.time()

# check user turn limit after accumulating transcript
self._check_user_turn_limit(transcript)

if self._vad_base_turn_detection or self._user_turn_committed:
if transcript_changed:
self._hooks.on_preemptive_generation(
Expand Down Expand Up @@ -1066,6 +1085,39 @@ async def _bounce_eou_task(
)
)

def _check_user_turn_limit(self, transcript: str) -> None:
"""Check if the user turn exceeds configured limits.
Called when a final transcript event is received."""
opts = self._session.options.turn_handling["user_turn_limit"]
max_words = opts.get("max_words")
max_duration = opts.get("max_duration")

if max_words is None and max_duration is None:
return

now = time.time()
if self._turn_tracker.started_at is None:
self._turn_tracker.started_at = self._speech_start_time or now

words = self._word_tokenizer.tokenize(transcript)
self._turn_tracker.words += len(words)
self._turn_tracker.transcript = f"{self._turn_tracker.transcript} {transcript}".strip()

duration = now - self._turn_tracker.started_at
time_exceeded = max_duration is not None and duration >= max_duration
words_exceeded = max_words is not None and self._turn_tracker.words >= max_words

if not time_exceeded and not words_exceeded:
return

ev = UserTurnExceededEvent(
transcript=self.current_transcript,
accumulated_transcript=self._turn_tracker.transcript,
accumulated_word_count=self._turn_tracker.words,
duration=duration,
)
self._hooks.on_user_turn_exceeded(ev)

@utils.log_exceptions(logger=logger)
async def _stt_consumer(
self,
Expand Down
14 changes: 14 additions & 0 deletions livekit-agents/livekit/agents/voice/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ class SpeechCreatedEvent(BaseModel):
created_at: float = Field(default_factory=time.time)


class UserTurnExceededEvent(BaseModel):
type: Literal["user_turn_exceeded"] = "user_turn_exceeded"
transcript: str
"""Transcript from the current (uncommitted) user turn only.
Previous turns in the accumulation window are already in the chat context."""
accumulated_transcript: str
"""Full transcript since the start of user speaking."""
accumulated_word_count: int
"""Total word count since the start of user speaking."""
duration: float
"""Duration of the user turn in seconds."""
created_at: float = Field(default_factory=time.time)


class ErrorEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type: Literal["error"] = "error"
Expand Down
Loading
Loading