diff --git a/Cargo.lock b/Cargo.lock index e600c6c..84578bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -129,12 +129,74 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "multer", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.13.1" @@ -618,6 +680,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "env_home" version = "0.1.0" @@ -772,11 +843,38 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] [[package]] name = "getrandom" @@ -887,12 +985,77 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.65" @@ -1319,6 +1482,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -1344,6 +1513,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1405,6 +1580,23 @@ dependencies = [ "syn", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -1823,6 +2015,12 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -2297,6 +2495,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "0.6.9" @@ -2306,6 +2515,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2336,6 +2557,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + [[package]] name = "smallvec" version = "1.15.1" @@ -2363,6 +2590,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2416,6 +2649,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -2668,12 +2907,41 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2892,6 +3160,7 @@ version = "0.6.6" dependencies = [ "anyhow", "async-trait", + "axum", "chrono", "clap", "clap_mangen", diff --git a/Cargo.toml b/Cargo.toml index d52058b..d5f8e71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ categories = ["multimedia::audio", "accessibility"] [dependencies] # Async runtime tokio = { version = "1", features = ["full", "signal", "sync", "time", "process", "io-util"] } +axum = { version = "0.7", features = ["multipart"] } # CLI clap = { version = "4", features = ["derive"] } diff --git a/src/cli.rs b/src/cli.rs index 88763b0..3735798 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -154,8 +154,21 @@ pub struct Cli { #[arg(long, value_name = "KEY", help_heading = "Hotkey")] pub model_modifier: Option, - // -- Audio -- + // -- Service -- + + /// Enable local OpenAI-compatible STT HTTP service alongside daemon loop + #[arg(long, help_heading = "Service")] + pub service: bool, + + /// Service bind host override (default: 127.0.0.1) + #[arg(long, value_name = "HOST", help_heading = "Service")] + pub service_host: Option, + /// Service bind port override (default: 8427) + #[arg(long, value_name = "PORT", help_heading = "Service")] + pub service_port: Option, + + // -- Audio -- /// Audio input device name (or "default" for system default) #[arg(long, value_name = "DEVICE", help_heading = "Audio")] pub audio_device: Option, @@ -298,7 +311,6 @@ pub struct Cli { pub append_text: Option, // -- VAD -- - /// Enable Voice Activity Detection (filter silence before transcription) #[arg(long, help_heading = "VAD")] pub vad: bool, diff --git a/src/config.rs b/src/config.rs index 25d4bd0..c3ad0b9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -295,6 +295,17 @@ on_transcription = true # [profiles.code] # post_process_command = "ollama run llama3.2:1b 'Format as code comment...'" # output_mode = "clipboard" + +# [service] +# Run local OpenAI-compatible STT API in parallel with daemon hotkey flow. +# Keep loopback-only unless you intentionally front it with a trusted proxy. +# +# enabled = false +# host = "127.0.0.1" +# port = 8427 +# max_upload_bytes = 10485760 +# request_timeout_ms = 90000 +# allowed_languages = ["de", "en"] "#; /// Hotkey activation mode @@ -363,6 +374,10 @@ pub struct Config { #[serde(default)] pub meeting: MeetingConfig, + /// Local HTTP service configuration (OpenAI-compatible STT API) + #[serde(default)] + pub service: ServiceConfig, + /// Optional path to state file for external integrations (e.g., Waybar) /// When set, the daemon writes current state ("idle", "recording", "transcribing") /// to this file whenever state changes. @@ -419,6 +434,35 @@ pub struct HotkeyConfig { pub profile_modifiers: HashMap, } +/// Local HTTP service configuration. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ServiceConfig { + /// Enable local HTTP service mode alongside daemon hotkey loop. + #[serde(default)] + pub enabled: bool, + + /// Listener host. Default is loopback only. + #[serde(default = "default_service_host")] + pub host: String, + + /// Listener port. + #[serde(default = "default_service_port")] + pub port: u16, + + /// Maximum request payload size in bytes. + #[serde(default = "default_service_max_upload_bytes")] + pub max_upload_bytes: usize, + + /// Request timeout in milliseconds. + #[serde(default = "default_service_request_timeout_ms")] + pub request_timeout_ms: u64, + + /// Allowed language set used for constrained auto-detection when callers + /// do not explicitly pin a language. + #[serde(default = "default_service_allowed_languages")] + pub allowed_languages: Vec, +} + /// Audio capture configuration #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AudioConfig { @@ -1797,6 +1841,39 @@ fn default_true() -> bool { true } +fn default_service_host() -> String { + "127.0.0.1".to_string() +} + +fn default_service_port() -> u16 { + 8427 +} + +fn default_service_max_upload_bytes() -> usize { + 200 * 1024 * 1024 +} + +fn default_service_request_timeout_ms() -> u64 { + 600_000 +} + +fn default_service_allowed_languages() -> Vec { + vec!["de".to_string(), "en".to_string()] +} + +impl Default for ServiceConfig { + fn default() -> Self { + Self { + enabled: false, + host: default_service_host(), + port: default_service_port(), + max_upload_bytes: default_service_max_upload_bytes(), + request_timeout_ms: default_service_request_timeout_ms(), + allowed_languages: default_service_allowed_languages(), + } + } +} + impl Default for Config { fn default() -> Self { Self { @@ -1877,6 +1954,7 @@ impl Default for Config { vad: VadConfig::default(), status: StatusConfig::default(), meeting: MeetingConfig::default(), + service: ServiceConfig::default(), state_file: Some("auto".to_string()), profiles: HashMap::new(), } @@ -2197,6 +2275,42 @@ pub fn load_config(path: Option<&Path>) -> Result { config.text.smart_auto_submit = parse_bool_env(&val); } + // Local service + if let Ok(val) = std::env::var("VOXTYPE_SERVICE_ENABLED") { + config.service.enabled = parse_bool_env(&val); + } + if let Ok(host) = std::env::var("VOXTYPE_SERVICE_HOST") { + let trimmed = host.trim(); + if !trimmed.is_empty() { + config.service.host = trimmed.to_string(); + } + } + if let Ok(val) = std::env::var("VOXTYPE_SERVICE_PORT") { + if let Ok(port) = val.parse::() { + config.service.port = port; + } + } + if let Ok(val) = std::env::var("VOXTYPE_SERVICE_MAX_UPLOAD_BYTES") { + if let Ok(bytes) = val.parse::() { + config.service.max_upload_bytes = bytes; + } + } + if let Ok(val) = std::env::var("VOXTYPE_SERVICE_REQUEST_TIMEOUT_MS") { + if let Ok(timeout_ms) = val.parse::() { + config.service.request_timeout_ms = timeout_ms; + } + } + if let Ok(raw) = std::env::var("VOXTYPE_SERVICE_ALLOWED_LANGUAGES") { + let langs: Vec = raw + .split(',') + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .collect(); + if !langs.is_empty() { + config.service.allowed_languages = langs; + } + } + Ok(config) } @@ -2232,6 +2346,46 @@ mod tests { assert_eq!(config.whisper.model, "base.en"); assert_eq!(config.output.mode, OutputMode::Type); assert!(!config.output.auto_submit); + assert!(!config.service.enabled); + assert_eq!(config.service.host, "127.0.0.1"); + assert_eq!(config.service.port, 8427); + assert_eq!(config.service.allowed_languages, vec!["de", "en"]); + } + + #[test] + fn test_parse_service_config_toml() { + let toml_str = r#" + [hotkey] + key = "SCROLLLOCK" + + [audio] + device = "default" + sample_rate = 16000 + max_duration_secs = 60 + + [whisper] + model = "base.en" + language = "en" + + [output] + mode = "type" + + [service] + enabled = true + host = "127.0.0.1" + port = 9027 + max_upload_bytes = 5242880 + request_timeout_ms = 45000 + allowed_languages = ["en", "de"] + "#; + + let config: Config = toml::from_str(toml_str).unwrap(); + assert!(config.service.enabled); + assert_eq!(config.service.host, "127.0.0.1"); + assert_eq!(config.service.port, 9027); + assert_eq!(config.service.max_upload_bytes, 5_242_880); + assert_eq!(config.service.request_timeout_ms, 45_000); + assert_eq!(config.service.allowed_languages, vec!["en", "de"]); } #[test] diff --git a/src/daemon.rs b/src/daemon.rs index 6e50c7a..28fbd4d 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -516,6 +516,8 @@ pub struct Daemon { meeting_loopback_buffer: Vec, // Meeting event receiver meeting_event_rx: Option>, + // Local HTTP transcription service handle + service_handle: Option, // GTCRN speech enhancer for mic echo cancellation #[cfg(feature = "onnx-common")] speech_enhancer: Option>, @@ -615,6 +617,7 @@ impl Daemon { meeting_mic_buffer: Vec::new(), meeting_loopback_buffer: Vec::new(), meeting_event_rx: None, + service_handle: None, #[cfg(feature = "onnx-common")] speech_enhancer: None, paused_media_players: Vec::new(), @@ -1689,6 +1692,25 @@ impl Daemon { self.model_manager = Some(model_manager); + // Start local HTTP service if enabled, sharing the daemon's transcriber + if self.config.service.enabled { + let shared_transcriber: Option> = match self.config.engine { + crate::config::TranscriptionEngine::Whisper => { + self.model_manager + .as_mut() + .and_then(|mm| mm.get_transcriber(None).ok()) + } + _ => transcriber_preloaded.clone(), + }; + let handle = crate::service::start( + &self.config, + self.config_path.clone(), + shared_transcriber, + ).await?; + tracing::info!("Local service listening on http://{}", handle.addr()); + self.service_handle = Some(handle); + } + // Start hotkey listener (if enabled) let mut hotkey_rx = if let Some(ref mut listener) = hotkey_listener { Some(listener.start().await?) @@ -2840,6 +2862,12 @@ impl Daemon { let _ = self.stop_meeting().await; } + // Stop local HTTP service + if let Some(handle) = self.service_handle.take() { + tracing::info!("Stopping local service"); + handle.shutdown().await; + } + // Remove override files on shutdown cleanup_profile_override(); diff --git a/src/lib.rs b/src/lib.rs index 301291a..5bb6788 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,6 +79,7 @@ pub mod hotkey; pub mod meeting; pub mod model_manager; pub mod output; +pub mod service; pub mod setup; pub mod state; pub mod text; diff --git a/src/main.rs b/src/main.rs index 6279e7e..5093d87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -162,6 +162,20 @@ async fn main() -> anyhow::Result<()> { config.hotkey.model_modifier = Some(model_modifier); } + // Service overrides + if cli.service { + config.service.enabled = true; + } + if let Some(host) = cli.service_host { + let trimmed = host.trim(); + if !trimmed.is_empty() { + config.service.host = trimmed.to_string(); + } + } + if let Some(port) = cli.service_port { + config.service.port = port; + } + // Whisper overrides if let Some(delay) = cli.pre_type_delay { config.output.pre_type_delay_ms = delay; diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 0000000..eee652e --- /dev/null +++ b/src/service.rs @@ -0,0 +1,1147 @@ +//! Local OpenAI-compatible HTTP transcription service. +//! +//! Runs in-process with the daemon and exposes: +//! - `GET /healthz` +//! - `POST /v1/audio/transcriptions` +//! - `POST /v1/audio/translations` (alias to transcriptions) + +use axum::extract::{DefaultBodyLimit, Multipart, State}; +use axum::http::{header, HeaderValue, StatusCode}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use serde::Serialize; +use std::collections::BTreeSet; +use std::io::Cursor; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::oneshot; + +use crate::config::{Config, LanguageConfig, ServiceConfig, TranscriptionEngine}; +use crate::error::{TranscribeError, VoxtypeError}; +use crate::meeting::VoiceActivityDetector; +use crate::transcribe::{Transcriber, TranscriptionResult, TranscriptionSegment}; + +const SERVICE_SAMPLE_RATE: usize = 16_000; +const LONG_FORM_CHUNK_SECS: usize = 30; +const LONG_FORM_CHUNK_THRESHOLD_SECS: usize = 90; +const LONG_FORM_VAD_THRESHOLD: f32 = 0.01; + +#[derive(Clone)] +struct AppState { + transcriber: Arc, + request_timeout: Duration, + allowed_languages: Arc>, +} + +#[derive(Serialize)] +struct HealthResponse { + status: &'static str, +} + +#[derive(Serialize)] +struct TranscriptionResponse { + text: String, +} + +#[derive(Serialize)] +struct VerboseTranscriptionResponse { + text: String, + language: String, + duration: f64, + segments: Vec, +} + +#[derive(Serialize)] +struct VerboseSegment { + id: usize, + start: f64, + end: f64, + text: String, +} + +#[derive(Serialize)] +struct ApiErrorResponse { + error: ApiErrorBody, +} + +#[derive(Serialize)] +struct ApiErrorBody { + message: String, + #[serde(rename = "type")] + error_type: String, +} + +#[derive(Debug)] +struct ApiError { + status: StatusCode, + message: String, + error_type: &'static str, +} + +impl ApiError { + fn bad_request(message: impl Into) -> Self { + Self { + status: StatusCode::BAD_REQUEST, + message: message.into(), + error_type: "invalid_request_error", + } + } + + fn internal(message: impl Into) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + message: message.into(), + error_type: "server_error", + } + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let body = ApiErrorResponse { + error: ApiErrorBody { + message: self.message, + error_type: self.error_type.to_string(), + }, + }; + (self.status, Json(body)).into_response() + } +} + +/// Running local service handle. +pub struct ServiceHandle { + addr: SocketAddr, + shutdown_tx: Option>, + task: tokio::task::JoinHandle<()>, +} + +impl ServiceHandle { + pub fn addr(&self) -> SocketAddr { + self.addr + } + + pub async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Err(e) = self.task.await { + tracing::warn!("Service task join error: {}", e); + } + } +} + +/// Start local OpenAI-compatible STT service sharing an existing transcriber. +/// +/// When `shared` is provided the service reuses it instead of loading a +/// second copy of the model into VRAM. Falls back to creating its own +/// transcriber when `shared` is `None`. +pub async fn start( + config: &Config, + config_path: Option, + shared: Option>, +) -> Result { + let service_cfg = config.service.clone(); + + let transcriber = if let Some(t) = shared { + tracing::info!("Service reusing daemon transcriber (shared VRAM)"); + t + } else { + let mut transcriber_config = config.clone(); + transcriber_config.whisper.language = + default_language_for_service(&service_cfg, &config.whisper.language); + + tokio::task::spawn_blocking(move || { + match transcriber_config.engine { + TranscriptionEngine::Whisper => { + crate::transcribe::create_transcriber_with_config_path( + &transcriber_config.whisper, + config_path, + ) + .map(Arc::from) + } + _ => crate::transcribe::create_transcriber(&transcriber_config).map(Arc::from), + } + }) + .await + .map_err(|e| { + VoxtypeError::Config(format!( + "Service transcriber initialization task failed: {}", + e + )) + })?? + }; + + start_with_transcriber(service_cfg, transcriber).await +} + +fn default_language_for_service( + service_cfg: &ServiceConfig, + fallback: &LanguageConfig, +) -> LanguageConfig { + let normalized = normalize_languages(&service_cfg.allowed_languages); + if normalized.is_empty() { + fallback.clone() + } else if normalized.len() == 1 { + LanguageConfig::Single(normalized[0].clone()) + } else { + LanguageConfig::Multiple(normalized) + } +} + +async fn start_with_transcriber( + service_cfg: ServiceConfig, + transcriber: Arc, +) -> Result { + let bind_addr = format!("{}:{}", service_cfg.host, service_cfg.port); + let listener = tokio::net::TcpListener::bind(&bind_addr) + .await + .map_err(|e| { + VoxtypeError::Config(format!( + "Failed to bind service listener on {}: {}", + bind_addr, e + )) + })?; + let local_addr = listener.local_addr().map_err(|e| { + VoxtypeError::Config(format!("Failed to read service local address: {}", e)) + })?; + + let state = AppState { + transcriber, + request_timeout: Duration::from_millis(service_cfg.request_timeout_ms), + allowed_languages: Arc::new(normalize_languages(&service_cfg.allowed_languages)), + }; + + let app = build_router(state, service_cfg.max_upload_bytes); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let task = tokio::spawn(async move { + if let Err(e) = axum::serve(listener, app.into_make_service()) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + { + tracing::error!("Service HTTP server failed: {}", e); + } + }); + + Ok(ServiceHandle { + addr: local_addr, + shutdown_tx: Some(shutdown_tx), + task, + }) +} + +fn build_router(state: AppState, max_upload_bytes: usize) -> Router { + Router::new() + .route("/healthz", get(healthz)) + .route("/v1/audio/transcriptions", post(transcribe_handler)) + .route("/v1/audio/translations", post(transcribe_handler)) + .layer(DefaultBodyLimit::max(max_upload_bytes)) + .with_state(state) +} + +async fn healthz() -> Json { + Json(HealthResponse { status: "ok" }) +} + +async fn transcribe_handler( + State(state): State, + mut multipart: Multipart, +) -> Result { + let mut audio_data: Option> = None; + let mut language: Option = None; + let mut prompt: Option = None; + let mut response_format = "json".to_string(); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| ApiError::bad_request(format!("Invalid multipart request: {}", e)))? + { + let name = field.name().unwrap_or("").to_string(); + match name.as_str() { + "file" => { + let bytes = field.bytes().await.map_err(|e| { + ApiError::bad_request(format!("Failed to read file field: {}", e)) + })?; + audio_data = Some(bytes.to_vec()); + } + "language" => { + let value = field + .text() + .await + .map_err(|e| ApiError::bad_request(format!("Invalid language field: {}", e)))?; + language = Some(value); + } + "prompt" => { + let value = field + .text() + .await + .map_err(|e| ApiError::bad_request(format!("Invalid prompt field: {}", e)))?; + prompt = Some(value); + } + "response_format" => { + response_format = field.text().await.map_err(|e| { + ApiError::bad_request(format!("Invalid response_format field: {}", e)) + })?; + } + _ => { + // Ignore non-essential fields (model, temperature, etc.). + } + } + } + + let audio_data = audio_data + .ok_or_else(|| ApiError::bad_request("Missing required multipart field: file"))?; + + let samples = decode_wav_to_mono_16k(&audio_data).map_err(ApiError::bad_request)?; + if samples.is_empty() { + return Err(ApiError::bad_request("Audio payload contains no samples")); + } + + let language_override = + normalize_language_override(language.as_deref(), &state.allowed_languages)?; + let prompt_override = prompt + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(ToOwned::to_owned); + + let format = response_format.trim().to_lowercase(); + if format != "json" && format != "verbose_json" && format != "text" { + return Err(ApiError::bad_request(format!( + "Unsupported response_format '{}'; expected json, verbose_json, or text", + response_format + ))); + } + + let use_segments = format == "verbose_json"; + let transcriber = state.transcriber.clone(); + let timeout = state.request_timeout; + + if use_segments { + let mut task = tokio::task::spawn_blocking(move || { + transcribe_segments_adaptive( + transcriber, + &samples, + language_override.as_deref(), + prompt_override.as_deref(), + ) + }); + + let result = tokio::select! { + join = &mut task => join, + _ = tokio::time::sleep(timeout) => { + task.abort(); + return Err(ApiError { + status: StatusCode::REQUEST_TIMEOUT, + message: format!("Transcription timed out after {}ms", timeout.as_millis()), + error_type: "timeout_error", + }); + } + }; + + let tr = match result { + Ok(Ok(tr)) => tr, + Ok(Err(e)) => return Err(map_transcription_error(e)), + Err(e) => return Err(ApiError::internal(format!("Transcription task failed: {}", e))), + }; + + let verbose = VerboseTranscriptionResponse { + text: tr.text, + language: tr.language, + duration: tr.duration, + segments: tr + .segments + .into_iter() + .enumerate() + .map(|(id, seg)| VerboseSegment { + id, + start: seg.start, + end: seg.end, + text: seg.text, + }) + .collect(), + }; + + Ok(Json(verbose).into_response()) + } else { + let mut task = tokio::task::spawn_blocking(move || { + transcribe_text_adaptive( + transcriber, + &samples, + language_override.as_deref(), + prompt_override.as_deref(), + ) + }); + + let result = tokio::select! { + join = &mut task => join, + _ = tokio::time::sleep(timeout) => { + task.abort(); + return Err(ApiError { + status: StatusCode::REQUEST_TIMEOUT, + message: format!("Transcription timed out after {}ms", timeout.as_millis()), + error_type: "timeout_error", + }); + } + }; + + let text = match result { + Ok(Ok(text)) => text, + Ok(Err(e)) => return Err(map_transcription_error(e)), + Err(e) => return Err(ApiError::internal(format!("Transcription task failed: {}", e))), + }; + + if format == "text" { + let mut response = text.into_response(); + response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + return Ok(response); + } + + Ok(Json(TranscriptionResponse { text }).into_response()) + } +} + +fn transcribe_text_adaptive( + transcriber: Arc, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, +) -> Result { + if should_chunk_long_form(samples) { + return Ok( + transcribe_segments_adaptive(transcriber, samples, language_override, prompt_override)? + .text, + ); + } + + transcriber.transcribe_with_options(samples, language_override, prompt_override) +} + +fn transcribe_segments_adaptive( + transcriber: Arc, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, +) -> Result { + if !should_chunk_long_form(samples) { + return transcriber.transcribe_segments(samples, language_override, prompt_override); + } + + let vad = VoiceActivityDetector::new(LONG_FORM_VAD_THRESHOLD, SERVICE_SAMPLE_RATE as u32); + let chunk_len = LONG_FORM_CHUNK_SECS * SERVICE_SAMPLE_RATE; + let total_duration = samples.len() as f64 / SERVICE_SAMPLE_RATE as f64; + let mut combined_text = String::new(); + let mut combined_segments = Vec::new(); + let mut detected_languages = BTreeSet::new(); + + tracing::info!( + "Long-form service request ({:.2}s) chunked into {}s windows", + total_duration, + LONG_FORM_CHUNK_SECS + ); + + for (chunk_index, chunk_samples) in samples.chunks(chunk_len).enumerate() { + if !vad.contains_speech(chunk_samples) { + tracing::debug!( + "Skipping silent long-form chunk {} ({:.2}s)", + chunk_index, + chunk_samples.len() as f64 / SERVICE_SAMPLE_RATE as f64 + ); + continue; + } + + let chunk_result = + transcriber.transcribe_segments(chunk_samples, language_override, prompt_override)?; + + let detected_language = chunk_result.language.trim().to_lowercase(); + if !detected_language.is_empty() && detected_language != "auto" { + detected_languages.insert(detected_language); + } + + push_text_piece(&mut combined_text, &chunk_result.text); + + let chunk_offset = (chunk_index * chunk_len) as f64 / SERVICE_SAMPLE_RATE as f64; + let chunk_duration = chunk_samples.len() as f64 / SERVICE_SAMPLE_RATE as f64; + + if chunk_result.segments.is_empty() { + let text = chunk_result.text.trim(); + if !text.is_empty() { + combined_segments.push(TranscriptionSegment { + start: chunk_offset, + end: chunk_offset + chunk_duration, + text: text.to_string(), + }); + } + continue; + } + + for segment in chunk_result.segments { + let text = segment.text.trim(); + if text.is_empty() { + continue; + } + combined_segments.push(TranscriptionSegment { + start: chunk_offset + segment.start, + end: chunk_offset + segment.end, + text: text.to_string(), + }); + } + } + + if combined_segments.is_empty() { + tracing::warn!( + "Long-form chunking found no speech chunks; falling back to single-pass transcription" + ); + return transcriber.transcribe_segments(samples, language_override, prompt_override); + } + + if combined_text.trim().is_empty() { + combined_text = combined_segments + .iter() + .map(|segment| segment.text.trim()) + .filter(|text| !text.is_empty()) + .collect::>() + .join(" "); + } + + Ok(TranscriptionResult { + text: combined_text.trim().to_string(), + language: summarize_detected_languages(language_override, &detected_languages), + duration: total_duration, + segments: combined_segments, + }) +} + +fn should_chunk_long_form(samples: &[f32]) -> bool { + samples.len() > LONG_FORM_CHUNK_THRESHOLD_SECS * SERVICE_SAMPLE_RATE +} + +fn summarize_detected_languages( + language_override: Option<&str>, + detected_languages: &BTreeSet, +) -> String { + if let Some(language) = language_override { + let trimmed = language.trim(); + if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("auto") { + return trimmed.to_lowercase(); + } + } + + match detected_languages.len() { + 0 => language_override.unwrap_or("auto").trim().to_lowercase(), + 1 => detected_languages.iter().next().cloned().unwrap_or_else(|| "auto".to_string()), + _ => "mixed".to_string(), + } +} + +fn push_text_piece(buffer: &mut String, piece: &str) { + let trimmed = piece.trim(); + if trimmed.is_empty() { + return; + } + + if !buffer.is_empty() { + buffer.push(' '); + } + buffer.push_str(trimmed); +} + +fn map_transcription_error(err: TranscribeError) -> ApiError { + match err { + TranscribeError::AudioFormat(msg) | TranscribeError::ConfigError(msg) => { + ApiError::bad_request(msg) + } + TranscribeError::ModelNotFound(msg) + | TranscribeError::InitFailed(msg) + | TranscribeError::InferenceFailed(msg) + | TranscribeError::NetworkError(msg) + | TranscribeError::RemoteError(msg) => ApiError { + status: StatusCode::BAD_GATEWAY, + message: msg, + error_type: "upstream_error", + }, + } +} + +fn normalize_languages(languages: &[String]) -> Vec { + let mut out = Vec::new(); + for lang in languages { + let normalized = lang.trim().to_lowercase(); + if !normalized.is_empty() && !out.contains(&normalized) { + out.push(normalized); + } + } + out +} + +fn normalize_language_override( + language: Option<&str>, + allowed_languages: &[String], +) -> Result, ApiError> { + let Some(raw) = language else { + return Ok(None); + }; + + let normalized = raw.trim().to_lowercase(); + if normalized.is_empty() { + return Ok(None); + } + + if normalized == "auto" { + return Ok(Some(normalized)); + } + + if !allowed_languages.is_empty() && !allowed_languages.contains(&normalized) { + return Err(ApiError::bad_request(format!( + "Language '{}' is not allowed; allowed languages: {}", + normalized, + allowed_languages.join(", ") + ))); + } + + Ok(Some(normalized)) +} + +fn decode_wav_to_mono_16k(wav_bytes: &[u8]) -> Result, String> { + let cursor = Cursor::new(wav_bytes); + let mut reader = + hound::WavReader::new(cursor).map_err(|e| format!("Invalid WAV payload: {}", e))?; + let spec = reader.spec(); + + let channels = spec.channels as usize; + if channels == 0 { + return Err("WAV payload has zero channels".to_string()); + } + let sample_rate = spec.sample_rate; + + let interleaved: Vec = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::() + .map(|s| s.map(|v| v.clamp(-1.0, 1.0))) + .collect::, _>>() + .map_err(|e| format!("Failed to decode float WAV samples: {}", e))?, + hound::SampleFormat::Int => { + if spec.bits_per_sample <= 8 { + reader + .samples::() + .map(|s| s.map(|v| v as f32 / i8::MAX as f32)) + .collect::, _>>() + .map_err(|e| format!("Failed to decode 8-bit WAV samples: {}", e))? + } else if spec.bits_per_sample <= 16 { + reader + .samples::() + .map(|s| s.map(|v| v as f32 / i16::MAX as f32)) + .collect::, _>>() + .map_err(|e| format!("Failed to decode 16-bit WAV samples: {}", e))? + } else { + let max_val = + ((1_i64 << (spec.bits_per_sample.saturating_sub(1) as u32)) - 1) as f32; + reader + .samples::() + .map(|s| s.map(|v| v as f32 / max_val)) + .collect::, _>>() + .map_err(|e| { + format!( + "Failed to decode {}-bit WAV samples: {}", + spec.bits_per_sample, e + ) + })? + } + } + }; + + let mono = if channels == 1 { + interleaved + } else { + let frame_count = interleaved.len() / channels; + let mut downmixed = Vec::with_capacity(frame_count); + for i in 0..frame_count { + let mut sum = 0.0f32; + for ch in 0..channels { + sum += interleaved[i * channels + ch]; + } + downmixed.push((sum / channels as f32).clamp(-1.0, 1.0)); + } + downmixed + }; + + if sample_rate == 16000 { + return Ok(mono); + } + + Ok(resample_linear(&mono, sample_rate, 16000)) +} + +fn resample_linear(samples: &[f32], source_rate: u32, target_rate: u32) -> Vec { + if samples.is_empty() || source_rate == target_rate { + return samples.to_vec(); + } + + let ratio = target_rate as f64 / source_rate as f64; + let output_len = (samples.len() as f64 * ratio).ceil() as usize; + let mut out = Vec::with_capacity(output_len); + + for i in 0..output_len { + let source_pos = i as f64 / ratio; + let idx = source_pos.floor() as usize; + let frac = (source_pos - idx as f64) as f32; + + let value = if idx + 1 < samples.len() { + samples[idx] * (1.0 - frac) + samples[idx + 1] * frac + } else { + samples.get(idx).copied().unwrap_or(0.0) + }; + out.push(value.clamp(-1.0, 1.0)); + } + + out +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{WhisperConfig, WhisperMode}; + use crate::transcribe::remote::RemoteTranscriber; + use crate::transcribe::{TranscriptionResult, TranscriptionSegment}; + use std::sync::Mutex; + + struct MockTranscriber { + text: String, + calls: Mutex, Option)>>, + } + + impl MockTranscriber { + fn new(text: &str) -> Self { + Self { + text: text.to_string(), + calls: Mutex::new(Vec::new()), + } + } + } + + impl Transcriber for MockTranscriber { + fn transcribe(&self, _samples: &[f32]) -> Result { + Ok(self.text.clone()) + } + + fn transcribe_with_options( + &self, + _samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + self.calls.lock().unwrap().push(( + language_override.map(ToOwned::to_owned), + prompt_override.map(ToOwned::to_owned), + )); + Ok(self.text.clone()) + } + + fn transcribe_segments( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + self.calls.lock().unwrap().push(( + language_override.map(ToOwned::to_owned), + prompt_override.map(ToOwned::to_owned), + )); + let duration = samples.len() as f64 / 16000.0; + Ok(TranscriptionResult { + text: self.text.clone(), + language: language_override.unwrap_or("en").to_string(), + duration, + segments: vec![ + TranscriptionSegment { + start: 0.0, + end: duration / 2.0, + text: "hello from".to_string(), + }, + TranscriptionSegment { + start: duration / 2.0, + end: duration, + text: "local service".to_string(), + }, + ], + }) + } + } + + struct ChunkCountingTranscriber { + calls: Mutex>, + } + + impl ChunkCountingTranscriber { + fn new() -> Self { + Self { + calls: Mutex::new(Vec::new()), + } + } + } + + impl Transcriber for ChunkCountingTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + self.transcribe_with_options(samples, None, None) + } + + fn transcribe_with_options( + &self, + samples: &[f32], + _language_override: Option<&str>, + _prompt_override: Option<&str>, + ) -> Result { + let mut calls = self.calls.lock().unwrap(); + calls.push(samples.len()); + Ok(format!("chunk {}", calls.len())) + } + + fn transcribe_segments( + &self, + samples: &[f32], + _language_override: Option<&str>, + _prompt_override: Option<&str>, + ) -> Result { + let mut calls = self.calls.lock().unwrap(); + calls.push(samples.len()); + let call_index = calls.len(); + let duration = samples.len() as f64 / SERVICE_SAMPLE_RATE as f64; + Ok(TranscriptionResult { + text: format!("chunk {}", call_index), + language: "de".to_string(), + duration, + segments: vec![TranscriptionSegment { + start: 0.0, + end: duration, + text: format!("segment {}", call_index), + }], + }) + } + } + + struct LanguageTrackingTranscriber { + languages_seen: Mutex>>, + returned_languages: Vec, + } + + impl LanguageTrackingTranscriber { + fn new(returned_languages: &[&str]) -> Self { + Self { + languages_seen: Mutex::new(Vec::new()), + returned_languages: returned_languages.iter().map(|s| s.to_string()).collect(), + } + } + } + + impl Transcriber for LanguageTrackingTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + self.transcribe_with_options(samples, None, None) + } + + fn transcribe_with_options( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + Ok(self + .transcribe_segments(samples, language_override, prompt_override)? + .text) + } + + fn transcribe_segments( + &self, + samples: &[f32], + language_override: Option<&str>, + _prompt_override: Option<&str>, + ) -> Result { + let mut seen = self.languages_seen.lock().unwrap(); + seen.push(language_override.map(ToOwned::to_owned)); + let call_index = seen.len(); + let duration = samples.len() as f64 / SERVICE_SAMPLE_RATE as f64; + let language = self + .returned_languages + .get(call_index - 1) + .cloned() + .unwrap_or_else(|| "auto".to_string()); + + Ok(TranscriptionResult { + text: format!("chunk {}", call_index), + language, + duration, + segments: vec![TranscriptionSegment { + start: 0.0, + end: duration, + text: format!("segment {}", call_index), + }], + }) + } + } + + async fn spawn_test_server( + transcriber: Arc, + allowed_languages: Vec, + ) -> ServiceHandle { + let service_cfg = ServiceConfig { + enabled: true, + host: "127.0.0.1".to_string(), + port: 0, + max_upload_bytes: 2_000_000, + request_timeout_ms: 5000, + allowed_languages, + }; + + start_with_transcriber(service_cfg, transcriber).await.unwrap() + } + + fn sine_samples(sample_rate: u32, duration_secs: f32, freq_hz: f32) -> Vec { + let sample_count = (sample_rate as f32 * duration_secs) as usize; + (0..sample_count) + .map(|i| { + let t = i as f32 / sample_rate as f32; + (2.0 * std::f32::consts::PI * freq_hz * t).sin() * 0.2 + }) + .collect() + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn remote_client_can_call_local_service() { + let mock = Arc::new(MockTranscriber::new("hello from local service")); + let handle = + spawn_test_server(mock.clone(), vec!["en".to_string(), "de".to_string()]).await; + let endpoint = format!("http://{}", handle.addr()); + + let cfg = WhisperConfig { + mode: Some(WhisperMode::Remote), + remote_endpoint: Some(endpoint), + remote_model: Some("whisper-1".to_string()), + language: LanguageConfig::Single("en".to_string()), + ..Default::default() + }; + + let client = RemoteTranscriber::new(&cfg).unwrap(); + let samples = sine_samples(16000, 0.3, 440.0); + let text = tokio::task::spawn_blocking(move || client.transcribe(&samples)) + .await + .unwrap() + .unwrap(); + assert_eq!(text, "hello from local service"); + + let calls = mock.calls.lock().unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].0.as_deref(), Some("en")); + + handle.shutdown().await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn language_outside_allowed_set_is_rejected() { + let mock = Arc::new(MockTranscriber::new("ignored")); + let handle = spawn_test_server(mock, vec!["en".to_string(), "de".to_string()]).await; + let endpoint = format!("http://{}", handle.addr()); + + let cfg = WhisperConfig { + mode: Some(WhisperMode::Remote), + remote_endpoint: Some(endpoint), + remote_model: Some("whisper-1".to_string()), + language: LanguageConfig::Single("fr".to_string()), + ..Default::default() + }; + + let client = RemoteTranscriber::new(&cfg).unwrap(); + let samples = sine_samples(16000, 0.2, 440.0); + let err = tokio::task::spawn_blocking(move || client.transcribe(&samples)) + .await + .unwrap() + .unwrap_err() + .to_string(); + assert!(err.contains("Language 'fr' is not allowed")); + + handle.shutdown().await; + } + + fn make_wav_bytes(samples: &[f32]) -> Vec { + let spec = hound::WavSpec { + channels: 1, + sample_rate: 16000, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + let mut cursor = Cursor::new(Vec::new()); + { + let mut writer = hound::WavWriter::new(&mut cursor, spec).unwrap(); + for &s in samples { + writer.write_sample((s * i16::MAX as f32) as i16).unwrap(); + } + writer.finalize().unwrap(); + } + cursor.into_inner() + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn verbose_json_returns_segments() { + use std::io::{Read, Write}; + + let mock = Arc::new(MockTranscriber::new("hello from local service")); + let handle = + spawn_test_server(mock.clone(), vec!["en".to_string(), "de".to_string()]).await; + let addr = handle.addr(); + + let samples = sine_samples(16000, 0.3, 440.0); + let wav_bytes = make_wav_bytes(&samples); + + let boundary = "----TestBoundary1234"; + let mut body = Vec::new(); + + body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes()); + body.extend_from_slice( + b"Content-Disposition: form-data; name=\"response_format\"\r\n\r\n", + ); + body.extend_from_slice(b"verbose_json\r\n"); + + body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes()); + body.extend_from_slice( + b"Content-Disposition: form-data; name=\"file\"; filename=\"audio.wav\"\r\n", + ); + body.extend_from_slice(b"Content-Type: audio/wav\r\n\r\n"); + body.extend_from_slice(&wav_bytes); + body.extend_from_slice(b"\r\n"); + body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes()); + + // Use raw TCP to avoid needing reqwest + let request = format!( + "POST /v1/audio/transcriptions HTTP/1.1\r\n\ + Host: {}\r\n\ + Content-Type: multipart/form-data; boundary={}\r\n\ + Content-Length: {}\r\n\ + Connection: close\r\n\ + \r\n", + addr, boundary, body.len() + ); + + let mut stream = std::net::TcpStream::connect(addr).unwrap(); + stream.write_all(request.as_bytes()).unwrap(); + stream.write_all(&body).unwrap(); + + let mut response = String::new(); + stream.read_to_string(&mut response).unwrap(); + + // Parse HTTP response body (after blank line) + let body_start = response.find("\r\n\r\n").unwrap() + 4; + let response_body = &response[body_start..]; + + // Handle chunked transfer encoding + let json_str = if response.contains("transfer-encoding: chunked") { + // Parse chunked body: size\r\ndata\r\n...0\r\n + let mut decoded = String::new(); + let mut remaining = response_body; + loop { + let size_end = remaining.find("\r\n").unwrap_or(0); + let chunk_size = + usize::from_str_radix(remaining[..size_end].trim(), 16).unwrap_or(0); + if chunk_size == 0 { + break; + } + let chunk_start = size_end + 2; + decoded.push_str(&remaining[chunk_start..chunk_start + chunk_size]); + remaining = &remaining[chunk_start + chunk_size + 2..]; + } + decoded + } else { + response_body.to_string() + }; + + let json: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + + assert_eq!(json["text"], "hello from local service"); + assert!(json["duration"].as_f64().unwrap() > 0.0); + assert!(json["language"].as_str().is_some()); + + let segments = json["segments"].as_array().unwrap(); + assert_eq!(segments.len(), 2); + assert_eq!(segments[0]["id"], 0); + assert_eq!(segments[0]["text"], "hello from"); + assert!(segments[0]["start"].as_f64().unwrap() >= 0.0); + assert_eq!(segments[1]["id"], 1); + assert_eq!(segments[1]["text"], "local service"); + + let calls = mock.calls.lock().unwrap(); + assert_eq!(calls.len(), 1); + + handle.shutdown().await; + } + + #[test] + fn decode_wav_resamples_to_16k() { + let samples = sine_samples(8000, 0.5, 440.0); + + let spec = hound::WavSpec { + channels: 1, + sample_rate: 8000, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut cursor = Cursor::new(Vec::new()); + { + let mut writer = hound::WavWriter::new(&mut cursor, spec).unwrap(); + for sample in samples { + let v = (sample * i16::MAX as f32) as i16; + writer.write_sample(v).unwrap(); + } + writer.finalize().unwrap(); + } + + let decoded = decode_wav_to_mono_16k(&cursor.into_inner()).unwrap(); + assert!(decoded.len() > 7000); + assert!(decoded.len() < 9000); + } + + #[test] + fn adaptive_long_form_chunking_splits_large_inputs() { + let counter = Arc::new(ChunkCountingTranscriber::new()); + let transcriber: Arc = counter.clone(); + let samples = sine_samples(SERVICE_SAMPLE_RATE as u32, 95.0, 440.0); + + let result = + transcribe_segments_adaptive(transcriber.clone(), &samples, Some("de"), None).unwrap(); + + let calls = counter.calls.lock().unwrap(); + assert_eq!(calls.len(), 4); + assert!(calls.iter().all(|&len| len <= 30 * SERVICE_SAMPLE_RATE)); + + assert_eq!(result.language, "de"); + assert_eq!(result.segments.len(), 4); + assert_eq!(result.segments[0].start, 0.0); + assert_eq!(result.segments[1].start, 30.0); + assert_eq!(result.segments[2].start, 60.0); + assert_eq!(result.segments[3].start, 90.0); + } + + #[test] + fn adaptive_long_form_chunking_reports_mixed_language_when_chunks_differ() { + let detector = Arc::new(LanguageTrackingTranscriber::new(&["de", "en", "de", "en"])); + let transcriber: Arc = detector.clone(); + let samples = sine_samples(SERVICE_SAMPLE_RATE as u32, 95.0, 440.0); + + let result = + transcribe_segments_adaptive(transcriber, &samples, Some("auto"), None).unwrap(); + + let seen = detector.languages_seen.lock().unwrap(); + assert_eq!(seen.len(), 4); + assert_eq!(seen[0].as_deref(), Some("auto")); + assert_eq!(seen[1].as_deref(), Some("auto")); + assert_eq!(seen[2].as_deref(), Some("auto")); + assert_eq!(seen[3].as_deref(), Some("auto")); + assert_eq!(result.language, "mixed"); + } +} diff --git a/src/transcribe/mod.rs b/src/transcribe/mod.rs index e42c6f5..5369a02 100644 --- a/src/transcribe/mod.rs +++ b/src/transcribe/mod.rs @@ -54,6 +54,8 @@ pub mod dolphin; #[cfg(feature = "omnilingual")] pub mod omnilingual; +use serde::Serialize; + use crate::config::{Config, TranscriptionEngine, WhisperConfig, WhisperMode}; use crate::error::TranscribeError; use crate::setup::gpu; @@ -68,6 +70,23 @@ pub struct TimedSegment { pub end_secs: f32, } +/// A single transcription segment with timestamps. +#[derive(Debug, Clone, Serialize)] +pub struct TranscriptionSegment { + pub start: f64, + pub end: f64, + pub text: String, +} + +/// Structured transcription result with segments. +#[derive(Debug, Clone, Serialize)] +pub struct TranscriptionResult { + pub text: String, + pub language: String, + pub duration: f64, + pub segments: Vec, +} + /// Trait for speech-to-text implementations pub trait Transcriber: Send + Sync { /// Transcribe audio samples to text @@ -90,6 +109,46 @@ pub trait Transcriber: Send + Sync { } } + /// Transcribe audio with optional request-level overrides. + /// + /// Implementations may ignore unsupported options and fall back to + /// `transcribe(samples)`. + fn transcribe_with_options( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + let _ = language_override; + let _ = prompt_override; + self.transcribe(samples) + } + + /// Transcribe with segment-level timestamps. + /// + /// Default implementation wraps `transcribe_with_options` into a single + /// segment spanning the full duration. Implementations that can produce + /// per-segment timestamps (e.g. WhisperTranscriber) should override this. + fn transcribe_segments( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + let text = self.transcribe_with_options(samples, language_override, prompt_override)?; + let duration = samples.len() as f64 / 16000.0; + Ok(TranscriptionResult { + text: text.clone(), + language: language_override.unwrap_or("auto").to_string(), + duration, + segments: vec![TranscriptionSegment { + start: 0.0, + end: duration, + text, + }], + }) + } + /// Prepare for transcription (optional, called when recording starts) /// /// For subprocess-based transcribers, this spawns the worker process diff --git a/src/transcribe/whisper.rs b/src/transcribe/whisper.rs index 2c4a428..8aaca14 100644 --- a/src/transcribe/whisper.rs +++ b/src/transcribe/whisper.rs @@ -7,16 +7,75 @@ //! - Auto-detect: Let Whisper detect from all ~99 supported languages //! - Constrained auto-detect: Detect from a user-specified subset of languages -use super::Transcriber; +use super::{Transcriber, TranscriptionResult, TranscriptionSegment}; use crate::config::{Config, LanguageConfig, WhisperConfig}; use crate::error::TranscribeError; +use std::ops::{Deref, DerefMut}; use std::path::PathBuf; -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; +use std::sync::{Condvar, Mutex}; +use whisper_rs::{ + FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState, +}; + +struct StatePool { + idle: Mutex>, + available: Condvar, +} + +impl StatePool { + fn new(items: Vec) -> Self { + assert!(!items.is_empty(), "state pool requires at least one item"); + Self { + idle: Mutex::new(items), + available: Condvar::new(), + } + } + + fn acquire(&self) -> StatePoolLease<'_, T> { + let mut idle = self.idle.lock().unwrap(); + loop { + if let Some(item) = idle.pop() { + return StatePoolLease { + pool: self, + item: Some(item), + }; + } + idle = self.available.wait(idle).unwrap(); + } + } +} + +struct StatePoolLease<'a, T> { + pool: &'a StatePool, + item: Option, +} + +impl Deref for StatePoolLease<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.item.as_ref().unwrap() + } +} + +impl DerefMut for StatePoolLease<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.item.as_mut().unwrap() + } +} + +impl Drop for StatePoolLease<'_, T> { + fn drop(&mut self) { + let mut idle = self.pool.idle.lock().unwrap(); + idle.push(self.item.take().unwrap()); + self.pool.available.notify_one(); + } +} /// Whisper-based transcriber pub struct WhisperTranscriber { - /// Whisper context (holds the model) - ctx: WhisperContext, + /// Reused inference states backed by the loaded model context. + state_pool: StatePool, /// Language configuration (single, auto, or array) language: LanguageConfig, /// Whether to translate to English @@ -56,11 +115,24 @@ impl WhisperTranscriber { .map_err(|e| TranscribeError::InitFailed(e.to_string()))?; tracing::info!("Model loaded in {:.2}s", start.elapsed().as_secs_f32()); - let threads = config.threads.unwrap_or_else(|| num_cpus::get().min(4)); + let state_pool_size = determine_state_pool_size(threads); + let mut states = Vec::with_capacity(state_pool_size); + for _ in 0..state_pool_size { + states.push( + ctx.create_state() + .map_err(|e| TranscribeError::InferenceFailed(e.to_string()))?, + ); + } + tracing::info!( + "Prepared {} reusable whisper inference state(s) with {} thread(s) each", + state_pool_size, + threads + ); + let state_pool = StatePool::new(states); Ok(Self { - ctx, + state_pool, language: config.language.clone(), translate: config.translate, threads, @@ -132,6 +204,15 @@ impl WhisperTranscriber { impl Transcriber for WhisperTranscriber { fn transcribe(&self, samples: &[f32]) -> Result { + self.transcribe_with_options(samples, None, None) + } + + fn transcribe_with_options( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { if samples.is_empty() { return Err(TranscribeError::AudioFormat( "Empty audio buffer".to_string(), @@ -147,14 +228,31 @@ impl Transcriber for WhisperTranscriber { let start = std::time::Instant::now(); - // Create state for this transcription - let mut state = self - .ctx - .create_state() - .map_err(|e| TranscribeError::InferenceFailed(e.to_string()))?; - - // Determine language based on configuration mode - let selected_language: Option = if self.language.is_auto() { + let mut state = self.state_pool.acquire(); + + let override_lang = language_override + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()); + + // Determine language based on request override and config mode. + let selected_language: Option = if let Some(lang) = override_lang { + if lang == "auto" { + if self.language.is_multiple() { + let allowed = self.language.as_vec(); + tracing::debug!( + "Using constrained language detection from override auto set: {:?}", + allowed + ); + Some(self.select_language_from_allowed(&mut state, samples, &allowed)?) + } else { + tracing::debug!("Using unconstrained language auto-detection (override)"); + None + } + } else { + tracing::debug!("Using request language override: {}", lang); + Some(lang) + } + } else if self.language.is_auto() { // Unconstrained auto-detection: let Whisper detect from all languages tracing::debug!("Using unconstrained language auto-detection"); None @@ -192,8 +290,16 @@ impl Transcriber for WhisperTranscriber { params.set_suppress_blank(true); params.set_suppress_nst(true); - // Set initial prompt if configured - if let Some(prompt) = &self.initial_prompt { + // Set initial prompt (request override takes precedence). + let request_prompt = prompt_override.and_then(|p| { + let trimmed = p.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }); + if let Some(prompt) = request_prompt.or(self.initial_prompt.as_deref()) { params.set_initial_prompt(prompt); tracing::debug!("Using initial prompt: {:?}", prompt); } @@ -235,10 +341,16 @@ impl Transcriber for WhisperTranscriber { } let result = text.trim().to_string(); + let result_language = selected_language.unwrap_or_else(|| { + whisper_rs::get_lang_str(state.full_lang_id_from_state()) + .unwrap_or("auto") + .to_string() + }); tracing::info!( - "Transcription completed in {:.2}s: {:?}", + "Transcription completed in {:.2}s (language={}): {:?}", start.elapsed().as_secs_f32(), + result_language, if result.chars().count() > 50 { format!("{}...", result.chars().take(50).collect::()) } else { @@ -248,6 +360,127 @@ impl Transcriber for WhisperTranscriber { Ok(result) } + + fn transcribe_segments( + &self, + samples: &[f32], + language_override: Option<&str>, + prompt_override: Option<&str>, + ) -> Result { + if samples.is_empty() { + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); + } + + let duration_secs = samples.len() as f64 / 16000.0; + let start = std::time::Instant::now(); + + let mut state = self.state_pool.acquire(); + + let override_lang = language_override + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()); + + let selected_language: Option = if let Some(lang) = override_lang { + if lang == "auto" { + if self.language.is_multiple() { + let allowed = self.language.as_vec(); + Some(self.select_language_from_allowed(&mut state, samples, &allowed)?) + } else { + None + } + } else { + Some(lang) + } + } else if self.language.is_auto() { + None + } else if self.language.is_multiple() { + let allowed = self.language.as_vec(); + Some(self.select_language_from_allowed(&mut state, samples, &allowed)?) + } else { + Some(self.language.primary().to_string()) + }; + + let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); + + match &selected_language { + Some(lang) => params.set_language(Some(lang)), + None => params.set_language(None), + } + + params.set_translate(self.translate); + params.set_n_threads(self.threads as i32); + params.set_print_special(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + params.set_suppress_blank(true); + params.set_suppress_nst(true); + + let request_prompt = prompt_override.and_then(|p| { + let trimmed = p.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }); + if let Some(prompt) = request_prompt.or(self.initial_prompt.as_deref()) { + params.set_initial_prompt(prompt); + } + + if self.context_window_optimization { + params.set_no_context(true); + if let Some(audio_ctx) = calculate_audio_ctx(duration_secs as f32) { + params.set_audio_ctx(audio_ctx); + } + } + + state + .full(params, samples) + .map_err(|e| TranscribeError::InferenceFailed(e.to_string()))?; + + let mut full_text = String::new(); + let mut segments = Vec::new(); + + for segment in state.as_iter() { + let seg_text = segment + .to_str() + .map_err(|e| TranscribeError::InferenceFailed(e.to_string()))? + .to_string(); + // Timestamps are in centiseconds (10ms units) + let seg_start = segment.start_timestamp() as f64 / 100.0; + let seg_end = segment.end_timestamp() as f64 / 100.0; + + full_text.push_str(&seg_text); + segments.push(TranscriptionSegment { + start: seg_start, + end: seg_end, + text: seg_text.trim().to_string(), + }); + } + + let detected_lang = selected_language.unwrap_or_else(|| { + whisper_rs::get_lang_str(state.full_lang_id_from_state()) + .unwrap_or("auto") + .to_string() + }); + + tracing::info!( + "Segment transcription completed in {:.2}s: {} segments, language={}", + start.elapsed().as_secs_f32(), + segments.len(), + detected_lang, + ); + + Ok(TranscriptionResult { + text: full_text.trim().to_string(), + language: detected_lang, + duration: duration_secs, + segments, + }) + } } /// Resolve model name to file path @@ -337,6 +570,14 @@ fn calculate_audio_ctx(duration_secs: f32) -> Option { } } +fn determine_state_pool_size(threads: usize) -> usize { + let threads = threads.max(1); + let parallelism = std::thread::available_parallelism() + .map(|value| value.get()) + .unwrap_or(threads); + (parallelism / threads).clamp(1, 2) +} + /// Get the filename for a model pub fn get_model_filename(model: &str) -> String { match model { @@ -368,6 +609,9 @@ pub fn get_model_url(model: &str) -> String { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + use std::thread; + use std::time::{Duration, Instant}; #[test] fn test_model_url() { @@ -448,4 +692,44 @@ mod tests { } } } + + #[test] + fn determine_state_pool_size_is_bounded() { + assert!((1..=2).contains(&determine_state_pool_size(0))); + assert_eq!(determine_state_pool_size(usize::MAX), 1); + assert!((1..=2).contains(&determine_state_pool_size(1))); + assert!((1..=2).contains(&determine_state_pool_size(4))); + } + + #[test] + fn state_pool_returns_items_after_release() { + let pool = StatePool::new(vec![1usize]); + { + let mut lease = pool.acquire(); + *lease += 1; + } + let lease = pool.acquire(); + assert_eq!(*lease, 2); + } + + #[test] + fn state_pool_blocks_until_item_is_released() { + let pool = Arc::new(StatePool::new(vec![7usize])); + let guard = pool.acquire(); + let worker_pool = Arc::clone(&pool); + let started = Instant::now(); + + let handle = thread::spawn(move || { + let lease = worker_pool.acquire(); + let waited = started.elapsed(); + (*lease, waited) + }); + + thread::sleep(Duration::from_millis(40)); + drop(guard); + + let (value, waited) = handle.join().unwrap(); + assert_eq!(value, 7); + assert!(waited >= Duration::from_millis(40)); + } }