Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class SpeechEvent:
request_id: str = ""
alternatives: list[SpeechData] = field(default_factory=list)
recognition_usage: RecognitionUsage | None = None
speech_start_time: float | None = None
"""server-reported wall-clock time of speech onset, when the provider sends
a separate speech-start signal carrying onset timing."""


@dataclass
Expand Down Expand Up @@ -311,6 +314,7 @@ def __init__(
self._resampler: rtc.AudioResampler | None = None

self._start_time_offset: float = 0.0
self._start_time: float = time.time()

@property
def start_time_offset(self) -> float:
Expand All @@ -322,6 +326,24 @@ def start_time_offset(self, value: float) -> None:
raise ValueError("start_time_offset must be non-negative")
self._start_time_offset = value

@property
def start_time(self) -> float:
"""Wall-clock anchor for the stream. Seeded to `time.time()` when the
stream is initialized (and re-seeded on each retry). Plugins may
override this via the setter to anchor it at a more accurate moment
(e.g., when the first audio frame is sent to the provider) so that
server-provided stream-relative timestamps (like
`SpeechEvent.speech_start_time`) can be converted to wall-clock
accurately.
"""
return self._start_time

@start_time.setter
def start_time(self, value: float) -> None:
if value < 0:
raise ValueError("start_time must be non-negative")
self._start_time = value

def _report_connection_acquired(self, acquire_time: float, connection_reused: bool) -> None:
"""Report connection timing as an STTMetrics event with zero usage."""
self._stt.emit(
Expand Down Expand Up @@ -351,6 +373,7 @@ async def _main_task(self) -> None:
while self._num_retries <= max_retries:
try:
self._start_time_offset += time.time() - last_start_time
self._start_time = time.time()
last_start_time = time.time()
return await self._run()
except APIError as e:
Expand Down
9 changes: 5 additions & 4 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,10 +1642,11 @@ def _interrupt_by_audio_activity(

# region recognition hooks

def on_start_of_speech(self, ev: vad.VADEvent | None) -> None:
speech_start_time = time.time()
if ev:
speech_start_time = speech_start_time - ev.speech_duration - ev.inference_duration
def on_start_of_speech(
self,
ev: vad.VADEvent | None,
speech_start_time: float,
) -> None:
self._session._update_user_state("speaking", last_speaking_time=speech_start_time)
if self._audio_recognition:
self._audio_recognition.on_start_of_speech(
Expand Down
15 changes: 9 additions & 6 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class _PreemptiveGenerationInfo:

class RecognitionHooks(Protocol):
def on_interruption(self, ev: inference.OverlappingSpeechEvent) -> None: ...
def on_start_of_speech(self, ev: vad.VADEvent | None) -> None: ...
def on_start_of_speech(self, ev: vad.VADEvent | None, speech_start_time: float) -> None: ...
def on_vad_inference_done(self, ev: vad.VADEvent) -> None: ...
def on_end_of_speech(self, ev: vad.VADEvent | None) -> None: ...
def on_interim_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None) -> None: ...
Expand Down Expand Up @@ -852,12 +852,15 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
self._run_eou_detection(chat_ctx)

elif ev.type == stt.SpeechEventType.START_OF_SPEECH and self._turn_detection_mode == "stt":
with trace.use_span(self._ensure_user_turn_span()):
self._hooks.on_start_of_speech(None)
# If the plugin provided a server onset timestamp, use it;
# otherwise fall back to message arrival time.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add a condition where:

self._speech_start_time = ev.speech_start_time if ev.speech_start_time < self._speech_start_time else self._speech_start_time

for when the vad detects activity before the stt as well

Copy link
Copy Markdown
Contributor Author

@gsharp-aai gsharp-aai Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open to this! Just want to flag it changes behavior from the current PR. Two shapes:

  1. Fallback only (current PR as-is): _speech_start_time is only set from STT when VAD hasn't already set it. VAD wins when it fires, preserving current behavior.
  2. Earlier of VAD or STT (your suggestion): every STT SOS compares both and picks the earlier onset, even when VAD already fired.

I leaned toward #1 since local VAD's back-date is usually more accurate than the server timestamp (no network delay, no clock skew) plus less of a behavioral change (in relation to what currently exists), but happy to flip to #2 if you think the "STT caught it earlier" case is common enough to trust by default.

Let me know which shape the team prefers!

if self._speech_start_time is None:
self._speech_start_time = ev.speech_start_time or time.time()

with trace.use_span(self._ensure_user_turn_span(start_time=self._speech_start_time)):
self._hooks.on_start_of_speech(None, speech_start_time=self._speech_start_time)

self._speaking = True
if self._speech_start_time is None:
self._speech_start_time = time.time()
self._last_speaking_time = time.time()

if self._end_of_turn_task is not None:
Expand All @@ -872,7 +875,7 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None:
self._vad_speech_started = True

with trace.use_span(self._ensure_user_turn_span(start_time=speech_start_time)):
self._hooks.on_start_of_speech(ev)
self._hooks.on_start_of_speech(ev, speech_start_time=speech_start_time)

self._speaking = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dataclasses
import json
import os
import time
import weakref
from dataclasses import dataclass
from typing import Literal
Expand Down Expand Up @@ -360,6 +361,7 @@ async def _run(self) -> None:

async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
nonlocal closing_ws
anchored = False

samples_per_buffer = self._opts.sample_rate // round(1 / self._opts.buffer_size_seconds)
audio_bstream = utils.audio.AudioByteStream(
Expand All @@ -378,6 +380,13 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
frames = audio_bstream.write(data.data.tobytes())

for frame in frames:
if not anchored:
# Anchor the stream's wall-clock to the moment just
# before the first frame is sent — aligned with the
# server's stream-relative zero used by
# SpeechStarted.timestamp.
self.start_time = time.time()
anchored = True
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())

Expand Down Expand Up @@ -518,7 +527,21 @@ def _process_stream_event(self, data: dict) -> None:
return

if message_type == "SpeechStarted":
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH))
# SpeechStarted can arrive well after actual speech onset. The
# `timestamp` field carries the server VAD's onset time in stream-
# relative ms. Convert to wall-clock by adding self.start_time
# (the stream's wall-clock anchor) so the framework records an
# accurate _speech_start_time instead of message arrival.
timestamp_ms = data.get("timestamp")
speech_start_time: float | None = None
if timestamp_ms is not None:
speech_start_time = self.start_time + timestamp_ms / 1000
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.START_OF_SPEECH,
speech_start_time=speech_start_time,
)
)
return

