Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,46 @@

from .log import logger

_DISFLUENCIES: frozenset[str] = frozenset(
{
"hm",
"hmm",
"mm",
"mmh",
"mmhm",
"mmhmm",
"mhmm",
"mm-hm",
"mm-hmm",
"mhm",
"huh",
"uh",
"uh-huh",
"uh-uh",
"um",
"umm",
"uhm",
"er",
"erm",
"ah",
"oh",
"eh",
}
)
_EM_DASH = "\u2014"
_PUNCT_STRIP = ".,!?;:\"'"


def _strip_em_dashes(s: str) -> str:
return s.replace(_EM_DASH, " ")


def _is_pure_disfluency(s: str) -> bool:
tokens = s.lower().split()
if not tokens:
return False
return all(t.strip(_PUNCT_STRIP) in _DISFLUENCIES for t in tokens)


@dataclass
class STTOptions:
Expand All @@ -63,6 +103,7 @@ class STTOptions:
speaker_labels: NotGivenOr[bool] = NOT_GIVEN
max_speakers: NotGivenOr[int] = NOT_GIVEN
domain: NotGivenOr[str] = NOT_GIVEN
filter_disfluencies: bool = False


class STT(stt.STT):
Expand All @@ -89,6 +130,7 @@ def __init__(
speaker_labels: NotGivenOr[bool] = NOT_GIVEN,
max_speakers: NotGivenOr[int] = NOT_GIVEN,
domain: NotGivenOr[str] = NOT_GIVEN,
filter_disfluencies: bool = False,
http_session: aiohttp.ClientSession | None = None,
buffer_size_seconds: float = 0.05,
base_url: str = "wss://streaming.assemblyai.com",
Expand Down Expand Up @@ -164,6 +206,7 @@ def __init__(
speaker_labels=speaker_labels,
max_speakers=max_speakers,
domain=domain,
filter_disfluencies=filter_disfluencies,
)
self._session = http_session
self._streams = weakref.WeakSet[SpeechStream]()
Expand Down Expand Up @@ -219,6 +262,7 @@ def update_options(
prompt: NotGivenOr[str] = NOT_GIVEN,
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
filter_disfluencies: NotGivenOr[bool] = NOT_GIVEN,
# Deprecated — use min_turn_silence instead
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
) -> None:
Expand All @@ -244,6 +288,8 @@ def update_options(
self._opts.keyterms_prompt = keyterms_prompt
if is_given(vad_threshold):
self._opts.vad_threshold = vad_threshold
if is_given(filter_disfluencies):
self._opts.filter_disfluencies = filter_disfluencies

for stream in self._streams:
stream.update_options(
Expand All @@ -254,6 +300,7 @@ def update_options(
prompt=prompt,
keyterms_prompt=keyterms_prompt,
vad_threshold=vad_threshold,
filter_disfluencies=filter_disfluencies,
)


Expand Down Expand Up @@ -306,6 +353,7 @@ def update_options(
prompt: NotGivenOr[str] = NOT_GIVEN,
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
filter_disfluencies: NotGivenOr[bool] = NOT_GIVEN,
# Deprecated — use min_turn_silence instead
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
) -> None:
Expand All @@ -331,6 +379,8 @@ def update_options(
self._opts.keyterms_prompt = keyterms_prompt
if is_given(vad_threshold):
self._opts.vad_threshold = vad_threshold
if is_given(filter_disfluencies):
self._opts.filter_disfluencies = filter_disfluencies

# Send UpdateConfiguration message over the active websocket
config_msg: dict = {"type": "UpdateConfiguration"}
Expand Down Expand Up @@ -545,6 +595,11 @@ def _process_stream_event(self, data: dict) -> None:
speaker_label = data.get("speaker_label")
speaker_id = speaker_label if speaker_label and speaker_label != "UNKNOWN" else None

filter_enabled = self._opts.filter_disfluencies
if filter_enabled:
transcript = _strip_em_dashes(transcript)
utterance = _strip_em_dashes(utterance)

# transcript (final) and words (interim) are cumulative
# utterance (preflight) is chunk based
start_time: float = 0
Expand All @@ -554,7 +609,11 @@ def _process_stream_event(self, data: dict) -> None:
# https://www.assemblyai.com/docs/api-reference/streaming-api/streaming-api#receive.receiveTurn.words
timed_words: list[TimedString] = [
TimedString(
text=word.get("text", ""),
text=(
_strip_em_dashes(word.get("text", ""))
if filter_enabled
else word.get("text", "")
),
start_time=word.get("start", 0) / 1000 + self.start_time_offset,
end_time=word.get("end", 0) / 1000 + self.start_time_offset,
start_time_offset=self.start_time_offset,
Expand All @@ -570,22 +629,25 @@ def _process_stream_event(self, data: dict) -> None:
end_time = timed_words[-1].end_time or end_time
confidence = sum(word.confidence or 0.0 for word in timed_words) / len(timed_words)

interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language,
text=interim_text,
start_time=start_time,
end_time=end_time,
words=timed_words,
confidence=confidence,
speaker_id=speaker_id,
)
],
)
self._event_ch.send_nowait(interim_event)
logger.debug("interim transcript end_of_turn_confidence=%s", end_of_turn_confidence)
if filter_enabled and _is_pure_disfluency(interim_text):
logger.debug("filtered pure-disfluency interim transcript")
else:
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language,
text=interim_text,
start_time=start_time,
end_time=end_time,
words=timed_words,
confidence=confidence,
speaker_id=speaker_id,
)
],
)
self._event_ch.send_nowait(interim_event)
logger.debug("interim transcript end_of_turn_confidence=%s", end_of_turn_confidence)

if utterance:
if self._last_preflight_start_time == 0.0:
Expand All @@ -602,33 +664,46 @@ def _process_stream_event(self, data: dict) -> None:
len(utterance_words), 1
)

final_event = stt.SpeechEvent(
type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language,
text=utterance,
start_time=self._last_preflight_start_time,
end_time=end_time,
words=utterance_words,
confidence=utterance_confidence,
speaker_id=speaker_id,
)
],
)
self._event_ch.send_nowait(final_event)
logger.debug("preflight transcript end_of_turn_confidence=%s", end_of_turn_confidence)
if filter_enabled and _is_pure_disfluency(utterance):
logger.debug("filtered pure-disfluency preflight transcript")
else:
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language,
text=utterance,
start_time=self._last_preflight_start_time,
end_time=end_time,
words=utterance_words,
confidence=utterance_confidence,
speaker_id=speaker_id,
)
],
)
self._event_ch.send_nowait(final_event)
logger.debug(
"preflight transcript end_of_turn_confidence=%s", end_of_turn_confidence
)
self._last_preflight_start_time = end_time

if end_of_turn and (
not (is_given(self._opts.format_turns) and self._opts.format_turns) or turn_is_formatted
):
final_text = transcript
if filter_enabled and _is_pure_disfluency(transcript):
# emit empty text to preserve _final_transcript_received.set() in
# AudioRecognition (prevents commit_user_turn hang) while suppressing
# downstream hooks, preemptive generation, and EOU detection
final_text = ""
logger.debug("filtered pure-disfluency final transcript")

final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language,
text=transcript,
text=final_text,
start_time=start_time,
end_time=end_time,
words=timed_words,
Expand Down
Loading