if message_type == "Termination":
Expand Down
106 changes: 106 additions & 0 deletions tests/test_plugin_assemblyai_stt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Tests for AssemblyAI STT plugin configuration options."""

from __future__ import annotations

import time
from unittest.mock import MagicMock, patch

from livekit.agents.stt import SpeechEventType
from livekit.agents.types import NOT_GIVEN


Expand Down Expand Up @@ -78,3 +84,103 @@ async def test_vad_threshold_partial_update():

assert stt._opts.vad_threshold == 0.8
assert stt._opts.max_turn_silence == 500


# ---------------------------------------------------------------------------
# SpeechStarted → speech_start_time conversion
#
# The plugin anchors the stream's wall-clock via the base-class `start_time`
# property (which it overrides in send_task on the first ws.send_bytes). The
# server emits SpeechStarted with a stream-relative `timestamp` in ms, which
# the plugin converts to wall-clock by `self.start_time + timestamp_ms/1000`
# and surfaces via `SpeechEvent.speech_start_time`.
# ---------------------------------------------------------------------------


def _make_stream_for_unit_test():
"""Construct a SpeechStream without triggering the _main_task WebSocket
loop. Patches asyncio.create_task during __init__ so the stream doesn't
try to open a real connection; also closes the coroutines that would
otherwise be scheduled, to avoid un-awaited coroutine warnings."""
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS
from livekit.plugins.assemblyai import STT
from livekit.plugins.assemblyai.stt import SpeechStream

stt = STT(api_key="test-key")

def _fake_create_task(coro, *args, **kwargs):
# Close the coroutine so we don't get RuntimeWarning about it never
# being awaited. Return a benign mock so callers that chain
# add_done_callback / cancel don't break.
coro.close()
task = MagicMock()
return task

with patch("livekit.agents.stt.stt.asyncio.create_task", side_effect=_fake_create_task):
stream = SpeechStream(
stt=stt,
opts=stt._opts,
conn_options=DEFAULT_API_CONNECT_OPTIONS,
api_key="test-key",
http_session=MagicMock(),
base_url="wss://streaming.assemblyai.com",
)
return stream


async def test_speech_started_uses_start_time_anchor():
"""The SpeechStarted handler converts server timestamp_ms to wall-clock
using self.start_time, and emits a SpeechEvent with the correct
speech_start_time."""
stream = _make_stream_for_unit_test()

# Override the stream anchor to a known value — simulates what send_task
# would do on the first ws.send_bytes.
anchor = 1_700_000_000.0
stream.start_time = anchor

# Simulate the server sending a SpeechStarted message 500ms into the stream.
stream._process_stream_event({"type": "SpeechStarted", "timestamp": 500})

ev = stream._event_ch.recv_nowait()
assert ev.type == SpeechEventType.START_OF_SPEECH
assert ev.speech_start_time == anchor + 0.5


async def test_speech_started_timestamp_zero_still_anchored():
"""A timestamp of 0 is a valid onset at stream start (not treated as
'missing field'). Tests the earlier fix for the `is not None` check."""
stream = _make_stream_for_unit_test()

anchor = 1_700_000_000.0
stream.start_time = anchor
stream._process_stream_event({"type": "SpeechStarted", "timestamp": 0})

ev = stream._event_ch.recv_nowait()
assert ev.type == SpeechEventType.START_OF_SPEECH
assert ev.speech_start_time == anchor


async def test_speech_started_without_timestamp_leaves_field_none():
"""If the server omits `timestamp` entirely, the plugin should emit
START_OF_SPEECH with speech_start_time=None so the framework falls back
to message-arrival time (pre-PR behavior)."""
stream = _make_stream_for_unit_test()

# Server sends SpeechStarted with no timestamp field.
stream._process_stream_event({"type": "SpeechStarted"})

ev = stream._event_ch.recv_nowait()
assert ev.type == SpeechEventType.START_OF_SPEECH
assert ev.speech_start_time is None


async def test_start_time_has_default_before_plugin_override():
"""Even without any plugin override, the base-class default
(time.time() seeded in __init__) is available for the SpeechStarted
conversion to use."""
stream = _make_stream_for_unit_test()

# start_time should already be a recent wall-clock value from the base
# class __init__, without any explicit override.
assert time.time() - stream.start_time < 5.0
115 changes: 115 additions & 0 deletions tests/test_stt_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Unit tests for base STT `RecognizeStream` fields (start_time, etc.)."""

from __future__ import annotations

import asyncio
import time

import pytest

from livekit.agents import APIConnectionError
from livekit.agents.stt import (
STT,
RecognizeStream,
SpeechData,
SpeechEvent,
SpeechEventType,
STTCapabilities,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS
from livekit.agents.utils.audio import AudioBuffer


class _DummyStream(RecognizeStream):
"""Minimal RecognizeStream for unit tests — does not hit the network."""

def __init__(
self,
*,
stt: STT,
fail_first_run: bool = False,
) -> None:
super().__init__(stt=stt, conn_options=DEFAULT_API_CONNECT_OPTIONS)
self._fail_first_run = fail_first_run
self._run_count = 0

async def _run(self) -> None:
self._run_count += 1
if self._fail_first_run and self._run_count == 1:
raise APIConnectionError("fake failure to trigger retry")
# emit a final and exit so _main_task can complete normally
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[SpeechData(language="", text="hello")],
)
)


class _DummySTT(STT):
def __init__(self) -> None:
super().__init__(capabilities=STTCapabilities(streaming=True, interim_results=False))

async def _recognize_impl(self, buffer: AudioBuffer, *, language, conn_options) -> SpeechEvent:
raise NotImplementedError

def stream(self, *, language=None, conn_options=DEFAULT_API_CONNECT_OPTIONS) -> _DummyStream:
return _DummyStream(stt=self)


async def test_start_time_seeded_on_init() -> None:
"""start_time is initialized to approximately time.time() when the stream is created."""
stt = _DummySTT()
before = time.time()
stream = stt.stream()
after = time.time()

assert before <= stream.start_time <= after
await stream.aclose()


async def test_start_time_setter_accepts_valid_values() -> None:
"""Plugins can override start_time by assigning to the public property."""
stt = _DummySTT()
stream = stt.stream()

new_anchor = time.time() + 10.0
stream.start_time = new_anchor
assert stream.start_time == new_anchor

await stream.aclose()


async def test_start_time_setter_rejects_negative() -> None:
"""start_time setter validates non-negative, matching start_time_offset behavior."""
stt = _DummySTT()
stream = stt.stream()

with pytest.raises(ValueError, match="start_time must be non-negative"):
stream.start_time = -1.0

await stream.aclose()


async def test_start_time_reseeded_on_retry() -> None:
"""When _main_task retries after an APIError, start_time is re-seeded so plugin
overrides from the previous connection don't leak into the new one."""
stt = _DummySTT()
stream = _DummyStream(stt=stt, fail_first_run=True)

# Simulate a plugin overriding start_time during the first (failing) _run()
# by assigning a sentinel value before the task picks up.
sentinel = 1.0
stream.start_time = sentinel

# Let the main task run: it should retry past the first-run APIError, and
# on each attempt re-seed start_time to a fresh time.time() value before
# _run() is called.
await asyncio.wait_for(stream._task, timeout=5.0)

# After the retry, start_time must have been re-seeded (not equal to sentinel).
assert stream.start_time != sentinel
# And it should be a recent wall-clock value.
assert time.time() - stream.start_time < 5.0

await stream.aclose()