From 615ed1006464a17ce6a48284ed3d48cb09dd331a Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 31 Mar 2026 16:34:49 -0700 Subject: [PATCH 01/52] wip: tool perms --- Cargo.lock | 39 ++++++--- Cargo.toml | 7 +- crates/atuin-ai/Cargo.toml | 3 + crates/atuin-ai/src/lib.rs | 1 + crates/atuin-ai/src/permissions/check.rs | 25 ++++++ crates/atuin-ai/src/permissions/file.rs | 22 +++++ crates/atuin-ai/src/permissions/mod.rs | 4 + crates/atuin-ai/src/permissions/rule.rs | 88 +++++++++++++++++++ crates/atuin-ai/src/permissions/walker.rs | 101 ++++++++++++++++++++++ crates/atuin-client/Cargo.toml | 2 +- 10 files changed, 276 insertions(+), 16 deletions(-) create mode 100644 crates/atuin-ai/src/permissions/check.rs create mode 100644 crates/atuin-ai/src/permissions/file.rs create mode 100644 crates/atuin-ai/src/permissions/mod.rs create mode 100644 crates/atuin-ai/src/permissions/rule.rs create mode 100644 crates/atuin-ai/src/permissions/walker.rs diff --git a/Cargo.lock b/Cargo.lock index 9a97372a147..90aeea31d3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -285,10 +285,13 @@ dependencies = [ "ratatui", "ratatui-core", "ratatui-widgets", + "regex", "reqwest", "serde", "serde_json", + "thiserror 2.0.18", "tokio", + "toml", "tracing", "tracing-appender", "tracing-subscriber", @@ -954,7 +957,7 @@ dependencies = [ "pathdiff", "serde_core", "toml", - "winnow", + "winnow 0.7.15", ] [[package]] @@ -4332,9 +4335,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" dependencies = [ "serde_core", ] @@ -5197,22 +5200,24 @@ dependencies = [ [[package]] name = "toml" -version = "1.0.6+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "399b1124a3c9e16766831c6bba21e50192572cdd98706ea114f9502509686ffc" +checksum = "994b95d9e7bae62b34bab0e2a4510b801fa466066a6a8b2b57361fa1eba068ee" dependencies = [ + "indexmap 2.13.0", "serde_core", "serde_spanned", "toml_datetime", "toml_parser", - "winnow", + "toml_writer", + "winnow 1.0.1", ] [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] @@ -5227,23 +5232,23 @@ dependencies = [ "toml_datetime", "toml_parser", "toml_writer", - "winnow", + "winnow 0.7.15", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "39ca317ebc49f06bd748bfba29533eac9485569dc9bf80b849024b025e814fb9" dependencies = [ - "winnow", + "winnow 1.0.1", ] [[package]] name = "toml_writer" -version = "1.0.6+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "tonic" @@ -6549,6 +6554,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" + [[package]] name = "winreg" version = "0.10.1" diff --git a/Cargo.toml b/Cargo.toml index cdb73e19e07..fa346af9f90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,9 @@ [workspace] -members = ["crates/*", "crates/atuin-nucleo/matcher", "crates/atuin-nucleo/bench"] +members = [ + "crates/*", + "crates/atuin-nucleo/matcher", + "crates/atuin-nucleo/bench", +] resolver = "2" exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"] @@ -65,6 +69,7 @@ rustls = { version = "0.23", default-features = false, features = [ "std", "tls12", ] } +regex = "1.10.5" [workspace.dependencies.tracing-subscriber] version = "0.3" diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 6e7315cdf5f..6d01abbf07d 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -42,6 +42,9 @@ unicode-width = "0.2" eye_declare = "0.3" ratatui-core = "0.1" ratatui-widgets = "0.3" +thiserror = { workspace = true } +regex = { workspace = true } +toml = "1.1" [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 2d86271dc02..f3e72cb5eb8 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,2 +1,3 @@ pub mod commands; +pub mod permissions; pub mod tui; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs new file mode 100644 index 00000000000..537f2787514 --- /dev/null +++ b/crates/atuin-ai/src/permissions/check.rs @@ -0,0 +1,25 @@ +use eyre::Result; + +pub(crate) struct PermissionRequest { + call: ToolCall, +} + +pub(crate) enum PermissionResponse { + Allowed, + Denied, + Ask, +} + +pub(crate) struct PermissionsChecker { + // +} + +impl PermissionsChecker { + pub fn new() -> Self { + Self {} + } + + pub async fn check(&self, request: &PermissionRequest) -> Result { + // + } +} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs new file mode 100644 index 00000000000..5e344c1c22a --- /dev/null +++ b/crates/atuin-ai/src/permissions/file.rs @@ -0,0 +1,22 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::permissions::rule::Rule; + +pub(crate) struct RuleFile { + pub path: PathBuf, + pub content: RuleFileContent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFileContent { + pub permissions: RuleFilePermissions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFilePermissions { + pub allow: Vec, + pub deny: Vec, + pub ask: Vec, +} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs new file mode 100644 index 00000000000..defb70130a2 --- /dev/null +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod check; +pub(crate) mod file; +pub(crate) mod rule; +pub(crate) mod walker; diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs new file mode 100644 index 00000000000..12d51245a01 --- /dev/null +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -0,0 +1,88 @@ +use std::sync::OnceLock; + +use regex::Regex; +use serde::{Deserialize, Serialize}; + +static RULE_RE: OnceLock = OnceLock::new(); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RuleError { + #[error("invalid rule format: {0}")] + InvalidRule(String), +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct Rule { + pub tool: String, + pub scope: Option, +} + +impl std::fmt::Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.scope.as_ref() { + Some(scope) => write!(f, "{}({})", self.tool, scope), + None => write!(f, "{}", self.tool), + } + } +} + +impl TryFrom<&str> for Rule { + type Error = RuleError; + + fn try_from(value: &str) -> Result { + let value = value.trim(); + let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); + let caps = re + .captures(value) + .ok_or(RuleError::InvalidRule(value.to_string()))?; + let tool = caps.get(1).unwrap().as_str().to_string(); + let scope = caps.get(2).map(|m| m.as_str().to_string()); + Ok(Rule { tool, scope }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_try_from() { + assert_eq!( + Rule::try_from("Read").unwrap(), + Rule { + tool: "Read".to_string(), + scope: None + } + ); + assert_eq!( + Rule::try_from("Read(*)").unwrap(), + Rule { + tool: "Read".to_string(), + scope: Some("*".to_string()) + } + ); + assert_eq!( + Rule::try_from("Write(*.md)").unwrap(), + Rule { + tool: "Write".to_string(), + scope: Some("*.md".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(git commit *)").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("git commit *".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(echo ())").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("echo ()".to_string()) + } + ); + assert!(Rule::try_from("Shell(git commit *").is_err()); + assert!(Rule::try_from("Shell(git commit *)!").is_err()); + } +} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs new file mode 100644 index 00000000000..6ccfe7d7764 --- /dev/null +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -0,0 +1,101 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +use crate::permissions::file::{RuleFile, RuleFileContent}; + +struct FoundRuleFile { + depth: usize, + file: RuleFile, +} + +pub(crate) struct PermissionsWalker { + start: PathBuf, + global_permissions_file: Option, + rules: Vec, +} + +impl PermissionsWalker { + pub fn new(start: PathBuf, global_permissions_file: Option) -> Self { + Self { + start, + global_permissions_file, + rules: Vec::new(), + } + } + + /// Walks the filesystem starting from the start path and collecting permission files along the way. + /// Walks to the root, then checks the global permissions file, if any. + pub async fn walk(&mut self) -> Result<()> { + let mut to_check = self + .start + .ancestors() + .map(PathBuf::from) + .collect::>(); + if let Some(global_path) = self.global_permissions_file.as_ref() { + to_check.push(global_path.clone()); + } + + let size = to_check.len(); + let mut set: JoinSet>> = JoinSet::new(); + + for (index, path) in to_check.into_iter().enumerate() { + set.spawn(async move { + match check_for_permissions(&path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth: index, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + let mut found = Vec::with_capacity(size); + while let Some(result) = set.join_next().await { + let result = result?; // JoinErrors result in failure to walk the filesystem + + match result { + Ok(Some(FoundRuleFile { depth, file })) => { + found.push((depth, file)); + } + Ok(None) => { + continue; + } + Err(e) => { + tracing::error!( + "Error while walking filesystem for permissions check; skipping folder: {}", + e + ); + continue; + } + } + } + // join_next() returns in order of completion, not order of spawn + found.sort_by_key(|(depth, _)| *depth); + self.rules = found.into_iter().map(|(_, file)| file).collect(); + + Ok(()) + } +} + +// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. +// Returns None if no permissions file is found. +// Returns an error if any FS or deserialization errors occur. +async fn check_for_permissions(path: &Path) -> Result> { + let permissions_file = path.join(".atuin").join("permissions.ai.toml"); + + if !tokio::fs::try_exists(&permissions_file).await? { + return Ok(None); + } + + let content = tokio::fs::read_to_string(permissions_file).await?; + let content: RuleFileContent = toml::from_str(&content)?; + + Ok(Some(RuleFile { + path: path.to_path_buf(), + content, + })) +} diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml index 763f9d4e859..2860c82b12d 100644 --- a/crates/atuin-client/Cargo.toml +++ b/crates/atuin-client/Cargo.toml @@ -41,7 +41,7 @@ rand = { workspace = true } shellexpand = "3" sqlx = { workspace = true, features = ["sqlite", "regexp"] } minspan = "0.1.5" -regex = "1.10.5" +regex = { workspace = true } serde_regex = "1.1.0" fs-err = { workspace = true } sql-builder = { workspace = true } From fced2b8f94b68a121c8b6d8f49872d08f7d86773 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 31 Mar 2026 22:48:00 -0700 Subject: [PATCH 02/52] wip: tool perms --- crates/atuin-ai/src/commands/inline.rs | 280 +--------------------- crates/atuin-ai/src/lib.rs | 2 + crates/atuin-ai/src/permissions/check.rs | 42 +++- crates/atuin-ai/src/stream.rs | 291 +++++++++++++++++++++++ crates/atuin-ai/src/tools/mod.rs | 232 ++++++++++++++++++ crates/atuin-ai/src/tui/state.rs | 6 + 6 files changed, 569 insertions(+), 284 deletions(-) create mode 100644 crates/atuin-ai/src/stream.rs create mode 100644 crates/atuin-ai/src/tools/mod.rs diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index c16e3dac5c5..9a0fc92a35e 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,17 +1,12 @@ use std::sync::mpsc; -use crate::commands::detect_shell; +use crate::stream::run_chat_stream; use crate::tui::events::AiTuiEvent; use crate::tui::state::{AppState, ExitAction}; use crate::tui::view::ai_view; -use atuin_client::distro::detect_linux_distribution; -use atuin_common::tls::ensure_crypto_provider; -use eventsource_stream::Eventsource; -use eye_declare::{Application, CtrlCBehavior, Handle}; +use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use futures::StreamExt; -use reqwest::Url; -use tracing::{debug, error, info, trace}; +use tracing::{debug, info}; pub async fn run( initial_command: Option, @@ -108,254 +103,6 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu Ok(token) } -// ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── - -#[derive(Debug, Clone)] -enum ChatStreamEvent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - }, - Status(String), - Done { - session_id: String, - }, - Error(String), -} - -fn create_chat_stream( - hub_address: String, - token: String, - session_id: Option, - messages: Vec, - send_cwd: bool, -) -> std::pin::Pin> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - debug!("Sending SSE request to {endpoint}"); - - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } - - let mut request_body = serde_json::json!({ - "messages": messages, - "context": context, - }); - - if let Some(ref sid) = session_id { - trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; - - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(ChatStreamEvent::Status(state.to_string())); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); - } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); - } else { - error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -// ─────────────────────────────────────────────────────────────────── -// Async streaming task — pushes updates to app state via Handle -// ─────────────────────────────────────────────────────────────────── - -async fn run_chat_stream( - handle: Handle, - endpoint: String, - token: String, - session_id: Option, - messages: Vec, - send_cwd: bool, -) { - let stream = create_chat_stream(endpoint, token, session_id, messages, send_cwd); - futures::pin_mut!(stream); - - while let Some(event) = stream.next().await { - match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); - } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - trace!(id = %id, name = %name, "Processing ToolCall"); - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); - } - Ok(ChatStreamEvent::Done { session_id }) => { - trace!(session_id = %session_id, "Processing Done"); - handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); - }); - break; - } - Ok(ChatStreamEvent::Error(msg)) => { - trace!(error = %msg, "Processing Error"); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - Err(e) => { - let msg = e.to_string(); - handle.update(move |state| { - state.streaming_error(msg); - }); - break; - } - } - } -} - // ─────────────────────────────────────────────────────────────────── // Main TUI entry point // ─────────────────────────────────────────────────────────────────── @@ -544,27 +291,6 @@ async fn run_inline_tui( // Helpers // ─────────────────────────────────────────────────────────────────── -fn hub_url(base: &str, path: &str) -> Result { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") -} - -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), - } -} - #[derive(Clone)] enum Action { Execute(String), diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index f3e72cb5eb8..0663a9ffe3e 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,3 +1,5 @@ pub mod commands; pub mod permissions; +pub mod stream; +pub mod tools; pub mod tui; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs index 537f2787514..50c69e90214 100644 --- a/crates/atuin-ai/src/permissions/check.rs +++ b/crates/atuin-ai/src/permissions/check.rs @@ -1,7 +1,12 @@ +use std::path::PathBuf; + use eyre::Result; +use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; + pub(crate) struct PermissionRequest { - call: ToolCall, + working_dir: PathBuf, + call: Box, } pub(crate) enum PermissionResponse { @@ -10,16 +15,39 @@ pub(crate) enum PermissionResponse { Ask, } -pub(crate) struct PermissionsChecker { - // +pub(crate) struct PermissionChecker { + files: Vec, } -impl PermissionsChecker { - pub fn new() -> Self { - Self {} +impl PermissionChecker { + pub fn new(files: Vec) -> Self { + Self { files } } pub async fn check(&self, request: &PermissionRequest) -> Result { - // + // Files are in order from deepest to shallowest, so we can stop at the first match. + // Within a file, deny rules take precedence over ask and allow rules. + // Ask rules take precedence over allow rules. + for file in &self.files { + for rule in &file.content.permissions.deny { + if request.call.matches_rule(rule) { + return Ok(PermissionResponse::Denied); + } + } + + for rule in &file.content.permissions.ask { + if request.call.matches_rule(rule) { + return Ok(PermissionResponse::Ask); + } + } + + for rule in &file.content.permissions.allow { + if request.call.matches_rule(rule) { + return Ok(PermissionResponse::Allowed); + } + } + } + + Ok(PermissionResponse::Ask) } } diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs new file mode 100644 index 00000000000..36956cbfa7c --- /dev/null +++ b/crates/atuin-ai/src/stream.rs @@ -0,0 +1,291 @@ +// ─────────────────────────────────────────────────────────────────── +// SSE streaming +// ─────────────────────────────────────────────────────────────────── + +use atuin_client::distro::detect_linux_distribution; +use atuin_common::tls::ensure_crypto_provider; + +use eventsource_stream::Eventsource; +use eye_declare::Handle; +use eyre::{Context, Result}; +use futures::StreamExt; +use reqwest::Url; + +use crate::{commands::detect_shell, tools::ToolCall, tui::AppState}; + +#[derive(Debug, Clone)] +enum ChatStreamEvent { + TextChunk(String), + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, + Status(String), + Done { + session_id: String, + }, + Error(String), +} + +fn create_chat_stream( + hub_address: String, + token: String, + session_id: Option, + messages: Vec, + send_cwd: bool, +) -> std::pin::Pin> + Send>> { + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + tracing::debug!("Sending SSE request to {endpoint}"); + + let os = detect_os(); + let shell = detect_shell(); + + let mut context = serde_json::json!({ + "os": os, + "shell": shell, + "pwd": if send_cwd { std::env::current_dir() + .ok() + .map(|path| path.to_string_lossy().into_owned()) } else { None }, + }); + + if os == "linux" { + context["distro"] = serde_json::json!(detect_linux_distribution()); + } + + let mut request_body = serde_json::json!({ + "messages": messages, + "context": context, + "capabilities": [ + "client_tools_v1" + ] + }); + + if let Some(ref sid) = session_id { + tracing::trace!("Including session_id in request: {sid}"); + request_body["session_id"] = serde_json::json!(sid); + } + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; + + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + tracing::error!("SSE request failed with status: {status}, clearing session"); + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!("SSE request failed ({}): {}", status, body); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; + } + + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); + + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, "SSE event received"); + + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(ChatStreamEvent::TextChunk(content.to_string())); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(ChatStreamEvent::ToolCall { id, name, input }); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(ChatStreamEvent::Status(state.to_string())); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(ChatStreamEvent::Done { session_id }); + } else { + yield Ok(ChatStreamEvent::Done { session_id: String::new() }); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + tracing::error!("SSE error: {}", message); + yield Ok(ChatStreamEvent::Error(message)); + } else { + tracing::error!("SSE error: {}", data); + yield Ok(ChatStreamEvent::Error(data)); + } + break; + } + _ => {} + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) +} + +// ─────────────────────────────────────────────────────────────────── +// Async streaming task — pushes updates to app state via Handle +// ─────────────────────────────────────────────────────────────────── + +pub(crate) async fn run_chat_stream( + handle: Handle, + endpoint: String, + token: String, + session_id: Option, + messages: Vec, + send_cwd: bool, +) { + let stream = create_chat_stream(endpoint, token, session_id, messages, send_cwd); + futures::pin_mut!(stream); + + while let Some(event) = stream.next().await { + match event { + Ok(ChatStreamEvent::TextChunk(text)) => { + tracing::trace!(text = %text, "Processing TextChunk"); + handle.update(move |state| { + state.append_streaming_text(&text); + }); + } + Ok(ChatStreamEvent::ToolCall { id, name, input }) => { + tracing::trace!(id = %id, name = %name, "Processing ToolCall"); + + if let Ok(tool) = ToolCall::try_from((name.as_str(), &input)) { + // Recognized as a client-side tool call. + handle.update(move |state| { + state.handle_client_tool_call(tool); + }); + continue; + } + + handle.update(move |state| { + state.add_tool_call(id, name, input); + }); + } + Ok(ChatStreamEvent::ToolResult { + tool_use_id, + content, + is_error, + }) => { + tracing::trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); + handle.update(move |state| { + state.add_tool_result(tool_use_id, content, is_error); + }); + } + Ok(ChatStreamEvent::Status(status)) => { + tracing::trace!(status = %status, "Processing Status"); + handle.update(move |state| { + state.update_streaming_status(&status); + }); + } + Ok(ChatStreamEvent::Done { session_id }) => { + tracing::trace!(session_id = %session_id, "Processing Done"); + handle.update(move |state| { + if !session_id.is_empty() { + state.store_session_id(session_id); + } + state.finalize_streaming(); + }); + break; + } + Ok(ChatStreamEvent::Error(msg)) => { + tracing::trace!(error = %msg, "Processing Error"); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + Err(e) => { + let msg = e.to_string(); + handle.update(move |state| { + state.streaming_error(msg); + }); + break; + } + } + } +} + +fn hub_url(base: &str, path: &str) -> Result { + let base_with_slash = if base.ends_with('/') { + base.to_string() + } else { + format!("{base}/") + }; + let stripped = path.strip_prefix('/').unwrap_or(path); + Url::parse(&base_with_slash)? + .join(stripped) + .context("failed to build hub URL") +} + +fn detect_os() -> String { + match std::env::consts::OS { + "macos" => "macos".to_string(), + "linux" => "linux".to_string(), + "windows" => "windows".to_string(), + other => format!("Other: {other}"), + } +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs new file mode 100644 index 00000000000..207192436cf --- /dev/null +++ b/crates/atuin-ai/src/tools/mod.rs @@ -0,0 +1,232 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; + +use crate::permissions::rule::Rule; + +pub(crate) enum ToolCall { + Read(ReadToolCall), + Write(WriteToolCall), + Shell(ShellToolCall), + AtuinHistory(AtuinHistoryToolCall), +} + +impl TryFrom<(&str, &serde_json::Value)> for ToolCall { + type Error = eyre::Error; + + fn try_from((name, input): (&str, &serde_json::Value)) -> Result { + match name { + "read" => Ok(ToolCall::Read(ReadToolCall::try_from(input)?)), + "write" => Ok(ToolCall::Write(WriteToolCall::try_from(input)?)), + "shell" => Ok(ToolCall::Shell(ShellToolCall::try_from(input)?)), + "atuin_history" => Ok(ToolCall::AtuinHistory(AtuinHistoryToolCall::try_from( + input, + )?)), + _ => Err(eyre::eyre!("Unknown tool call: {name}")), + } + } +} + +pub(crate) trait PermissableToolCall { + fn matches_rule(&self, rule: &Rule) -> bool; + fn target_dir(&self) -> Option<&Path> { + None + } +} + +pub(crate) struct ReadToolCall { + path: PathBuf, +} + +impl TryFrom<&serde_json::Value> for ReadToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let path = value + .get("path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + Ok(ReadToolCall { + path: PathBuf::from(path), + }) + } +} + +impl PermissableToolCall for ReadToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Read" { + return false; + } + + if let Some(scope) = rule.scope.as_ref() { + if scope == "*" { + return true; + } + + todo!("check path vs scope glob"); + } + + true + } +} + +pub(crate) struct WriteToolCall { + path: PathBuf, + content: String, +} + +impl TryFrom<&serde_json::Value> for WriteToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let path = value + .get("path") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing path"))?; + + let content = value + .get("content") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing content"))?; + + Ok(WriteToolCall { + path: PathBuf::from(path), + content: content.to_string(), + }) + } +} + +impl PermissableToolCall for WriteToolCall { + fn target_dir(&self) -> Option<&Path> { + Some(&self.path) + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Write" { + return false; + } + + if let Some(scope) = rule.scope.as_ref() { + if scope == "*" { + return true; + } + + todo!("check path vs scope glob"); + } + + true + } +} + +pub(crate) struct ShellToolCall { + dir: Option, + command: String, +} + +impl TryFrom<&serde_json::Value> for ShellToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let dir = value.get("dir").and_then(|v| v.as_str()); + + let command = value + .get("command") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing command"))?; + + Ok(ShellToolCall { + dir: dir.map(PathBuf::from), + command: command.to_string(), + }) + } +} + +impl PermissableToolCall for ShellToolCall { + fn target_dir(&self) -> Option<&Path> { + self.dir.as_deref() + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "Shell" { + return false; + } + + if let Some(scope) = rule.scope.as_ref() { + if scope == "*" { + return true; + } + + todo!("split command into subcommands, check each"); + } + + true + } +} + +pub(crate) struct AtuinHistoryToolCall { + filter_modes: Vec, + query: String, +} + +pub(crate) enum HistorySearchFilterMode { + Global, + Host, + Session, + Directory, + Workspace, +} + +impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let filter_modes = value + .get("filter_modes") + .and_then(|v| v.as_array()) + .ok_or(eyre::eyre!("Missing filter_modes"))?; + + let filter_modes = filter_modes + .iter() + .map(|v| { + let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; + match mode { + "global" => Ok(HistorySearchFilterMode::Global), + "host" => Ok(HistorySearchFilterMode::Host), + "session" => Ok(HistorySearchFilterMode::Session), + "directory" => Ok(HistorySearchFilterMode::Directory), + "workspace" => Ok(HistorySearchFilterMode::Workspace), + _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), + } + }) + .collect::>>()?; + + let query = value + .get("query") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing query"))?; + + Ok(AtuinHistoryToolCall { + filter_modes, + query: query.to_string(), + }) + } +} + +impl PermissableToolCall for AtuinHistoryToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + if rule.tool != "AtuinHistory" { + return false; + } + + todo!() + } +} diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 4c5c2a1e49e..aff2a8f45fd 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -5,6 +5,8 @@ use tokio::task::AbortHandle; +use crate::tools::ToolCall; + /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] pub enum StreamingStatus { @@ -354,6 +356,10 @@ impl AppState { } } + pub fn handle_client_tool_call(&mut self, tool: ToolCall) { + todo!("check permissions, handle tool call, send result - async") + } + /// Add a tool call event during streaming. /// The current streaming text is already in events, so we just push the tool call. pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { From 64b82478a3f6ce97efebc6df51b2fd1d83bc1aa1 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Fri, 3 Apr 2026 21:30:55 -0700 Subject: [PATCH 03/52] wip: permission system --- Cargo.lock | 1 + crates/atuin-ai/Cargo.toml | 1 + crates/atuin-ai/src/commands/inline.rs | 177 ++++++++++++++++++- crates/atuin-ai/src/permissions/check.rs | 33 +++- crates/atuin-ai/src/permissions/file.rs | 4 + crates/atuin-ai/src/permissions/rule.rs | 20 ++- crates/atuin-ai/src/permissions/walker.rs | 16 +- crates/atuin-ai/src/stream.rs | 11 +- crates/atuin-ai/src/tools/mod.rs | 94 ++++++++-- crates/atuin-ai/src/tui/components/mod.rs | 1 + crates/atuin-ai/src/tui/components/select.rs | 90 ++++++++++ crates/atuin-ai/src/tui/events.rs | 12 ++ crates/atuin-ai/src/tui/state.rs | 60 +++++-- crates/atuin-ai/src/tui/view/mod.rs | 103 +++++++++-- 14 files changed, 569 insertions(+), 54 deletions(-) create mode 100644 crates/atuin-ai/src/tui/components/select.rs diff --git a/Cargo.lock b/Cargo.lock index 90aeea31d3c..097341be9ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -296,6 +296,7 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "tui-textarea-2", + "typed-builder 0.18.2", "unicode-width 0.2.2", "uuid", ] diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 6d01abbf07d..9b7cfff1827 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -45,6 +45,7 @@ ratatui-widgets = "0.3" thiserror = { workspace = true } regex = { workspace = true } toml = "1.1" +typed-builder = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 9a0fc92a35e..e98c93e89be 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,7 +1,12 @@ +use std::path::PathBuf; use std::sync::mpsc; +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; use crate::stream::run_chat_stream; -use crate::tui::events::AiTuiEvent; +use crate::tools::ToolCallState; +use crate::tui::ConversationEvent; +use crate::tui::events::{AiTuiEvent, PermissionResult}; use crate::tui::state::{AppState, ExitAction}; use crate::tui::view::ai_view; use eye_declare::{Application, CtrlCBehavior}; @@ -113,11 +118,33 @@ async fn run_inline_tui( initial_prompt: Option, settings: &atuin_client::settings::Settings, ) -> Result { - let initial_state = AppState::new(); + let (tx, rx) = mpsc::channel::(); - println!(); + let mut initial_state = AppState::new(tx.clone()); + initial_state + .pending_tool_calls + .push_back(crate::tools::PendingToolCall { + id: "1".to_string(), + state: crate::tools::ToolCallState::CheckingPermissions, + tool: crate::tools::ClientToolCall::Read(crate::tools::ReadToolCall { + path: std::path::PathBuf::from("test.txt"), + }), + }); + initial_state + .pending_tool_calls + .push_back(crate::tools::PendingToolCall { + id: "2".to_string(), + state: crate::tools::ToolCallState::CheckingPermissions, + tool: crate::tools::ClientToolCall::Shell(crate::tools::ShellToolCall { + dir: None, + command: "ls -lah".to_string(), + }), + }); + + let _ = tx.send(AiTuiEvent::CheckToolCallPermission("1".to_string())); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission("2".to_string())); - let (tx, rx) = mpsc::channel::(); + println!(); // If there's an initial prompt, send it as a SubmitInput event // so it flows through the same path as user-typed input. @@ -151,6 +178,7 @@ async fn run_inline_tui( state.is_input_blank = input_blank; }); } + AiTuiEvent::SubmitInput(input) => { let input = input.trim().to_string(); if input.is_empty() { @@ -199,6 +227,147 @@ async fn run_inline_tui( }); } + AiTuiEvent::CheckToolCallPermission(id) => { + eprintln!("Checking tool call permission: {:?}", &id); + let h2 = h.clone(); + + let id_clone = id.clone(); + tokio::spawn(async move { + let Ok(Some(tool_call)) = h2 + .fetch(move |state| state.get_pending_tool_call(&id).cloned()) + .await + else { + // todo: raise error + eprintln!("Error getting pending tool call: {:?}", &id_clone); + return; + }; + + let Some(working_dir) = tool_call + .target_dir() + .map(PathBuf::from) + .or_else(|| std::env::current_dir().ok()) + else { + // todo: raise error + eprintln!( + "Error getting working directory for tool call: {:?}", + &tool_call + ); + return; + }; + + let mut walker = PermissionWalker::new(working_dir.clone(), None); // todo: get global dir + + let Ok(_) = walker.walk().await else { + eprintln!("Error walking filesystem for permissions check"); + // todo: raise error + return; + }; + + let checker = PermissionChecker::new(walker.rules().to_owned()); + let request = + PermissionRequest::new(working_dir, Box::new(&tool_call.tool)); + + let Ok(response) = checker.check(&request).await else { + // todo: raise error + eprintln!("Error checking tool call permission"); + return; + }; + + match response { + PermissionResponse::Allowed => { + eprintln!("Executing tool call: {:?}", tool_call); + h2.update(move |state| { + state.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + content: format!( + "Permission granted for tool call {:?}", + &tool_call + ), + command: None, + }); + }); + } + PermissionResponse::Denied => { + eprintln!("Permission denied for tool call: {:?}", &tool_call); + h2.update(move |state| { + state.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + content: format!( + "Permission denied for tool call {:?}", + &tool_call + ), + command: None, + }); + }); + } + PermissionResponse::Ask => { + eprintln!("Asking for permission for tool call: {:?}", &tool_call); + h2.update(move |state| { + let mut tool_call = state.get_pending_tool_call_mut(&id_clone); + + let Some(tool_call) = tool_call.as_mut() else { + eprintln!( + "Error getting pending tool call: {:?}", + &id_clone + ); + return; + }; + + eprintln!( + "Setting tool call state to AskingForPermission: {:?}", + &tool_call + ); + tool_call.state = ToolCallState::AskingForPermission; + eprintln!( + "Tool call state set to AskingForPermission: {:?}", + &tool_call + ); + }); + } + } + }); + } + + AiTuiEvent::SelectPermission(permission) => { + // Okay, we have permssion information. + // If accepted, we can start executing. + // If denied, we can show an error message. + h.update(move |state| { + let tool_call = state + .pending_tool_calls + .iter() + .enumerate() + .find(|(_, call)| call.state == ToolCallState::AskingForPermission); + + let Some((index, _)) = tool_call else { + return; + }; + + match permission { + PermissionResult::Allow => { + state.pending_tool_calls.remove(index); + } + PermissionResult::AlwaysAllowInDir => { + // + } + PermissionResult::AlwaysAllow => { + // + } + PermissionResult::Deny => { + let Some(call) = state.pending_tool_calls.remove(index) else { + return; + }; + + state.add_tool_result( + call.id, + "Permission denied on the user's system".to_string(), + true, + ); + } + } + }); + } + AiTuiEvent::CancelGeneration => { h.update(|state| match state.mode { crate::tui::state::AppMode::Generating => { diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs index 50c69e90214..2bcb04262ed 100644 --- a/crates/atuin-ai/src/permissions/check.rs +++ b/crates/atuin-ai/src/permissions/check.rs @@ -4,9 +4,18 @@ use eyre::Result; use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; -pub(crate) struct PermissionRequest { +pub(crate) struct PermissionRequest<'t> { working_dir: PathBuf, - call: Box, + call: Box<&'t (dyn PermissableToolCall + Send + Sync)>, +} + +impl<'t> PermissionRequest<'t> { + pub fn new( + working_dir: PathBuf, + call: Box<&'t (dyn PermissableToolCall + Send + Sync)>, + ) -> Self { + Self { working_dir, call } + } } pub(crate) enum PermissionResponse { @@ -24,25 +33,43 @@ impl PermissionChecker { Self { files } } - pub async fn check(&self, request: &PermissionRequest) -> Result { + pub async fn check<'t>( + &self, + request: &'t PermissionRequest<'t>, + ) -> Result { // Files are in order from deepest to shallowest, so we can stop at the first match. // Within a file, deny rules take precedence over ask and allow rules. // Ask rules take precedence over allow rules. for file in &self.files { for rule in &file.content.permissions.deny { if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'DENY' by rule: {} in file: {}", + rule, + file.path.display() + ); return Ok(PermissionResponse::Denied); } } for rule in &file.content.permissions.ask { if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ASK' by rule: {} in file: {}", + rule, + file.path.display() + ); return Ok(PermissionResponse::Ask); } } for rule in &file.content.permissions.allow { if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ALLOW' by rule: {} in file: {}", + rule, + file.path.display() + ); return Ok(PermissionResponse::Allowed); } } diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs index 5e344c1c22a..c973f55bd81 100644 --- a/crates/atuin-ai/src/permissions/file.rs +++ b/crates/atuin-ai/src/permissions/file.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::permissions::rule::Rule; +#[derive(Debug, Clone)] pub(crate) struct RuleFile { pub path: PathBuf, pub content: RuleFileContent, @@ -16,7 +17,10 @@ pub(crate) struct RuleFileContent { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct RuleFilePermissions { + #[serde(default)] pub allow: Vec, + #[serde(default)] pub deny: Vec, + #[serde(default)] pub ask: Vec, } diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs index 12d51245a01..8fa3fa4a51f 100644 --- a/crates/atuin-ai/src/permissions/rule.rs +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -11,7 +11,7 @@ pub(crate) enum RuleError { InvalidRule(String), } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct Rule { pub tool: String, pub scope: Option, @@ -26,6 +26,24 @@ impl std::fmt::Display for Rule { } } +impl Serialize for Rule { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Rule { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::try_from(s.as_str()).map_err(serde::de::Error::custom) + } +} impl TryFrom<&str> for Rule { type Error = RuleError; diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs index 6ccfe7d7764..e7313117228 100644 --- a/crates/atuin-ai/src/permissions/walker.rs +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -5,18 +5,19 @@ use tokio::task::JoinSet; use crate::permissions::file::{RuleFile, RuleFileContent}; +#[derive(Debug)] struct FoundRuleFile { depth: usize, file: RuleFile, } -pub(crate) struct PermissionsWalker { +pub(crate) struct PermissionWalker { start: PathBuf, global_permissions_file: Option, rules: Vec, } -impl PermissionsWalker { +impl PermissionWalker { pub fn new(start: PathBuf, global_permissions_file: Option) -> Self { Self { start, @@ -25,6 +26,10 @@ impl PermissionsWalker { } } + pub fn rules(&self) -> &[RuleFile] { + &self.rules + } + /// Walks the filesystem starting from the start path and collecting permission files along the way. /// Walks to the root, then checks the global permissions file, if any. pub async fn walk(&mut self) -> Result<()> { @@ -33,14 +38,18 @@ impl PermissionsWalker { .ancestors() .map(PathBuf::from) .collect::>(); + if let Some(global_path) = self.global_permissions_file.as_ref() { to_check.push(global_path.clone()); } + eprintln!("to_check: {:?}", to_check); + let size = to_check.len(); let mut set: JoinSet>> = JoinSet::new(); for (index, path) in to_check.into_iter().enumerate() { + eprintln!("Checking: {:?}", path); set.spawn(async move { match check_for_permissions(&path).await { Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { @@ -57,6 +66,7 @@ impl PermissionsWalker { while let Some(result) = set.join_next().await { let result = result?; // JoinErrors result in failure to walk the filesystem + eprintln!("result: {:?}", result); match result { Ok(Some(FoundRuleFile { depth, file })) => { found.push((depth, file)); @@ -77,6 +87,8 @@ impl PermissionsWalker { found.sort_by_key(|(depth, _)| *depth); self.rules = found.into_iter().map(|(_, file)| file).collect(); + eprintln!("rules: {:?}", self.rules); + Ok(()) } } diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 36956cbfa7c..16808a63418 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -11,7 +11,11 @@ use eyre::{Context, Result}; use futures::StreamExt; use reqwest::Url; -use crate::{commands::detect_shell, tools::ToolCall, tui::AppState}; +use crate::{ + commands::detect_shell, + tools::ClientToolCall, + tui::{AppState, events::AiTuiEvent}, +}; #[derive(Debug, Clone)] enum ChatStreamEvent { @@ -213,10 +217,11 @@ pub(crate) async fn run_chat_stream( Ok(ChatStreamEvent::ToolCall { id, name, input }) => { tracing::trace!(id = %id, name = %name, "Processing ToolCall"); - if let Ok(tool) = ToolCall::try_from((name.as_str(), &input)) { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { // Recognized as a client-side tool call. handle.update(move |state| { - state.handle_client_tool_call(tool); + state.handle_client_tool_call(id.clone(), tool); + let _ = state.tx.send(AiTuiEvent::CheckToolCallPermission(id)); }); continue; } diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 207192436cf..021d519b8f0 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -4,29 +4,78 @@ use eyre::Result; use crate::permissions::rule::Rule; -pub(crate) enum ToolCall { +#[derive(Debug, Clone)] +pub(crate) struct PendingToolCall { + pub id: String, + pub state: ToolCallState, + pub tool: ClientToolCall, +} + +impl PendingToolCall { + pub(crate) fn target_dir(&self) -> Option<&Path> { + self.tool.target_dir() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolCallState { + CheckingPermissions, + AskingForPermission, + Denied(String), + Executing, +} + +pub(crate) enum ClientToolCallType { + Read, + Write, + Shell, + AtuinHistory, +} + +#[derive(Debug, Clone)] +pub(crate) enum ClientToolCall { Read(ReadToolCall), Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), } -impl TryFrom<(&str, &serde_json::Value)> for ToolCall { +impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { type Error = eyre::Error; fn try_from((name, input): (&str, &serde_json::Value)) -> Result { match name { - "read" => Ok(ToolCall::Read(ReadToolCall::try_from(input)?)), - "write" => Ok(ToolCall::Write(WriteToolCall::try_from(input)?)), - "shell" => Ok(ToolCall::Shell(ShellToolCall::try_from(input)?)), - "atuin_history" => Ok(ToolCall::AtuinHistory(AtuinHistoryToolCall::try_from( - input, - )?)), + "read" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), + "write" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + "shell" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), + "atuin_history" => Ok(ClientToolCall::AtuinHistory( + AtuinHistoryToolCall::try_from(input)?, + )), _ => Err(eyre::eyre!("Unknown tool call: {name}")), } } } +impl ClientToolCall { + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { + match self { + ClientToolCall::Read(tool) => tool.matches_rule(rule), + ClientToolCall::Write(tool) => tool.matches_rule(rule), + ClientToolCall::Shell(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + } + } + + pub(crate) fn target_dir(&self) -> Option<&Path> { + match self { + ClientToolCall::Read(tool) => tool.target_dir(), + ClientToolCall::Write(tool) => tool.target_dir(), + ClientToolCall::Shell(tool) => tool.target_dir(), + ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + } + } +} + pub(crate) trait PermissableToolCall { fn matches_rule(&self, rule: &Rule) -> bool; fn target_dir(&self) -> Option<&Path> { @@ -34,8 +83,19 @@ pub(crate) trait PermissableToolCall { } } +impl PermissableToolCall for ClientToolCall { + fn matches_rule(&self, rule: &Rule) -> bool { + self.matches_rule(rule) + } + + fn target_dir(&self) -> Option<&Path> { + self.target_dir() + } +} + +#[derive(Debug, Clone)] pub(crate) struct ReadToolCall { - path: PathBuf, + pub path: PathBuf, } impl TryFrom<&serde_json::Value> for ReadToolCall { @@ -75,9 +135,10 @@ impl PermissableToolCall for ReadToolCall { } } +#[derive(Debug, Clone)] pub(crate) struct WriteToolCall { - path: PathBuf, - content: String, + pub path: PathBuf, + pub content: String, } impl TryFrom<&serde_json::Value> for WriteToolCall { @@ -123,9 +184,10 @@ impl PermissableToolCall for WriteToolCall { } } +#[derive(Debug, Clone)] pub(crate) struct ShellToolCall { - dir: Option, - command: String, + pub dir: Option, + pub command: String, } impl TryFrom<&serde_json::Value> for ShellToolCall { @@ -168,11 +230,13 @@ impl PermissableToolCall for ShellToolCall { } } +#[derive(Debug, Clone)] pub(crate) struct AtuinHistoryToolCall { - filter_modes: Vec, - query: String, + pub filter_modes: Vec, + pub query: String, } +#[derive(Debug, Clone)] pub(crate) enum HistorySearchFilterMode { Global, Host, diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs index 2f684f5f257..94cf4005319 100644 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ b/crates/atuin-ai/src/tui/components/mod.rs @@ -1,3 +1,4 @@ pub mod atuin_ai; pub mod input_box; pub mod markdown; +pub mod select; diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs new file mode 100644 index 00000000000..b59e973f551 --- /dev/null +++ b/crates/atuin-ai/src/tui/components/select.rs @@ -0,0 +1,90 @@ +use std::sync::mpsc; + +use crossterm::event::KeyCode; +use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props}; +use ratatui::style::Style; +use typed_builder::TypedBuilder; + +use crate::tui::events::AiTuiEvent; + +#[derive(TypedBuilder)] +pub(crate) struct SelectOption { + #[builder(setter(into))] + pub label: String, + #[builder(setter(into))] + pub value: String, + #[builder(default = Style::default())] + pub label_style: Style, + #[builder(default = Style::default().reversed())] + pub selected_style: Style, +} + +#[derive(Default)] +pub(crate) struct PermissionSelectorState { + selected_option: usize, + tx: Option>, +} + +#[props] +pub(crate) struct Select { + pub options: Vec, + pub on_select: Box, +} + +#[component(props = Select, state = PermissionSelectorState)] +pub(crate) fn permission_selector( + props: &Select, + state: &PermissionSelectorState, + hooks: &mut Hooks, +) -> Elements { + hooks.use_focusable(true); + hooks.use_autofocus(); + + hooks.use_context::>(|tx, _, state| { + state.tx = tx.cloned(); + }); + + hooks.use_event(move |event, props, state| { + if !event.is_key_press() { + return EventResult::Ignored; + } + + if let crossterm::event::Event::Key(key) = event { + if key.kind != crossterm::event::KeyEventKind::Press { + return EventResult::Ignored; + } + + match key.code { + KeyCode::Up => { + state.selected_option = + (state.selected_option + props.options.len() - 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Down => { + state.selected_option = (state.selected_option + 1) % props.options.len(); + return EventResult::Consumed; + } + KeyCode::Enter => { + let option = &props.options[state.selected_option]; + (props.on_select)(&option); + return EventResult::Consumed; + } + _ => {} + } + } + + EventResult::Ignored + }); + + element!( + View { + #(for (index, option) in props.options.iter().enumerate() { + Text { Span(text: &option.label, style: if index == state.selected_option { + option.selected_style + } else { + option.label_style + }) } + }) + } + ) +} diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index a791bb80374..b31446f8447 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -12,6 +12,10 @@ pub enum AiTuiEvent { SubmitInput(String), /// User entered a slash command (e.g. "/help") SlashCommand(String), + /// Check the permission for a tool call + CheckToolCallPermission(String), + /// User selected a permission + SelectPermission(PermissionResult), /// Cancel active generation or streaming (Esc during Generating/Streaming) CancelGeneration, /// Execute the suggested command @@ -25,3 +29,11 @@ pub enum AiTuiEvent { /// Exit the application Exit, } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PermissionResult { + Allow, + AlwaysAllowInDir, + AlwaysAllow, + Deny, +} diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index aff2a8f45fd..b7d3fd05118 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -3,9 +3,15 @@ //! This module contains the core state types that represent the application's //! domain model. Conversation events match the API protocol format. +use std::{collections::VecDeque, path::PathBuf, sync::mpsc}; + use tokio::task::AbortHandle; -use crate::tools::ToolCall; +use crate::{ + permissions::walker::PermissionWalker, + tools::{ClientToolCall, PendingToolCall, ToolCallState}, + tui::events::AiTuiEvent, +}; /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] @@ -143,6 +149,8 @@ pub enum ExitAction { /// The view function derives the UI from this state. #[derive(Debug)] pub struct AppState { + /// Channel to send events to the main event loop + pub tx: mpsc::Sender, /// Current application mode pub mode: AppMode, /// Conversation events (source of truth, matches API protocol) @@ -163,11 +171,14 @@ pub struct AppState { pub confirmation_pending: bool, /// Abort handle for the active streaming task, if any pub stream_abort: Option, + /// Tool calls that are pending permission checking + execution + pub pending_tool_calls: VecDeque, } impl AppState { - pub fn new() -> Self { + pub fn new(tx: mpsc::Sender) -> Self { Self { + tx, mode: AppMode::Input, events: Vec::new(), error: None, @@ -178,6 +189,7 @@ impl AppState { was_interrupted: false, confirmation_pending: false, stream_abort: None, + pending_tool_calls: VecDeque::new(), } } @@ -356,8 +368,30 @@ impl AppState { } } - pub fn handle_client_tool_call(&mut self, tool: ToolCall) { - todo!("check permissions, handle tool call, send result - async") + pub(crate) fn handle_client_tool_call(&mut self, id: String, tool: ClientToolCall) { + self.pending_tool_calls.push_back(PendingToolCall { + id, + state: ToolCallState::CheckingPermissions, + tool, + }); + } + + pub(crate) fn handle_select_permission(&mut self, permission: String) { + match permission.as_str() { + "allow" => { + self.pending_tool_calls.pop_front(); + } + "always-allow-in-dir" => { + self.pending_tool_calls.pop_front(); + } + "always-allow" => { + self.pending_tool_calls.pop_front(); + } + "deny" => { + self.pending_tool_calls.pop_front(); + } + _ => {} + } } /// Add a tool call event during streaming. @@ -454,6 +488,18 @@ impl AppState { // ===== Query methods ===== + /// Get a pending tool call by ID + pub(crate) fn get_pending_tool_call(&self, id: &str) -> Option<&PendingToolCall> { + self.pending_tool_calls.iter().find(|call| call.id == id) + } + + /// Get a mutable pending tool call by ID + pub(crate) fn get_pending_tool_call_mut(&mut self, id: &str) -> Option<&mut PendingToolCall> { + self.pending_tool_calls + .iter_mut() + .find(|call| call.id == id) + } + /// Get the most recent command from events pub fn current_command(&self) -> Option<&str> { self.events.iter().rev().find_map(|e| e.as_command()) @@ -548,9 +594,3 @@ impl AppState { self.exit_action.is_some() } } - -impl Default for AppState { - fn default() -> Self { - Self::new() - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 0cd51dfadaa..859d6753a4c 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,13 +1,20 @@ //! View function that builds the eye-declare element tree from app state. +use std::sync::mpsc; + use eye_declare::{ Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element, }; use ratatui_core::style::{Color, Modifier, Style}; +use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; +use crate::tui::components::select::SelectOption; +use crate::tui::events::{AiTuiEvent, PermissionResult}; + use super::components::atuin_ai::AtuinAi; use super::components::input_box::InputBox; use super::components::markdown::Markdown; +use super::components::select::Select; use super::state::{AppMode, AppState}; mod turn; @@ -53,25 +60,89 @@ pub fn ai_view(state: &AppState) -> Elements { }) #(if !state.is_exiting() { - View(key: "input-box", padding_top: Cells::from(1)) { - InputBox( - key: "input", - title: "Generate a command or ask a question", - title_right: "Atuin AI", - footer: state.footer_text(), - active: state.mode == AppMode::Input && !state.confirmation_pending, - ) + #(input_view(state)) + }) + } + } +} - #(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input { - #(if state.confirmation_pending { - Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } - } else { - Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } - }) +fn input_view(state: &AppState) -> Elements { + let first_pending_tool_call = state + .pending_tool_calls + .iter() + .find(|call| call.state == ToolCallState::AskingForPermission); + + element! { + #(if first_pending_tool_call.is_some() { + #(tool_call_view(first_pending_tool_call.unwrap(), state.tx.clone())) + }) + + #(if first_pending_tool_call.is_none() { + View(key: "input-box", padding_top: Cells::from(1)) { + InputBox( + key: "input", + title: "Generate a command or ask a question", + title_right: "Atuin AI", + footer: state.footer_text(), + active: state.mode == AppMode::Input && !state.confirmation_pending, + ) + + #(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input { + #(if state.confirmation_pending { + Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } + } else { + Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } }) + }) + } + }) + } +} - } - }) +fn tool_call_view(tool_call: &PendingToolCall, tx: mpsc::Sender) -> Elements { + let (verb, tool_desc) = match &tool_call.tool { + ClientToolCall::Read(tool) => ("read", tool.path.display().to_string()), + ClientToolCall::Write(tool) => ("write to", tool.path.display().to_string()), + ClientToolCall::Shell(tool) => ("run", tool.command.clone()), + ClientToolCall::AtuinHistory(tool) => ("search your Atuin history for", tool.query.clone()), + }; + + element! { + View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { + Text { + Span(text: format!("Atuin AI would like to {}: ", verb), style: Style::default()) + Span(text: &tool_desc, style: Style::default().fg(Color::Yellow)) + } + View(padding_left: Cells::from(2)) { + Select(options: [ + SelectOption::builder() + .label("Allow") + .value("allow") + .build(), + SelectOption::builder() + .label("Always allow in this directory") + .value("always-allow-in-dir") + .build(), + SelectOption::builder() + .label("Always allow") + .value("always-allow") + .build(), + SelectOption::builder() + .label("Deny") + .value("deny") + .build(), + ], on_select: Box::new(move |option: &SelectOption| { + let value = match option.value.as_str() { + "allow" => PermissionResult::Allow, + "always-allow-in-dir" => PermissionResult::AlwaysAllowInDir, + "always-allow" => PermissionResult::AlwaysAllow, + "deny" => PermissionResult::Deny, + _ => unreachable!(), + }; + + let _ = tx.send(AiTuiEvent::SelectPermission(value)); + }) as Box) + } } } } From d605c90746cde03097f5ab396fd92585bbab535a Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 10:50:27 -0700 Subject: [PATCH 04/52] Make most of atuin-ai crate pub(crate) --- crates/atuin-ai/src/commands.rs | 6 +++--- crates/atuin-ai/src/commands/init.rs | 2 +- crates/atuin-ai/src/commands/inline.rs | 2 +- crates/atuin-ai/src/lib.rs | 8 ++++---- crates/atuin-ai/src/tui/components/atuin_ai.rs | 2 +- crates/atuin-ai/src/tui/components/markdown.rs | 4 ++-- crates/atuin-ai/src/tui/components/mod.rs | 8 ++++---- crates/atuin-ai/src/tui/events.rs | 4 ++-- crates/atuin-ai/src/tui/mod.rs | 10 +++++----- crates/atuin-ai/src/tui/state.rs | 18 +++++++++--------- crates/atuin-ai/src/tui/view/mod.rs | 2 +- 11 files changed, 33 insertions(+), 33 deletions(-) diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs index 6e79da61b22..6c12cca583d 100644 --- a/crates/atuin-ai/src/commands.rs +++ b/crates/atuin-ai/src/commands.rs @@ -9,10 +9,10 @@ use eyre::Result; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; pub mod init; -pub mod inline; +pub(crate) mod inline; #[derive(Args, Debug)] -pub struct AiArgs { +pub(crate) struct AiArgs { /// Enable verbose logging #[arg(short, long, global = true)] verbose: bool, @@ -71,7 +71,7 @@ pub async fn run( } } -pub fn detect_shell() -> Option { +pub(crate) fn detect_shell() -> Option { Some(Shell::current().to_string()) } diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs index 77abc4f4086..f693d89225f 100644 --- a/crates/atuin-ai/src/commands/init.rs +++ b/crates/atuin-ai/src/commands/init.rs @@ -1,6 +1,6 @@ use crate::commands::detect_shell; -pub async fn run(shell: String) -> eyre::Result<()> { +pub(crate) async fn run(shell: String) -> eyre::Result<()> { let integration = match shell.as_str() { "zsh" => generate_zsh_integration(), "bash" => generate_bash_integration(), diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index e98c93e89be..aea9afbfe2f 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -13,7 +13,7 @@ use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; use tracing::{debug, info}; -pub async fn run( +pub(crate) async fn run( initial_command: Option, api_endpoint: Option, api_token: Option, diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 0663a9ffe3e..3cbb0f4ec43 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,5 +1,5 @@ pub mod commands; -pub mod permissions; -pub mod stream; -pub mod tools; -pub mod tui; +pub(crate) mod permissions; +pub(crate) mod stream; +pub(crate) mod tools; +pub(crate) mod tui; diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index fab295029ec..2db2b216495 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -25,7 +25,7 @@ pub(crate) struct AtuinAi { } #[derive(Default)] -pub struct AtuinAiState { +pub(crate) struct AtuinAiState { tx: Option>, } diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs index 1cd7dbcf75d..6bbcf41b63d 100644 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ b/crates/atuin-ai/src/tui/components/markdown.rs @@ -16,7 +16,7 @@ use ratatui_widgets::paragraph::{Paragraph, Wrap}; /// A markdown rendering component backed by pulldown-cmark. #[props] -pub struct Markdown { +pub(crate) struct Markdown { pub source: String, } @@ -29,7 +29,7 @@ impl Markdown { } /// Style configuration for markdown rendering. -pub struct MarkdownStyles { +pub(crate) struct MarkdownStyles { pub base: Style, pub code_inline: Style, pub code_block: Style, diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs index 94cf4005319..3458327d3f0 100644 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ b/crates/atuin-ai/src/tui/components/mod.rs @@ -1,4 +1,4 @@ -pub mod atuin_ai; -pub mod input_box; -pub mod markdown; -pub mod select; +pub(crate) mod atuin_ai; +pub(crate) mod input_box; +pub(crate) mod markdown; +pub(crate) mod select; diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index b31446f8447..b7d48f68d2e 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -5,7 +5,7 @@ /// eye-declare's context system. The main event loop in `inline.rs` /// receives them and mutates `AppState` accordingly. #[derive(Debug)] -pub enum AiTuiEvent { +pub(crate) enum AiTuiEvent { /// User updated the input text InputUpdated(String), /// User submitted text input (Enter in Input mode) @@ -31,7 +31,7 @@ pub enum AiTuiEvent { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum PermissionResult { +pub(crate) enum PermissionResult { Allow, AlwaysAllowInDir, AlwaysAllow, diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index acb251a78ec..08c0d4da414 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,6 +1,6 @@ -pub mod components; -pub mod events; -pub mod state; -pub mod view; +pub(crate) mod components; +pub(crate) mod events; +pub(crate) mod state; +pub(crate) mod view; -pub use state::{AppMode, AppState, ConversationEvent, ExitAction}; +pub(crate) use state::{AppMode, AppState, ConversationEvent, ExitAction}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index b7d3fd05118..6053207e2da 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -15,7 +15,7 @@ use crate::{ /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamingStatus { +pub(crate) enum StreamingStatus { Processing, Searching, Thinking, @@ -23,7 +23,7 @@ pub enum StreamingStatus { } impl StreamingStatus { - pub fn from_status_str(s: &str) -> Self { + pub(crate) fn from_status_str(s: &str) -> Self { match s { "processing" => Self::Processing, "searching" => Self::Searching, @@ -32,7 +32,7 @@ impl StreamingStatus { } } - pub fn display_text(&self) -> &'static str { + pub(crate) fn display_text(&self) -> &'static str { match self { Self::Processing => "Processing...", Self::Searching => "Searching...", @@ -44,7 +44,7 @@ impl StreamingStatus { /// Conversation event types matching the API protocol #[derive(Debug, Clone)] -pub enum ConversationEvent { +pub(crate) enum ConversationEvent { /// User message (what the user typed) UserMessage { content: String }, /// Text content from assistant (streamed or complete) @@ -71,7 +71,7 @@ pub enum ConversationEvent { impl ConversationEvent { /// Convert to JSON for API calls - pub fn to_json(&self) -> serde_json::Value { + pub(crate) fn to_json(&self) -> serde_json::Value { match self { ConversationEvent::UserMessage { content } => serde_json::json!({ "type": "user_message", @@ -111,7 +111,7 @@ impl ConversationEvent { } /// Extract command from a suggest_command tool call - pub fn as_command(&self) -> Option<&str> { + pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self && name == "suggest_command" { @@ -122,7 +122,7 @@ impl ConversationEvent { } #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum AppMode { +pub(crate) enum AppMode { /// User is typing input Input, /// Waiting for generation (showing spinner) @@ -134,7 +134,7 @@ pub enum AppMode { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExitAction { +pub(crate) enum ExitAction { /// Run the command Execute(String), /// Insert command without running @@ -148,7 +148,7 @@ pub enum ExitAction { /// Conversation is stored as a sequence of events matching the API protocol. /// The view function derives the UI from this state. #[derive(Debug)] -pub struct AppState { +pub(crate) struct AppState { /// Channel to send events to the main event loop pub tx: mpsc::Sender, /// Current application mode diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 859d6753a4c..44de52710bc 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -27,7 +27,7 @@ mod turn; /// - Error display (if in error state) /// - Spacer /// - Input box (bordered, with contextual keybindings) -pub fn ai_view(state: &AppState) -> Elements { +pub(crate) fn ai_view(state: &AppState) -> Elements { let mut turn_builder = turn::TurnBuilder::new(); for event in &state.events { From 3429f87365910f4813cf239846a790c52d07534c Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 16:40:45 -0700 Subject: [PATCH 05/52] wip --- crates/atuin-ai/Cargo.toml | 2 +- crates/atuin-ai/src/commands/inline.rs | 181 +++++++++++++++++++++---- crates/atuin-ai/src/stream.rs | 1 - crates/atuin-ai/src/tools/mod.rs | 43 ++++-- crates/atuin-ai/src/tui/events.rs | 2 + crates/atuin-ai/src/tui/state.rs | 11 +- crates/atuin-ai/src/tui/view/mod.rs | 9 ++ crates/atuin-ai/src/tui/view/turn.rs | 15 +- 8 files changed, 215 insertions(+), 49 deletions(-) diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 9b7cfff1827..8c2d02e5a9f 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -39,7 +39,7 @@ async-stream = "0.3" uuid = { workspace = true } tui-textarea-2 = "0.10.2" unicode-width = "0.2" -eye_declare = "0.3" +eye_declare = { path = "../../../eye_declare/crates/eye_declare" } ratatui-core = "0.1" ratatui-widgets = "0.3" thiserror = { workspace = true } diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index aea9afbfe2f..fcae80384cf 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -11,6 +11,7 @@ use crate::tui::state::{AppState, ExitAction}; use crate::tui::view::ai_view; use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; +use serde_json::json; use tracing::{debug, info}; pub(crate) async fn run( @@ -121,28 +122,28 @@ async fn run_inline_tui( let (tx, rx) = mpsc::channel::(); let mut initial_state = AppState::new(tx.clone()); - initial_state - .pending_tool_calls - .push_back(crate::tools::PendingToolCall { - id: "1".to_string(), - state: crate::tools::ToolCallState::CheckingPermissions, - tool: crate::tools::ClientToolCall::Read(crate::tools::ReadToolCall { - path: std::path::PathBuf::from("test.txt"), - }), - }); - initial_state - .pending_tool_calls - .push_back(crate::tools::PendingToolCall { - id: "2".to_string(), - state: crate::tools::ToolCallState::CheckingPermissions, - tool: crate::tools::ClientToolCall::Shell(crate::tools::ShellToolCall { - dir: None, - command: "ls -lah".to_string(), - }), - }); - - let _ = tx.send(AiTuiEvent::CheckToolCallPermission("1".to_string())); - let _ = tx.send(AiTuiEvent::CheckToolCallPermission("2".to_string())); + // initial_state + // .pending_tool_calls + // .push_back(crate::tools::PendingToolCall { + // id: "1".to_string(), + // state: crate::tools::ToolCallState::CheckingPermissions, + // tool: crate::tools::ClientToolCall::Read(crate::tools::ReadToolCall { + // path: std::path::PathBuf::from("test.txt"), + // }), + // }); + // initial_state + // .pending_tool_calls + // .push_back(crate::tools::PendingToolCall { + // id: "2".to_string(), + // state: crate::tools::ToolCallState::CheckingPermissions, + // tool: crate::tools::ClientToolCall::Shell(crate::tools::ShellToolCall { + // dir: None, + // command: "ls -lah".to_string(), + // }), + // }); + + // let _ = tx.send(AiTuiEvent::CheckToolCallPermission("1".to_string())); + // let _ = tx.send(AiTuiEvent::CheckToolCallPermission("2".to_string())); println!(); @@ -171,6 +172,21 @@ async fn run_inline_tui( tokio::task::spawn_blocking(move || { while let Ok(event) = rx.recv() { match event { + AiTuiEvent::ContinueAfterTools => { + let ep = ep.clone(); + let tk = tk.clone(); + let h2 = h.clone(); + h.update(move |state| { + state.start_streaming(); + let messages = state.events_to_messages(); + let sid = state.session_id.clone(); + let task = tokio::spawn(async move { + run_chat_stream(h2, ep, tk, sid, messages, send_cwd).await; + }); + state.stream_abort = Some(task.abort_handle()); + }); + } + AiTuiEvent::InputUpdated(input) => { let input_blank = input.trim().is_empty(); @@ -276,16 +292,121 @@ async fn run_inline_tui( match response { PermissionResponse::Allowed => { eprintln!("Executing tool call: {:?}", tool_call); + + let id_clone2 = id_clone.clone(); h2.update(move |state| { - state.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - content: format!( - "Permission granted for tool call {:?}", - &tool_call - ), - command: None, - }); + state.add_tool_call( + id_clone2.clone(), + "read".to_string(), + json!({}), + ); + + let mut tool_call = state.get_pending_tool_call_mut(&id_clone2); + + let Some(tool_call) = tool_call.as_mut() else { + eprintln!( + "Error getting pending tool call: {:?}", + &id_clone2 + ); + return; + }; + + tool_call.state = ToolCallState::Executing; + + // + + // state.events.push(ConversationEvent::OutOfBandOutput { + // name: "System".to_string(), + // content: format!( + // "Permission granted for tool call {:?}", + // &tool_call + // ), + // command: None, + // }); }); + + match tool_call.tool { + crate::tools::ClientToolCall::Read(read) => { + let mut path = read.path.clone(); + + if path.is_relative() { + if let Ok(current_dir) = std::env::current_dir() { + path = current_dir.join(path); + } + } + + if !path.exists() { + let id = id_clone.clone(); + h2.update(move |state| { + state.add_tool_result( + id.clone(), + format!( + "Error: file does not exist: {}", + path.display() + ), + true, + ); + state.pending_tool_calls.retain(|c| c.id != id); + }); + return; + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path) + .map_err(|e| { + eprintln!("Error reading directory: {}", e); + e + }) + .ok() + .and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| { + entry + .file_name() + .to_string_lossy() + .to_string() + }) + .collect::>() + .into() + }) + else { + h2.update(move |state| { + state.add_tool_result( + id_clone.clone(), + format!( + "Error: could not read directory: {}", + path.display() + ), + true, + ); + state + .pending_tool_calls + .retain(|c| c.id != id_clone); + }); + return; + }; + + h2.update(move |state| { + state.add_tool_result( + id_clone.clone(), + format!( + "Directory contents:\n{}", + files.join("\n") + ), + false, + ); + state + .pending_tool_calls + .retain(|c| c.id != id_clone); + + let _ = + state.tx.send(AiTuiEvent::ContinueAfterTools); + }); + } + } + _ => {} + } } PermissionResponse::Denied => { eprintln!("Permission denied for tool call: {:?}", &tool_call); diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 16808a63418..d51bf7821c7 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -221,7 +221,6 @@ pub(crate) async fn run_chat_stream( // Recognized as a client-side tool call. handle.update(move |state| { state.handle_client_tool_call(id.clone(), tool); - let _ = state.tx.send(AiTuiEvent::CheckToolCallPermission(id)); }); continue; } diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 021d519b8f0..570801bb988 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -4,6 +4,7 @@ use eyre::Result; use crate::permissions::rule::Rule; +/// A pending tool call from the server, awaiting permissions or execution. #[derive(Debug, Clone)] pub(crate) struct PendingToolCall { pub id: String, @@ -17,6 +18,7 @@ impl PendingToolCall { } } +/// State of a pending tool call #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum ToolCallState { CheckingPermissions, @@ -25,13 +27,7 @@ pub(crate) enum ToolCallState { Executing, } -pub(crate) enum ClientToolCallType { - Read, - Write, - Shell, - AtuinHistory, -} - +/// A tool call from the server, with parsed input parameters. #[derive(Debug, Clone)] pub(crate) enum ClientToolCall { Read(ReadToolCall), @@ -45,8 +41,11 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { fn try_from((name, input): (&str, &serde_json::Value)) -> Result { match name { - "read" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), - "write" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), + // TODO: split these into separate tool calls, but rely on Write permissions for all + "str_replace" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + "file_create" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), + "file_insert" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), "shell" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, @@ -76,8 +75,11 @@ impl ClientToolCall { } } +/// A trait for tool calls that can be checked against permission rules. pub(crate) trait PermissableToolCall { + /// Checks if this tool call matches the given permission rule. fn matches_rule(&self, rule: &Rule) -> bool; + /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. fn target_dir(&self) -> Option<&Path> { None } @@ -96,6 +98,7 @@ impl PermissableToolCall for ClientToolCall { #[derive(Debug, Clone)] pub(crate) struct ReadToolCall { pub path: PathBuf, + pub view_range: Option<(u64, u64)>, } impl TryFrom<&serde_json::Value> for ReadToolCall { @@ -103,12 +106,32 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { fn try_from(value: &serde_json::Value) -> Result { let path = value - .get("path") + .get("file_path") .and_then(|v| v.as_str()) .ok_or(eyre::eyre!("Missing path"))?; + let view_range = value.get("view_range").and_then(|v| v.as_array()); + + let is_proper_size = view_range + .map(|arr| arr.len() == 2 && arr.iter().all(|v| v.is_u64())) + .unwrap_or(true); + + if !is_proper_size { + return Err(eyre::eyre!( + "Invalid view_range: must be an array of two integers" + )); + } + + let view_range = view_range.map(|arr| { + // SAFETY: already checked that the array has two elements and they are both u64 + let start = arr[0].as_u64().unwrap(); + let end = arr[1].as_u64().unwrap(); + (start, end) + }); + Ok(ReadToolCall { path: PathBuf::from(path), + view_range, }) } } diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index b7d48f68d2e..a3aa87942ee 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -16,6 +16,8 @@ pub(crate) enum AiTuiEvent { CheckToolCallPermission(String), /// User selected a permission SelectPermission(PermissionResult), + /// Continue after client tools have completed + ContinueAfterTools, /// Cancel active generation or streaming (Esc during Generating/Streaming) CancelGeneration, /// Execute the suggested command diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 6053207e2da..c6d6858db13 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -3,12 +3,11 @@ //! This module contains the core state types that represent the application's //! domain model. Conversation events match the API protocol format. -use std::{collections::VecDeque, path::PathBuf, sync::mpsc}; +use std::{collections::VecDeque, sync::mpsc}; use tokio::task::AbortHandle; use crate::{ - permissions::walker::PermissionWalker, tools::{ClientToolCall, PendingToolCall, ToolCallState}, tui::events::AiTuiEvent, }; @@ -370,10 +369,16 @@ impl AppState { pub(crate) fn handle_client_tool_call(&mut self, id: String, tool: ClientToolCall) { self.pending_tool_calls.push_back(PendingToolCall { - id, + id: id.clone(), state: ToolCallState::CheckingPermissions, tool, }); + + // Client tool calls can only happen at the last part of a turn + self.streaming_status = None; + self.mode = AppMode::Input; + + let _ = self.tx.send(AiTuiEvent::CheckToolCallPermission(id)); } pub(crate) fn handle_select_permission(&mut self, permission: String) { diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 44de52710bc..82931064558 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -207,6 +207,15 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { turn::UiEvent::SuggestedCommand(details) => { suggested_command_view(details) }, + turn::UiEvent::ToolCall(details) => { + element! { + View(padding_left: Cells::from(2)) { + Text { + Span(text: format!("Running tool: {}", details.name), style: Style::default().fg(Color::Blue)) + } + } + } + } _ => element!{} }) }) diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 861da64c947..01d2f47e3a9 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,5 +1,6 @@ use crate::tui::ConversationEvent; +/// Server-sent danger level for a suggested command #[derive(Debug)] pub(crate) enum DangerLevel { Low(Option), @@ -37,6 +38,7 @@ impl From<(&String, &String)> for DangerLevel { } } +/// Server-sent confidence level for a suggested command #[derive(Debug)] pub(crate) enum ConfidenceLevel { Low(Option), @@ -85,9 +87,10 @@ pub(crate) enum UiEvent { #[derive(Debug)] pub(crate) struct ToolCallDetails { - tool_use_id: String, - name: String, - status: ToolResultStatus, + pub(crate) tool_use_id: String, + pub(crate) name: String, + pub(crate) status: ToolResultStatus, + pub(crate) is_client: bool, } #[derive(Debug)] @@ -123,6 +126,7 @@ pub(crate) struct TurnBuilder { current_turn: Option, } +/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. impl TurnBuilder { pub(crate) fn new() -> Self { Self { @@ -174,7 +178,7 @@ impl TurnBuilder { for event in events.drain(..) { match event { - UiEvent::ToolCall(details) => { + UiEvent::ToolCall(details) if !details.is_client => { pending_tools.push(details); } other => { @@ -308,10 +312,13 @@ impl TurnBuilder { fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { + let is_client = matches!(name, "file_read"); + events.push(UiEvent::ToolCall(ToolCallDetails { tool_use_id: id.to_string(), name: name.to_string(), status: ToolResultStatus::Pending, + is_client, })); } } From 424f41fa2ba2786bf7eb13afcb6c8183a51f02a0 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 17:41:45 -0700 Subject: [PATCH 06/52] =?UTF-8?q?atuin-ai:=20Phase=201=20=E2=80=94=20estab?= =?UTF-8?q?lish=20foundation=20types=20and=20resolve=20mutation=20boundary?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change extracts four new abstractions and sets a clear architectural rule that will guide subsequent refactoring: **Mutation boundary (Phase 1a):** AppState no longer owns an mpsc::Sender. State is passive data; the event loop (or stream task) is the only thing that emits events. handle_client_tool_call is now a pure mutation — the CheckToolCallPermission event is sent by the caller. **AppContext (Phase 1b):** Session-scoped API configuration (endpoint, token, send_cwd) is now a single Clone struct instead of three individually-cloned bindings in the event loop. **ClientContext + ChatRequest (Phase 1c):** Machine identity (OS, shell, distro) is computed once per session via ClientContext::detect() instead of on every SSE request. ChatRequest wraps the per-turn message/session payload, replacing the inline request body construction in create_chat_stream. **ToolDescriptor (Phase 1d):** Centralizes tool metadata — canonical names, display verbs, progressive/past verbs, and client/server classification — into static descriptors. Replaces four separate name-to-text match sites and fixes a bug where is_client was checked against 'file_read' (wrong) instead of 'read_file' (correct). Also fixes several clippy warnings in modified code. --- .atuin/permissions.ai.toml | 3 + Cargo.lock | 4 - crates/atuin-ai/src/commands/inline.rs | 185 ++++++++++--------- crates/atuin-ai/src/context.rs | 62 +++++++ crates/atuin-ai/src/lib.rs | 1 + crates/atuin-ai/src/stream.rs | 80 ++++---- crates/atuin-ai/src/tools/descriptor.rs | 98 ++++++++++ crates/atuin-ai/src/tools/mod.rs | 11 ++ crates/atuin-ai/src/tui/components/select.rs | 10 +- crates/atuin-ai/src/tui/mod.rs | 2 +- crates/atuin-ai/src/tui/state.rs | 14 +- crates/atuin-ai/src/tui/view/mod.rs | 21 +-- crates/atuin-ai/src/tui/view/turn.rs | 25 +-- 13 files changed, 341 insertions(+), 175 deletions(-) create mode 100644 .atuin/permissions.ai.toml create mode 100644 crates/atuin-ai/src/context.rs create mode 100644 crates/atuin-ai/src/tools/descriptor.rs diff --git a/.atuin/permissions.ai.toml b/.atuin/permissions.ai.toml new file mode 100644 index 00000000000..399c89906c6 --- /dev/null +++ b/.atuin/permissions.ai.toml @@ -0,0 +1,3 @@ +[permissions] + +allow = ["Read"] diff --git a/Cargo.lock b/Cargo.lock index 097341be9ec..79b655c4447 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1503,8 +1503,6 @@ dependencies = [ [[package]] name = "eye_declare" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac5c1a3b194e6674e9e44dbfb035f31c4df7a1ff6c8765181c50e8482bb393a" dependencies = [ "crossterm", "eye_declare_macros", @@ -1519,8 +1517,6 @@ dependencies = [ [[package]] name = "eye_declare_macros" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98595776b5e10c6ea519c09940fb7995b64da1e9a70cc94aa6c08b3bd404925a" dependencies = [ "proc-macro2", "quote", diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index fcae80384cf..9201cf954bc 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,9 +1,10 @@ use std::path::PathBuf; use std::sync::mpsc; +use crate::context::{AppContext, ClientContext}; use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; use crate::permissions::walker::PermissionWalker; -use crate::stream::run_chat_stream; +use crate::stream::{ChatRequest, run_chat_stream}; use crate::tools::ToolCallState; use crate::tui::ConversationEvent; use crate::tui::events::{AiTuiEvent, PermissionResult}; @@ -47,7 +48,13 @@ pub(crate) async fn run( ensure_hub_session(settings).await? }; - let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?; + let ctx = AppContext { + endpoint: endpoint.to_string(), + token, + send_cwd: settings.ai.send_cwd, + }; + + let action = run_inline_tui(ctx, initial_command).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -113,15 +120,11 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu // Main TUI entry point // ─────────────────────────────────────────────────────────────────── -async fn run_inline_tui( - endpoint: String, - token: String, - initial_prompt: Option, - settings: &atuin_client::settings::Settings, -) -> Result { +async fn run_inline_tui(ctx: AppContext, initial_prompt: Option) -> Result { let (tx, rx) = mpsc::channel::(); - let mut initial_state = AppState::new(tx.clone()); + let initial_state = AppState::new(); + let client_ctx = ClientContext::detect(); // initial_state // .pending_tool_calls // .push_back(crate::tools::PendingToolCall { @@ -159,29 +162,31 @@ async fn run_inline_tui( .ctrl_c(CtrlCBehavior::Deliver) .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) .bracketed_paste(true) - .with_context(tx) + .with_context(tx.clone()) .extra_newlines_at_exit(1) .build()?; - let send_cwd = settings.ai.send_cwd; - // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); - let ep = endpoint.clone(); - let tk = token.clone(); tokio::task::spawn_blocking(move || { + // Clone tx for use in each loop iteration (run_chat_stream and update closures + // move tx clones into spawned tasks/closures). + let tx = tx.clone(); + let client_ctx = client_ctx; while let Ok(event) = rx.recv() { match event { AiTuiEvent::ContinueAfterTools => { - let ep = ep.clone(); - let tk = tk.clone(); + let ctx = ctx.clone(); let h2 = h.clone(); + let tx2 = tx.clone(); + let cc = client_ctx.clone(); h.update(move |state| { state.start_streaming(); let messages = state.events_to_messages(); let sid = state.session_id.clone(); + let request = ChatRequest::new(messages, sid); let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd).await; + run_chat_stream(h2, tx2, ctx, cc, request).await; }); state.stream_abort = Some(task.abort_handle()); }); @@ -221,17 +226,19 @@ async fn run_inline_tui( } // Start generation and spawn streaming task - let ep = ep.clone(); - let tk = tk.clone(); + let ctx = ctx.clone(); let h2 = h.clone(); + let tx2 = tx.clone(); + let cc = client_ctx.clone(); h.update(move |state| { state.start_generating(input); state.start_streaming(); state.is_input_blank = true; let messages = state.events_to_messages(); let sid = state.session_id.clone(); + let request = ChatRequest::new(messages, sid); let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd).await; + run_chat_stream(h2, tx2, ctx, cc, request).await; }); state.stream_abort = Some(task.abort_handle()); }); @@ -246,6 +253,7 @@ async fn run_inline_tui( AiTuiEvent::CheckToolCallPermission(id) => { eprintln!("Checking tool call permission: {:?}", &id); let h2 = h.clone(); + let tx_for_task = tx.clone(); let id_clone = id.clone(); tokio::spawn(async move { @@ -325,87 +333,82 @@ async fn run_inline_tui( // }); }); - match tool_call.tool { - crate::tools::ClientToolCall::Read(read) => { - let mut path = read.path.clone(); + if let crate::tools::ClientToolCall::Read(read) = tool_call.tool { + let mut path = read.path.clone(); - if path.is_relative() { - if let Ok(current_dir) = std::env::current_dir() { - path = current_dir.join(path); - } - } + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } - if !path.exists() { - let id = id_clone.clone(); - h2.update(move |state| { - state.add_tool_result( - id.clone(), - format!( - "Error: file does not exist: {}", - path.display() - ), - true, - ); - state.pending_tool_calls.retain(|c| c.id != id); - }); - return; - } - - if path.is_dir() { - let Some(files) = std::fs::read_dir(&path) - .map_err(|e| { - eprintln!("Error reading directory: {}", e); - e - }) - .ok() - .and_then(|entries| { - entries - .filter_map(|entry| entry.ok()) - .map(|entry| { - entry - .file_name() - .to_string_lossy() - .to_string() - }) - .collect::>() - .into() - }) - else { - h2.update(move |state| { - state.add_tool_result( - id_clone.clone(), - format!( - "Error: could not read directory: {}", - path.display() - ), - true, - ); - state - .pending_tool_calls - .retain(|c| c.id != id_clone); - }); - return; - }; + if !path.exists() { + let id = id_clone.clone(); + h2.update(move |state| { + state.add_tool_result( + id.clone(), + format!( + "Error: file does not exist: {}", + path.display() + ), + true, + ); + state.pending_tool_calls.retain(|c| c.id != id); + }); + return; + } + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path) + .map_err(|e| { + eprintln!("Error reading directory: {}", e); + e + }) + .ok() + .and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| { + entry + .file_name() + .to_string_lossy() + .to_string() + }) + .collect::>() + .into() + }) + else { h2.update(move |state| { state.add_tool_result( id_clone.clone(), format!( - "Directory contents:\n{}", - files.join("\n") + "Error: could not read directory: {}", + path.display() ), - false, + true, ); state .pending_tool_calls .retain(|c| c.id != id_clone); - - let _ = - state.tx.send(AiTuiEvent::ContinueAfterTools); }); - } + return; + }; + + h2.update(move |state| { + state.add_tool_result( + id_clone.clone(), + format!( + "Directory contents:\n{}", + files.join("\n") + ), + false, + ); + state.pending_tool_calls.retain(|c| c.id != id_clone); + + let _ = + tx_for_task.send(AiTuiEvent::ContinueAfterTools); + }); } - _ => {} } } PermissionResponse::Denied => { @@ -536,16 +539,18 @@ async fn run_inline_tui( } AiTuiEvent::Retry => { - let ep = ep.clone(); - let tk = tk.clone(); + let ctx = ctx.clone(); let h2 = h.clone(); + let tx2 = tx.clone(); + let cc = client_ctx.clone(); h.update(move |state| { state.retry(); state.start_streaming(); let messages = state.events_to_messages(); let sid = state.session_id.clone(); + let request = ChatRequest::new(messages, sid); let task = tokio::spawn(async move { - run_chat_stream(h2, ep, tk, sid, messages, send_cwd).await; + run_chat_stream(h2, tx2, ctx, cc, request).await; }); state.stream_abort = Some(task.abort_handle()); }); diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs new file mode 100644 index 00000000000..03dc3f891e2 --- /dev/null +++ b/crates/atuin-ai/src/context.rs @@ -0,0 +1,62 @@ +use atuin_client::distro::detect_linux_distribution; + +/// Session-scoped context for the AI chat session. +/// Holds the API configuration and client settings needed by the event loop and stream task. +#[derive(Clone, Debug)] +pub(crate) struct AppContext { + pub endpoint: String, + pub token: String, + pub send_cwd: bool, +} + +/// Machine identity — computed once per session. +#[derive(Clone, Debug)] +pub(crate) struct ClientContext { + pub os: String, + pub shell: Option, + pub distro: Option, +} + +impl ClientContext { + pub(crate) fn detect() -> Self { + let os = detect_os(); + let shell = crate::commands::detect_shell(); + let distro = if os == "linux" { + Some(detect_linux_distribution()) + } else { + None + }; + Self { os, shell, distro } + } + + /// Serialize to the JSON format the API expects for the "context" field. + /// The `pwd` field is always dynamic (current working directory), so it's + /// computed fresh on each call if `send_cwd` is true. + pub(crate) fn to_json(&self, send_cwd: bool) -> serde_json::Value { + let mut ctx = serde_json::json!({ + "os": self.os, + "shell": self.shell, + "pwd": if send_cwd { + std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) + } else { + None + }, + }); + + if let Some(ref distro) = self.distro { + ctx["distro"] = serde_json::json!(distro); + } + + ctx + } +} + +/// Move the `detect_os` function here since it's about client identity. +fn detect_os() -> String { + match std::env::consts::OS { + "macos" => "macos".to_string(), + "linux" => "linux".to_string(), + "windows" => "windows".to_string(), + other => format!("Other: {other}"), + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 3cbb0f4ec43..6f431179a4e 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,4 +1,5 @@ pub mod commands; +pub(crate) mod context; pub(crate) mod permissions; pub(crate) mod stream; pub(crate) mod tools; diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index d51bf7821c7..49f8ef4e729 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -2,7 +2,8 @@ // SSE streaming // ─────────────────────────────────────────────────────────────────── -use atuin_client::distro::detect_linux_distribution; +use std::sync::mpsc; + use atuin_common::tls::ensure_crypto_provider; use eventsource_stream::Eventsource; @@ -12,7 +13,7 @@ use futures::StreamExt; use reqwest::Url; use crate::{ - commands::detect_shell, + context::{AppContext, ClientContext}, tools::ClientToolCall, tui::{AppState, events::AiTuiEvent}, }; @@ -37,11 +38,29 @@ enum ChatStreamEvent { Error(String), } +/// Per-turn request payload for the chat API. +pub(crate) struct ChatRequest { + pub messages: Vec, + pub session_id: Option, + /// Requested capabilities. Currently always ["client_tools_v1"]. + pub capabilities: Vec, +} + +impl ChatRequest { + pub(crate) fn new(messages: Vec, session_id: Option) -> Self { + Self { + messages, + session_id, + capabilities: vec!["client_tools_v1".to_string()], + } + } +} + fn create_chat_stream( hub_address: String, token: String, - session_id: Option, - messages: Vec, + request: ChatRequest, + client_ctx: ClientContext, send_cwd: bool, ) -> std::pin::Pin> + Send>> { Box::pin(async_stream::stream! { @@ -56,30 +75,15 @@ fn create_chat_stream( tracing::debug!("Sending SSE request to {endpoint}"); - let os = detect_os(); - let shell = detect_shell(); - - let mut context = serde_json::json!({ - "os": os, - "shell": shell, - "pwd": if send_cwd { std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()) } else { None }, - }); - - if os == "linux" { - context["distro"] = serde_json::json!(detect_linux_distribution()); - } + let context = client_ctx.to_json(send_cwd); let mut request_body = serde_json::json!({ - "messages": messages, + "messages": request.messages, "context": context, - "capabilities": [ - "client_tools_v1" - ] + "capabilities": request.capabilities, }); - if let Some(ref sid) = session_id { + if let Some(ref sid) = request.session_id { tracing::trace!("Including session_id in request: {sid}"); request_body["session_id"] = serde_json::json!(sid); } @@ -197,13 +201,18 @@ fn create_chat_stream( pub(crate) async fn run_chat_stream( handle: Handle, - endpoint: String, - token: String, - session_id: Option, - messages: Vec, - send_cwd: bool, + tx: mpsc::Sender, + app_ctx: AppContext, + client_ctx: ClientContext, + request: ChatRequest, ) { - let stream = create_chat_stream(endpoint, token, session_id, messages, send_cwd); + let stream = create_chat_stream( + app_ctx.endpoint.clone(), + app_ctx.token.clone(), + request, + client_ctx, + app_ctx.send_cwd, + ); futures::pin_mut!(stream); while let Some(event) = stream.next().await { @@ -219,9 +228,11 @@ pub(crate) async fn run_chat_stream( if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { // Recognized as a client-side tool call. + let id_for_update = id.clone(); handle.update(move |state| { - state.handle_client_tool_call(id.clone(), tool); + state.handle_client_tool_call(id_for_update, tool); }); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); continue; } @@ -284,12 +295,3 @@ fn hub_url(base: &str, path: &str) -> Result { .join(stripped) .context("failed to build hub URL") } - -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), - } -} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs new file mode 100644 index 00000000000..4518c88a6f0 --- /dev/null +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -0,0 +1,98 @@ +/// Centralized metadata for a tool type. +/// +/// Covers both client-side tools (ones the CLI executes locally) and +/// server-side tools (ones the API executes remotely). This is the single +/// source of truth for display text and classification. +pub(crate) struct ToolDescriptor { + /// Canonical wire names for this tool (the names the server sends). + pub canonical_names: &'static [&'static str], + /// Imperative verb for permission prompts (e.g. "read", "run"). + pub display_verb: &'static str, + /// Present-tense progressive verb for spinners (e.g. "Reading file..."). + pub progressive_verb: &'static str, + /// Past-tense verb for summaries (e.g. "Read file"). + pub past_verb: &'static str, + /// Whether this tool is executed client-side (by the CLI). + pub is_client: bool, +} + +// ── Client-side tool descriptors ── + +pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["read_file"], + display_verb: "read", + progressive_verb: "Reading file...", + past_verb: "Read file", + is_client: true, +}; + +pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["str_replace", "file_create", "file_insert"], + display_verb: "write to", + progressive_verb: "Writing file...", + past_verb: "Wrote file", + is_client: true, +}; + +pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["shell"], + display_verb: "run", + progressive_verb: "Running command...", + past_verb: "Ran command", + is_client: true, +}; + +pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_history"], + display_verb: "search your Atuin history for", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: true, +}; + +// ── Server-side tool descriptors ── +// These appear in tool summaries but aren't client-side tools. + +pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["search"], + display_verb: "search", + progressive_verb: "Searching...", + past_verb: "Searched", + is_client: false, +}; + +pub(crate) const SERVER_EXECUTE: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["execute", "run", "bash"], + display_verb: "run", + progressive_verb: "Running command...", + past_verb: "Ran command", + is_client: false, +}; + +pub(crate) const SERVER_LIST: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["list", "list_files"], + display_verb: "list", + progressive_verb: "Listing files...", + past_verb: "Listed files", + is_client: false, +}; + +/// All known tool descriptors, for lookup by name. +const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ + READ, + WRITE, + SHELL, + ATUIN_HISTORY, + SERVER_SEARCH, + SERVER_EXECUTE, + SERVER_LIST, +]; + +/// Look up a tool descriptor by its canonical wire name. +/// Returns None for unknown tool names. +pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { + ALL_DESCRIPTORS + .iter() + .find(|d| d.canonical_names.contains(&name)) + .copied() +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 570801bb988..d92b42ae72c 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -2,6 +2,8 @@ use std::path::{Path, PathBuf}; use eyre::Result; +pub(crate) mod descriptor; + use crate::permissions::rule::Rule; /// A pending tool call from the server, awaiting permissions or execution. @@ -56,6 +58,15 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { } impl ClientToolCall { + pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { + match self { + ClientToolCall::Read(_) => descriptor::READ, + ClientToolCall::Write(_) => descriptor::WRITE, + ClientToolCall::Shell(_) => descriptor::SHELL, + ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + } + } + pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { match self { ClientToolCall::Read(tool) => tool.matches_rule(rule), diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs index b59e973f551..5abbe655eb2 100644 --- a/crates/atuin-ai/src/tui/components/select.rs +++ b/crates/atuin-ai/src/tui/components/select.rs @@ -7,6 +7,8 @@ use typed_builder::TypedBuilder; use crate::tui::events::AiTuiEvent; +type OnSelectFn = Box Option + Send + Sync + 'static>; + #[derive(TypedBuilder)] pub(crate) struct SelectOption { #[builder(setter(into))] @@ -28,7 +30,7 @@ pub(crate) struct PermissionSelectorState { #[props] pub(crate) struct Select { pub options: Vec, - pub on_select: Box, + pub on_select: OnSelectFn, } #[component(props = Select, state = PermissionSelectorState)] @@ -66,7 +68,11 @@ pub(crate) fn permission_selector( } KeyCode::Enter => { let option = &props.options[state.selected_option]; - (props.on_select)(&option); + if let Some(event) = (props.on_select)(option) + && let Some(ref tx) = state.tx + { + let _ = tx.send(event); + } return EventResult::Consumed; } _ => {} diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index 08c0d4da414..081e79fdab1 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -3,4 +3,4 @@ pub(crate) mod events; pub(crate) mod state; pub(crate) mod view; -pub(crate) use state::{AppMode, AppState, ConversationEvent, ExitAction}; +pub(crate) use state::{AppState, ConversationEvent}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index c6d6858db13..95e333a36d4 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -3,14 +3,11 @@ //! This module contains the core state types that represent the application's //! domain model. Conversation events match the API protocol format. -use std::{collections::VecDeque, sync::mpsc}; +use std::collections::VecDeque; use tokio::task::AbortHandle; -use crate::{ - tools::{ClientToolCall, PendingToolCall, ToolCallState}, - tui::events::AiTuiEvent, -}; +use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] @@ -148,8 +145,6 @@ pub(crate) enum ExitAction { /// The view function derives the UI from this state. #[derive(Debug)] pub(crate) struct AppState { - /// Channel to send events to the main event loop - pub tx: mpsc::Sender, /// Current application mode pub mode: AppMode, /// Conversation events (source of truth, matches API protocol) @@ -175,9 +170,8 @@ pub(crate) struct AppState { } impl AppState { - pub fn new(tx: mpsc::Sender) -> Self { + pub fn new() -> Self { Self { - tx, mode: AppMode::Input, events: Vec::new(), error: None, @@ -377,8 +371,6 @@ impl AppState { // Client tool calls can only happen at the last part of a turn self.streaming_status = None; self.mode = AppMode::Input; - - let _ = self.tx.send(AiTuiEvent::CheckToolCallPermission(id)); } pub(crate) fn handle_select_permission(&mut self, permission: String) { diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 82931064558..540aa5e7eb8 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,7 +1,5 @@ //! View function that builds the eye-declare element tree from app state. -use std::sync::mpsc; - use eye_declare::{ Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element, }; @@ -74,7 +72,7 @@ fn input_view(state: &AppState) -> Elements { element! { #(if first_pending_tool_call.is_some() { - #(tool_call_view(first_pending_tool_call.unwrap(), state.tx.clone())) + #(tool_call_view(first_pending_tool_call.unwrap())) }) #(if first_pending_tool_call.is_none() { @@ -99,12 +97,13 @@ fn input_view(state: &AppState) -> Elements { } } -fn tool_call_view(tool_call: &PendingToolCall, tx: mpsc::Sender) -> Elements { - let (verb, tool_desc) = match &tool_call.tool { - ClientToolCall::Read(tool) => ("read", tool.path.display().to_string()), - ClientToolCall::Write(tool) => ("write to", tool.path.display().to_string()), - ClientToolCall::Shell(tool) => ("run", tool.command.clone()), - ClientToolCall::AtuinHistory(tool) => ("search your Atuin history for", tool.query.clone()), +fn tool_call_view(tool_call: &PendingToolCall) -> Elements { + let verb = tool_call.tool.descriptor().display_verb; + let tool_desc = match &tool_call.tool { + ClientToolCall::Read(tool) => tool.path.display().to_string(), + ClientToolCall::Write(tool) => tool.path.display().to_string(), + ClientToolCall::Shell(tool) => tool.command.clone(), + ClientToolCall::AtuinHistory(tool) => tool.query.clone(), }; element! { @@ -140,8 +139,8 @@ fn tool_call_view(tool_call: &PendingToolCall, tx: mpsc::Sender) -> _ => unreachable!(), }; - let _ = tx.send(AiTuiEvent::SelectPermission(value)); - }) as Box) + Some(AiTuiEvent::SelectPermission(value)) + }) as Box Option + Send + Sync>) } } } diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 01d2f47e3a9..c92785c4ea8 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,3 +1,4 @@ +use crate::tools::descriptor; use crate::tui::ConversationEvent; /// Server-sent danger level for a suggested command @@ -312,7 +313,7 @@ impl TurnBuilder { fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { - let is_client = matches!(name, "file_read"); + let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); events.push(UiEvent::ToolCall(ToolCallDetails { tool_use_id: id.to_string(), @@ -392,25 +393,15 @@ impl ToolSummary { /// Present-tense progressive verb for a tool name (e.g. "Searching...") fn progressive_verb(name: &str) -> String { - match name { - "search" => "Searching...".into(), - "read" | "read_file" => "Reading file...".into(), - "write" | "write_file" => "Writing file...".into(), - "execute" | "run" | "bash" => "Running command...".into(), - "list" | "list_files" => "Listing files...".into(), - _ => format!("Running {}...", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.progressive_verb.to_string()) + .unwrap_or_else(|| format!("Running {}...", name.replace('_', " "))) } /// Past-tense verb for a tool name (e.g. "Searched") fn past_verb(name: &str) -> String { - match name { - "search" => "Searched".into(), - "read" | "read_file" => "Read file".into(), - "write" | "write_file" => "Wrote file".into(), - "execute" | "run" | "bash" => "Ran command".into(), - "list" | "list_files" => "Listed files".into(), - _ => format!("Ran {}", name.replace('_', " ")), - } + descriptor::by_name(name) + .map(|d| d.past_verb.to_string()) + .unwrap_or_else(|| format!("Ran {}", name.replace('_', " "))) } } From 66cd10cdd858eb5ce12ee3d71f3a4316ca9ded90 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 17:59:43 -0700 Subject: [PATCH 07/52] =?UTF-8?q?atuin-ai:=20Phase=202=20=E2=80=94=20decom?= =?UTF-8?q?pose=20state,=20extract=20dispatch,=20split=20stream=20frames?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **State decomposition (Phase 2a):** AppState is replaced by three types with clear ownership: - Conversation: owns the event log and session_id. All pure query and event-manipulation methods (events_to_messages, current_command, has_any_command, etc.) move here. - Interaction: owns ephemeral UI state (mode, is_input_blank, confirmation_pending, streaming_status, error). - Session: the top-level type containing conversation, interaction, pending_tool_calls, exit_action, and stream_abort. Cross-cutting lifecycle methods (start_streaming, cancel_streaming, add_tool_call, etc.) stay here. **Dispatch extraction (Phase 2b):** The 400-line match event in the spawn_blocking loop is now a 5-line dispatch call. All 12 handlers are named functions in a new tui/dispatch.rs module. inline.rs shrinks from ~640 lines to ~240 lines. **Stream launch centralization (Phase 2c):** The duplicated 8-line stream launch ritual (present in ContinueAfterTools, SubmitInput, and Retry) is replaced by a single launch_stream function that takes a setup callback for pre-work. Each handler collapses to one line. **Stream frame split (Phase 2d):** ChatStreamEvent is replaced by StreamFrame with explicit Content (TextChunk, ToolCall, ToolResult) and Control (Done, Error, StatusChanged) variants. run_chat_stream now dispatches on frame type, with apply_content_frame and apply_control_frame as separate functions. --- crates/atuin-ai/src/commands/inline.rs | 410 +------------------- crates/atuin-ai/src/stream.rs | 170 +++++---- crates/atuin-ai/src/tui/dispatch.rs | 430 +++++++++++++++++++++ crates/atuin-ai/src/tui/mod.rs | 3 +- crates/atuin-ai/src/tui/state.rs | 495 +++++++++++++------------ crates/atuin-ai/src/tui/view/mod.rs | 25 +- 6 files changed, 805 insertions(+), 728 deletions(-) create mode 100644 crates/atuin-ai/src/tui/dispatch.rs diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 9201cf954bc..2fb5c56fa22 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,18 +1,12 @@ -use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; -use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; -use crate::permissions::walker::PermissionWalker; -use crate::stream::{ChatRequest, run_chat_stream}; -use crate::tools::ToolCallState; -use crate::tui::ConversationEvent; -use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::{AppState, ExitAction}; +use crate::tui::dispatch; +use crate::tui::events::AiTuiEvent; +use crate::tui::state::{ExitAction, Session}; use crate::tui::view::ai_view; use eye_declare::{Application, CtrlCBehavior}; use eyre::{Context as _, Result, bail}; -use serde_json::json; use tracing::{debug, info}; pub(crate) async fn run( @@ -123,7 +117,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu async fn run_inline_tui(ctx: AppContext, initial_prompt: Option) -> Result { let (tx, rx) = mpsc::channel::(); - let initial_state = AppState::new(); + let initial_state = Session::new(); let client_ctx = ClientContext::detect(); // initial_state // .pending_tool_calls @@ -169,404 +163,10 @@ async fn run_inline_tui(ctx: AppContext, initial_prompt: Option) -> Resu // Event loop: receives AiTuiEvent from components, mutates state via Handle. let h = handle.clone(); tokio::task::spawn_blocking(move || { - // Clone tx for use in each loop iteration (run_chat_stream and update closures - // move tx clones into spawned tasks/closures). let tx = tx.clone(); let client_ctx = client_ctx; while let Ok(event) = rx.recv() { - match event { - AiTuiEvent::ContinueAfterTools => { - let ctx = ctx.clone(); - let h2 = h.clone(); - let tx2 = tx.clone(); - let cc = client_ctx.clone(); - h.update(move |state| { - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let request = ChatRequest::new(messages, sid); - let task = tokio::spawn(async move { - run_chat_stream(h2, tx2, ctx, cc, request).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::InputUpdated(input) => { - let input_blank = input.trim().is_empty(); - - h.update(move |state| { - state.is_input_blank = input_blank; - }); - } - - AiTuiEvent::SubmitInput(input) => { - let input = input.trim().to_string(); - if input.is_empty() { - let h2 = h.clone(); - h.update(move |state| { - if state.has_any_command() { - state.exit_action = Some(ExitAction::Execute( - state.current_command().unwrap().to_string(), - )); - } else { - state.exit_action = Some(ExitAction::Cancel); - } - h2.exit(); - }); - continue; - } - - if input.starts_with('/') { - let input_clone = input.clone(); - h.update(move |state| { - state.handle_slash_command(&input_clone); - }); - continue; - } - - // Start generation and spawn streaming task - let ctx = ctx.clone(); - let h2 = h.clone(); - let tx2 = tx.clone(); - let cc = client_ctx.clone(); - h.update(move |state| { - state.start_generating(input); - state.start_streaming(); - state.is_input_blank = true; - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let request = ChatRequest::new(messages, sid); - let task = tokio::spawn(async move { - run_chat_stream(h2, tx2, ctx, cc, request).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::SlashCommand(command) => { - h.update(move |state| { - state.handle_slash_command(&command); - }); - } - - AiTuiEvent::CheckToolCallPermission(id) => { - eprintln!("Checking tool call permission: {:?}", &id); - let h2 = h.clone(); - let tx_for_task = tx.clone(); - - let id_clone = id.clone(); - tokio::spawn(async move { - let Ok(Some(tool_call)) = h2 - .fetch(move |state| state.get_pending_tool_call(&id).cloned()) - .await - else { - // todo: raise error - eprintln!("Error getting pending tool call: {:?}", &id_clone); - return; - }; - - let Some(working_dir) = tool_call - .target_dir() - .map(PathBuf::from) - .or_else(|| std::env::current_dir().ok()) - else { - // todo: raise error - eprintln!( - "Error getting working directory for tool call: {:?}", - &tool_call - ); - return; - }; - - let mut walker = PermissionWalker::new(working_dir.clone(), None); // todo: get global dir - - let Ok(_) = walker.walk().await else { - eprintln!("Error walking filesystem for permissions check"); - // todo: raise error - return; - }; - - let checker = PermissionChecker::new(walker.rules().to_owned()); - let request = - PermissionRequest::new(working_dir, Box::new(&tool_call.tool)); - - let Ok(response) = checker.check(&request).await else { - // todo: raise error - eprintln!("Error checking tool call permission"); - return; - }; - - match response { - PermissionResponse::Allowed => { - eprintln!("Executing tool call: {:?}", tool_call); - - let id_clone2 = id_clone.clone(); - h2.update(move |state| { - state.add_tool_call( - id_clone2.clone(), - "read".to_string(), - json!({}), - ); - - let mut tool_call = state.get_pending_tool_call_mut(&id_clone2); - - let Some(tool_call) = tool_call.as_mut() else { - eprintln!( - "Error getting pending tool call: {:?}", - &id_clone2 - ); - return; - }; - - tool_call.state = ToolCallState::Executing; - - // - - // state.events.push(ConversationEvent::OutOfBandOutput { - // name: "System".to_string(), - // content: format!( - // "Permission granted for tool call {:?}", - // &tool_call - // ), - // command: None, - // }); - }); - - if let crate::tools::ClientToolCall::Read(read) = tool_call.tool { - let mut path = read.path.clone(); - - if path.is_relative() - && let Ok(current_dir) = std::env::current_dir() - { - path = current_dir.join(path); - } - - if !path.exists() { - let id = id_clone.clone(); - h2.update(move |state| { - state.add_tool_result( - id.clone(), - format!( - "Error: file does not exist: {}", - path.display() - ), - true, - ); - state.pending_tool_calls.retain(|c| c.id != id); - }); - return; - } - - if path.is_dir() { - let Some(files) = std::fs::read_dir(&path) - .map_err(|e| { - eprintln!("Error reading directory: {}", e); - e - }) - .ok() - .and_then(|entries| { - entries - .filter_map(|entry| entry.ok()) - .map(|entry| { - entry - .file_name() - .to_string_lossy() - .to_string() - }) - .collect::>() - .into() - }) - else { - h2.update(move |state| { - state.add_tool_result( - id_clone.clone(), - format!( - "Error: could not read directory: {}", - path.display() - ), - true, - ); - state - .pending_tool_calls - .retain(|c| c.id != id_clone); - }); - return; - }; - - h2.update(move |state| { - state.add_tool_result( - id_clone.clone(), - format!( - "Directory contents:\n{}", - files.join("\n") - ), - false, - ); - state.pending_tool_calls.retain(|c| c.id != id_clone); - - let _ = - tx_for_task.send(AiTuiEvent::ContinueAfterTools); - }); - } - } - } - PermissionResponse::Denied => { - eprintln!("Permission denied for tool call: {:?}", &tool_call); - h2.update(move |state| { - state.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - content: format!( - "Permission denied for tool call {:?}", - &tool_call - ), - command: None, - }); - }); - } - PermissionResponse::Ask => { - eprintln!("Asking for permission for tool call: {:?}", &tool_call); - h2.update(move |state| { - let mut tool_call = state.get_pending_tool_call_mut(&id_clone); - - let Some(tool_call) = tool_call.as_mut() else { - eprintln!( - "Error getting pending tool call: {:?}", - &id_clone - ); - return; - }; - - eprintln!( - "Setting tool call state to AskingForPermission: {:?}", - &tool_call - ); - tool_call.state = ToolCallState::AskingForPermission; - eprintln!( - "Tool call state set to AskingForPermission: {:?}", - &tool_call - ); - }); - } - } - }); - } - - AiTuiEvent::SelectPermission(permission) => { - // Okay, we have permssion information. - // If accepted, we can start executing. - // If denied, we can show an error message. - h.update(move |state| { - let tool_call = state - .pending_tool_calls - .iter() - .enumerate() - .find(|(_, call)| call.state == ToolCallState::AskingForPermission); - - let Some((index, _)) = tool_call else { - return; - }; - - match permission { - PermissionResult::Allow => { - state.pending_tool_calls.remove(index); - } - PermissionResult::AlwaysAllowInDir => { - // - } - PermissionResult::AlwaysAllow => { - // - } - PermissionResult::Deny => { - let Some(call) = state.pending_tool_calls.remove(index) else { - return; - }; - - state.add_tool_result( - call.id, - "Permission denied on the user's system".to_string(), - true, - ); - } - } - }); - } - - AiTuiEvent::CancelGeneration => { - h.update(|state| match state.mode { - crate::tui::state::AppMode::Generating => { - state.cancel_generation(); - } - crate::tui::state::AppMode::Streaming => { - state.cancel_streaming(); - } - _ => {} - }); - } - - AiTuiEvent::ExecuteCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - if state.is_current_command_dangerous() && !state.confirmation_pending { - state.confirmation_pending = true; - } else { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Execute(cmd)); - h2.exit(); - } - } - }); - } - - AiTuiEvent::CancelConfirmation => { - h.update(move |state| { - state.confirmation_pending = false; - }); - } - - AiTuiEvent::InsertCommand => { - let h2 = h.clone(); - h.update(move |state| { - let cmd = state.current_command().map(|c| c.to_string()); - if let Some(cmd) = cmd { - state.confirmation_pending = false; - state.exit_action = Some(ExitAction::Insert(cmd)); - h2.exit(); - } - }); - } - - AiTuiEvent::Retry => { - let ctx = ctx.clone(); - let h2 = h.clone(); - let tx2 = tx.clone(); - let cc = client_ctx.clone(); - h.update(move |state| { - state.retry(); - state.start_streaming(); - let messages = state.events_to_messages(); - let sid = state.session_id.clone(); - let request = ChatRequest::new(messages, sid); - let task = tokio::spawn(async move { - run_chat_stream(h2, tx2, ctx, cc, request).await; - }); - state.stream_abort = Some(task.abort_handle()); - }); - } - - AiTuiEvent::Exit => { - let h2 = h.clone(); - h.update(move |state| { - if let Some(abort) = state.stream_abort.take() { - abort.abort(); - } - state.exit_action = Some(ExitAction::Cancel); - h2.exit(); - }); - } - } + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); } }); diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 49f8ef4e729..b93da09de26 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -15,11 +15,20 @@ use reqwest::Url; use crate::{ context::{AppContext, ClientContext}, tools::ClientToolCall, - tui::{AppState, events::AiTuiEvent}, + tui::{Session, events::AiTuiEvent}, }; +/// Frames that alter the stream lifecycle — terminal or state-changing. #[derive(Debug, Clone)] -enum ChatStreamEvent { +pub(crate) enum StreamControl { + Done { session_id: String }, + Error(String), + StatusChanged(String), +} + +/// Frames that carry conversation content — they mutate the event log. +#[derive(Debug, Clone)] +pub(crate) enum StreamContent { TextChunk(String), ToolCall { id: String, @@ -31,11 +40,13 @@ enum ChatStreamEvent { content: String, is_error: bool, }, - Status(String), - Done { - session_id: String, - }, - Error(String), +} + +/// A frame from the SSE stream, classified as control or content. +#[derive(Debug, Clone)] +pub(crate) enum StreamFrame { + Content(StreamContent), + Control(StreamControl), } /// Per-turn request payload for the chat API. @@ -62,7 +73,7 @@ fn create_chat_stream( request: ChatRequest, client_ctx: ClientContext, send_cwd: bool, -) -> std::pin::Pin> + Send>> { +) -> std::pin::Pin> + Send>> { Box::pin(async_stream::stream! { ensure_crypto_provider(); let endpoint = match hub_url(&hub_address, "/api/cli/chat") { @@ -134,7 +145,7 @@ fn create_chat_stream( if let Ok(json) = serde_json::from_str::(&data) && let Some(content) = json.get("content").and_then(|v| v.as_str()) { - yield Ok(ChatStreamEvent::TextChunk(content.to_string())); + yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); } } "tool_call" => { @@ -142,7 +153,7 @@ fn create_chat_stream( let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(ChatStreamEvent::ToolCall { id, name, input }); + yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); } } "tool_result" => { @@ -150,14 +161,14 @@ fn create_chat_stream( let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); + yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error })); } } "status" => { if let Ok(json) = serde_json::from_str::(&data) && let Some(state) = json.get("state").and_then(|v| v.as_str()) { - yield Ok(ChatStreamEvent::Status(state.to_string())); + yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); } } "done" => { @@ -166,9 +177,9 @@ fn create_chat_stream( .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); - yield Ok(ChatStreamEvent::Done { session_id }); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); } else { - yield Ok(ChatStreamEvent::Done { session_id: String::new() }); + yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); } break; } @@ -176,10 +187,10 @@ fn create_chat_stream( if let Ok(json) = serde_json::from_str::(&data) { let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); tracing::error!("SSE error: {}", message); - yield Ok(ChatStreamEvent::Error(message)); + yield Ok(StreamFrame::Control(StreamControl::Error(message))); } else { tracing::error!("SSE error: {}", data); - yield Ok(ChatStreamEvent::Error(data)); + yield Ok(StreamFrame::Control(StreamControl::Error(data))); } break; } @@ -200,7 +211,7 @@ fn create_chat_stream( // ─────────────────────────────────────────────────────────────────── pub(crate) async fn run_chat_stream( - handle: Handle, + handle: Handle, tx: mpsc::Sender, app_ctx: AppContext, client_ctx: ClientContext, @@ -217,70 +228,93 @@ pub(crate) async fn run_chat_stream( while let Some(event) = stream.next().await { match event { - Ok(ChatStreamEvent::TextChunk(text)) => { - tracing::trace!(text = %text, "Processing TextChunk"); - handle.update(move |state| { - state.append_streaming_text(&text); - }); + Ok(StreamFrame::Content(content)) => { + apply_content_frame(&handle, &tx, content); } - Ok(ChatStreamEvent::ToolCall { id, name, input }) => { - tracing::trace!(id = %id, name = %name, "Processing ToolCall"); - - if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { - // Recognized as a client-side tool call. - let id_for_update = id.clone(); - handle.update(move |state| { - state.handle_client_tool_call(id_for_update, tool); - }); - let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); - continue; + Ok(StreamFrame::Control(control)) => { + let terminal = apply_control_frame(&handle, control); + if terminal { + break; } - - handle.update(move |state| { - state.add_tool_call(id, name, input); - }); - } - Ok(ChatStreamEvent::ToolResult { - tool_use_id, - content, - is_error, - }) => { - tracing::trace!(tool_use_id = %tool_use_id, "Processing ToolResult"); - handle.update(move |state| { - state.add_tool_result(tool_use_id, content, is_error); - }); - } - Ok(ChatStreamEvent::Status(status)) => { - tracing::trace!(status = %status, "Processing Status"); - handle.update(move |state| { - state.update_streaming_status(&status); - }); } - Ok(ChatStreamEvent::Done { session_id }) => { - tracing::trace!(session_id = %session_id, "Processing Done"); + Err(e) => { + let msg = e.to_string(); handle.update(move |state| { - if !session_id.is_empty() { - state.store_session_id(session_id); - } - state.finalize_streaming(); + state.streaming_error(msg); }); break; } - Ok(ChatStreamEvent::Error(msg)) => { - tracing::trace!(error = %msg, "Processing Error"); + } + } +} + +/// Apply a content frame to session state. +/// Control flow: always continues the stream. +fn apply_content_frame( + handle: &Handle, + tx: &mpsc::Sender, + content: StreamContent, +) { + match content { + StreamContent::TextChunk(text) => { + handle.update(move |state| { + state.conversation.append_streaming_text(&text); + }); + } + StreamContent::ToolCall { id, name, input } => { + if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { + // Client-side tool — queue for permission check + let id_for_update = id.clone(); handle.update(move |state| { - state.streaming_error(msg); + state.handle_client_tool_call(id_for_update, tool); }); - break; - } - Err(e) => { - let msg = e.to_string(); + let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); + } else { + // Server-side tool handle.update(move |state| { - state.streaming_error(msg); + state.add_tool_call(id, name, input); }); - break; } } + StreamContent::ToolResult { + tool_use_id, + content, + is_error, + } => { + handle.update(move |state| { + state + .conversation + .add_tool_result(tool_use_id, content, is_error); + }); + } + } +} + +/// Apply a control frame to session state. +/// Returns true if the stream should terminate. +fn apply_control_frame(handle: &Handle, control: StreamControl) -> bool { + match control { + StreamControl::StatusChanged(status) => { + handle.update(move |state| { + state.update_streaming_status(&status); + }); + false + } + StreamControl::Done { session_id } => { + handle.update(move |state| { + if !session_id.is_empty() { + state.conversation.store_session_id(session_id); + } + state.finalize_streaming(); + }); + true + } + StreamControl::Error(msg) => { + handle.update(move |state| { + state.streaming_error(msg); + }); + true + } } } diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs new file mode 100644 index 00000000000..b2f8fcc41a6 --- /dev/null +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -0,0 +1,430 @@ +use std::path::PathBuf; +use std::sync::mpsc; + +use crate::context::{AppContext, ClientContext}; +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::stream::{ChatRequest, run_chat_stream}; +use crate::tools::ToolCallState; +use crate::tui::ConversationEvent; +use crate::tui::events::{AiTuiEvent, PermissionResult}; +use crate::tui::state::{ExitAction, Session}; +use eye_declare::Handle; +use serde_json::json; +use tokio::task::JoinHandle; + +pub(crate) fn dispatch( + handle: &Handle, + event: AiTuiEvent, + tx: &mpsc::Sender, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + match event { + AiTuiEvent::ContinueAfterTools => { + on_continue_after_tools(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::InputUpdated(input) => { + on_input_updated(handle, input); + } + AiTuiEvent::SubmitInput(input) => { + on_submit_input(handle, tx, app_ctx, client_ctx, input); + } + AiTuiEvent::SlashCommand(cmd) => { + on_slash_command(handle, cmd); + } + AiTuiEvent::CheckToolCallPermission(id) => { + on_check_tool_permission(handle, tx, id); + } + AiTuiEvent::SelectPermission(result) => { + on_select_permission(handle, result); + } + AiTuiEvent::CancelGeneration => { + on_cancel_generation(handle); + } + AiTuiEvent::ExecuteCommand => { + on_execute_command(handle); + } + AiTuiEvent::CancelConfirmation => { + on_cancel_confirmation(handle); + } + AiTuiEvent::InsertCommand => { + on_insert_command(handle); + } + AiTuiEvent::Retry => { + on_retry(handle, tx, app_ctx, client_ctx); + } + AiTuiEvent::Exit => { + on_exit(handle); + } + } +} + +fn launch_stream( + handle: &Handle, + tx: &mpsc::Sender, + app_ctx: &AppContext, + client_ctx: &ClientContext, + setup: impl FnOnce(&mut Session) + Send + 'static, +) { + let h2 = handle.clone(); + let tx2 = tx.clone(); + let app = app_ctx.clone(); + let cc = client_ctx.clone(); + handle.update(move |state| { + (setup)(state); + state.start_streaming(); + let messages = state.conversation.events_to_messages(); + let sid = state.conversation.session_id.clone(); + let request = ChatRequest::new(messages, sid); + let task: JoinHandle<()> = tokio::spawn(async move { + run_chat_stream(h2, tx2, app, cc, request).await; + }); + state.stream_abort = Some(task.abort_handle()); + }); +} + +fn on_continue_after_tools( + handle: &Handle, + tx: &mpsc::Sender, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |_state| {}); +} + +fn on_input_updated(handle: &Handle, input: String) { + let input_blank = input.trim().is_empty(); + + handle.update(move |state| { + state.interaction.is_input_blank = input_blank; + }); +} + +fn on_submit_input( + handle: &Handle, + tx: &mpsc::Sender, + app_ctx: &AppContext, + client_ctx: &ClientContext, + input: String, +) { + let input = input.trim().to_string(); + if input.is_empty() { + let h2 = handle.clone(); + handle.update(move |state| { + if state.conversation.has_any_command() { + state.exit_action = Some(ExitAction::Execute( + state.conversation.current_command().unwrap().to_string(), + )); + } else { + state.exit_action = Some(ExitAction::Cancel); + } + h2.exit(); + }); + return; + } + + if input.starts_with('/') { + let input_clone = input.clone(); + handle.update(move |state| { + state.conversation.handle_slash_command(&input_clone); + }); + return; + } + + // Start generation and spawn streaming task + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.start_generating(input); + state.interaction.is_input_blank = true; + }); +} + +fn on_slash_command(handle: &Handle, command: String) { + handle.update(move |state| { + state.conversation.handle_slash_command(&command); + }); +} + +fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, id: String) { + eprintln!("Checking tool call permission: {:?}", &id); + let h2 = handle.clone(); + let tx_for_task = tx.clone(); + + let id_clone = id.clone(); + tokio::spawn(async move { + let Ok(Some(tool_call)) = h2 + .fetch(move |state| state.pending_tool_call(&id).cloned()) + .await + else { + // todo: raise error + eprintln!("Error getting pending tool call: {:?}", &id_clone); + return; + }; + + let Some(working_dir) = tool_call + .target_dir() + .map(PathBuf::from) + .or_else(|| std::env::current_dir().ok()) + else { + // todo: raise error + eprintln!( + "Error getting working directory for tool call: {:?}", + &tool_call + ); + return; + }; + + let mut walker = PermissionWalker::new(working_dir.clone(), None); // todo: get global dir + + let Ok(_) = walker.walk().await else { + eprintln!("Error walking filesystem for permissions check"); + // todo: raise error + return; + }; + + let checker = PermissionChecker::new(walker.rules().to_owned()); + let request = PermissionRequest::new(working_dir, Box::new(&tool_call.tool)); + + let Ok(response) = checker.check(&request).await else { + // todo: raise error + eprintln!("Error checking tool call permission"); + return; + }; + + match response { + PermissionResponse::Allowed => { + eprintln!("Executing tool call: {:?}", tool_call); + + let id_clone2 = id_clone.clone(); + h2.update(move |state| { + state.add_tool_call(id_clone2.clone(), "read".to_string(), json!({})); + + let mut tool_call = state.pending_tool_call_mut(&id_clone2); + + let Some(tool_call) = tool_call.as_mut() else { + eprintln!("Error getting pending tool call: {:?}", &id_clone2); + return; + }; + + tool_call.state = ToolCallState::Executing; + + // + + // state.events.push(ConversationEvent::OutOfBandOutput { + // name: "System".to_string(), + // content: format!( + // "Permission granted for tool call {:?}", + // &tool_call + // ), + // command: None, + // }); + }); + + if let crate::tools::ClientToolCall::Read(read) = tool_call.tool { + let mut path = read.path.clone(); + + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } + + if !path.exists() { + let id = id_clone.clone(); + h2.update(move |state| { + state.conversation.add_tool_result( + id.clone(), + format!("Error: file does not exist: {}", path.display()), + true, + ); + state.pending_tool_calls.retain(|c| c.id != id); + }); + return; + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path) + .map_err(|e| { + eprintln!("Error reading directory: {}", e); + e + }) + .ok() + .and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.file_name().to_string_lossy().to_string()) + .collect::>() + .into() + }) + else { + h2.update(move |state| { + state.conversation.add_tool_result( + id_clone.clone(), + format!("Error: could not read directory: {}", path.display()), + true, + ); + state.pending_tool_calls.retain(|c| c.id != id_clone); + }); + return; + }; + + h2.update(move |state| { + state.conversation.add_tool_result( + id_clone.clone(), + format!("Directory contents:\n{}", files.join("\n")), + false, + ); + state.pending_tool_calls.retain(|c| c.id != id_clone); + + let _ = tx_for_task.send(AiTuiEvent::ContinueAfterTools); + }); + } + } + } + PermissionResponse::Denied => { + eprintln!("Permission denied for tool call: {:?}", &tool_call); + h2.update(move |state| { + state + .conversation + .events + .push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + content: format!("Permission denied for tool call {:?}", &tool_call), + command: None, + }); + }); + } + PermissionResponse::Ask => { + eprintln!("Asking for permission for tool call: {:?}", &tool_call); + h2.update(move |state| { + let mut tool_call = state.pending_tool_call_mut(&id_clone); + + let Some(tool_call) = tool_call.as_mut() else { + eprintln!("Error getting pending tool call: {:?}", &id_clone); + return; + }; + + eprintln!( + "Setting tool call state to AskingForPermission: {:?}", + &tool_call + ); + tool_call.state = ToolCallState::AskingForPermission; + eprintln!( + "Tool call state set to AskingForPermission: {:?}", + &tool_call + ); + }); + } + } + }); +} + +fn on_select_permission(handle: &Handle, permission: PermissionResult) { + // Okay, we have permssion information. + // If accepted, we can start executing. + // If denied, we can show an error message. + handle.update(move |state| { + let tool_call = state + .pending_tool_calls + .iter() + .enumerate() + .find(|(_, call)| call.state == ToolCallState::AskingForPermission); + + let Some((index, _)) = tool_call else { + return; + }; + + match permission { + PermissionResult::Allow => { + state.pending_tool_calls.remove(index); + } + PermissionResult::AlwaysAllowInDir => { + // + } + PermissionResult::AlwaysAllow => { + // + } + PermissionResult::Deny => { + let Some(call) = state.pending_tool_calls.remove(index) else { + return; + }; + + state.conversation.add_tool_result( + call.id, + "Permission denied on the user's system".to_string(), + true, + ); + } + } + }); +} + +fn on_cancel_generation(handle: &Handle) { + handle.update(|state| match state.interaction.mode { + crate::tui::state::AppMode::Generating => { + state.cancel_generation(); + } + crate::tui::state::AppMode::Streaming => { + state.cancel_streaming(); + } + _ => {} + }); +} + +fn on_execute_command(handle: &Handle) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + if state.conversation.is_current_command_dangerous() + && !state.interaction.confirmation_pending + { + state.interaction.confirmation_pending = true; + } else { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Execute(cmd)); + h2.exit(); + } + } + }); +} + +fn on_cancel_confirmation(handle: &Handle) { + handle.update(move |state| { + state.interaction.confirmation_pending = false; + }); +} + +fn on_insert_command(handle: &Handle) { + let h2 = handle.clone(); + handle.update(move |state| { + let cmd = state.conversation.current_command().map(|c| c.to_string()); + if let Some(cmd) = cmd { + state.interaction.confirmation_pending = false; + state.exit_action = Some(ExitAction::Insert(cmd)); + h2.exit(); + } + }); +} + +fn on_retry( + handle: &Handle, + tx: &mpsc::Sender, + app_ctx: &AppContext, + client_ctx: &ClientContext, +) { + launch_stream(handle, tx, app_ctx, client_ctx, |state| { + state.retry(); + }); +} + +fn on_exit(handle: &Handle) { + let h2 = handle.clone(); + handle.update(move |state| { + if let Some(abort) = state.stream_abort.take() { + abort.abort(); + } + state.exit_action = Some(ExitAction::Cancel); + h2.exit(); + }); +} diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index 081e79fdab1..afd63312f7b 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod components; +pub(crate) mod dispatch; pub(crate) mod events; pub(crate) mod state; pub(crate) mod view; -pub(crate) use state::{AppState, ConversationEvent}; +pub(crate) use state::{ConversationEvent, Session}; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 95e333a36d4..561d3e6af26 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -139,50 +139,20 @@ pub(crate) enum ExitAction { Cancel, } -/// Application state — the domain model -/// -/// Conversation is stored as a sequence of events matching the API protocol. -/// The view function derives the UI from this state. +/// Owned event log and session ID #[derive(Debug)] -pub(crate) struct AppState { - /// Current application mode - pub mode: AppMode, +pub(crate) struct Conversation { /// Conversation events (source of truth, matches API protocol) pub events: Vec, - /// Current error message - pub error: Option, - /// Exit action (set when exiting) - pub exit_action: Option, /// Session ID from server pub session_id: Option, - /// Current streaming status - pub streaming_status: Option, - /// Whether the input is blank - pub is_input_blank: bool, - /// Whether current turn was interrupted by user - pub was_interrupted: bool, - /// True when user has pressed Enter once on a dangerous command - pub confirmation_pending: bool, - /// Abort handle for the active streaming task, if any - pub stream_abort: Option, - /// Tool calls that are pending permission checking + execution - pub pending_tool_calls: VecDeque, } -impl AppState { +impl Conversation { pub fn new() -> Self { Self { - mode: AppMode::Input, events: Vec::new(), - error: None, - exit_action: None, session_id: None, - streaming_status: None, - is_input_blank: false, - was_interrupted: false, - confirmation_pending: false, - stream_abort: None, - pending_tool_calls: VecDeque::new(), } } @@ -254,53 +224,74 @@ impl AppState { messages } - // ===== Generation lifecycle methods ===== - - /// Start generating from submitted input - pub fn start_generating(&mut self, input: String) { - self.events - .push(ConversationEvent::UserMessage { content: input }); - self.mode = AppMode::Generating; + /// Get the most recent command from events + pub fn current_command(&self) -> Option<&str> { + self.events.iter().rev().find_map(|e| e.as_command()) } - /// Generation error occurred - pub fn generation_error(&mut self, error: String) { - self.error = Some(error); - self.mode = AppMode::Error; + /// Check if any turn in the conversation has a command + pub fn has_any_command(&self) -> bool { + self.events.iter().any(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e { + name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + } else { + false + } + }) } - /// Cancel during generation - pub fn cancel_generation(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() { - self.events.pop(); - } - self.mode = AppMode::Input; + /// Check if the most recent command is marked dangerous + pub fn is_current_command_dangerous(&self) -> bool { + self.events + .iter() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger_level = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + return Some( + danger_level == "high" || danger_level == "medium" || danger_level == "med", + ); + } + None + }) + .unwrap_or(false) } - // ===== Streaming lifecycle methods ===== + /// Count non-suggest_command tool calls since the last user message + pub fn tool_count_since_last_user(&self) -> usize { + let last_user_idx = self + .events + .iter() + .rposition(|e| matches!(e, ConversationEvent::UserMessage { .. })) + .unwrap_or(0); - /// Start streaming response. - /// Pushes an empty Text event that will be mutated in-place as chunks arrive. - pub fn start_streaming(&mut self) { - self.events.push(ConversationEvent::Text { - content: String::new(), - }); - self.streaming_status = None; - self.was_interrupted = false; - self.mode = AppMode::Streaming; - } + let mut completed = 0; + let mut in_flight = false; - /// Store session ID from server response - pub fn store_session_id(&mut self, session_id: String) { - self.session_id = Some(session_id); - } + for event in &self.events[last_user_idx..] { + match event { + ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => { + if in_flight { + completed += 1; + } + in_flight = true; + } + ConversationEvent::ToolResult { .. } => { + if in_flight { + completed += 1; + in_flight = false; + } + } + _ => {} + } + } - /// Update streaming status from SSE event - pub fn update_streaming_status(&mut self, status: &str) { - self.streaming_status = Some(StreamingStatus::from_status_str(status)); + completed } /// Get a mutable reference to the last Text event's content (the streaming buffer). @@ -314,28 +305,15 @@ impl AppState { }) } - /// Cancel streaming with context preservation - pub fn cancel_streaming(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - self.was_interrupted = true; - - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - if trimmed.is_empty() { - // Remove the empty text event - *content = String::new(); + /// Remove trailing empty Text events from the events list + fn remove_empty_trailing_text(&mut self) { + while let Some(ConversationEvent::Text { content }) = self.events.last() { + if content.is_empty() { + self.events.pop(); } else { - *content = format!("{trimmed}\n\n[User cancelled this generation]"); + break; } } - // Remove trailing empty Text events - self.remove_empty_trailing_text(); - - self.streaming_status = None; - self.confirmation_pending = false; - self.mode = AppMode::Input; } /// Append text chunk during streaming (mutates the last Text event in-place) @@ -361,218 +339,251 @@ impl AppState { } } - pub(crate) fn handle_client_tool_call(&mut self, id: String, tool: ClientToolCall) { - self.pending_tool_calls.push_back(PendingToolCall { - id: id.clone(), - state: ToolCallState::CheckingPermissions, - tool, + /// Add a tool result event during streaming + pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { + self.events.push(ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, }); + } - // Client tool calls can only happen at the last part of a turn - self.streaming_status = None; - self.mode = AppMode::Input; + /// Store session ID from server response + pub fn store_session_id(&mut self, session_id: String) { + self.session_id = Some(session_id); } - pub(crate) fn handle_select_permission(&mut self, permission: String) { - match permission.as_str() { - "allow" => { - self.pending_tool_calls.pop_front(); - } - "always-allow-in-dir" => { - self.pending_tool_calls.pop_front(); - } - "always-allow" => { - self.pending_tool_calls.pop_front(); + /// Handle a slash command + pub fn handle_slash_command(&mut self, command: &str) { + match command.trim() { + "/help" => { + let content = include_str!("./content/help.md"); + + self.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: Some("/help".to_string()), + content: content.to_string(), + }); } - "deny" => { - self.pending_tool_calls.pop_front(); + _ => self.events.push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: None, + content: (format!("Unknown command: {command}")), + }), + } + } +} + +/// Ephemeral UI/presentation state +#[derive(Debug)] +pub(crate) struct Interaction { + /// Current application mode + pub mode: AppMode, + /// Whether the input is blank + pub is_input_blank: bool, + /// True when user has pressed Enter once on a dangerous command + pub confirmation_pending: bool, + /// Current streaming status + pub streaming_status: Option, + /// Whether current turn was interrupted by user + pub was_interrupted: bool, + /// Current error message + pub error: Option, +} + +impl Interaction { + pub fn new() -> Self { + Self { + mode: AppMode::Input, + is_input_blank: false, + confirmation_pending: false, + streaming_status: None, + was_interrupted: false, + error: None, + } + } +} + +/// Top-level session state +/// +/// Decomposed into `Conversation` (event log + session ID) and +/// `Interaction` (ephemeral UI state). Session methods that cross +/// both sub-structs live here. +#[derive(Debug)] +pub(crate) struct Session { + pub conversation: Conversation, + pub interaction: Interaction, + /// Tool calls that are pending permission checking + execution + pub pending_tool_calls: VecDeque, + /// Exit action (set when exiting) + pub exit_action: Option, + /// Abort handle for the active streaming task, if any + pub stream_abort: Option, +} + +impl Session { + pub fn new() -> Self { + Self { + conversation: Conversation::new(), + interaction: Interaction::new(), + pending_tool_calls: VecDeque::new(), + exit_action: None, + stream_abort: None, + } + } + + // ===== Generation lifecycle methods ===== + + /// Start generating from submitted input + pub fn start_generating(&mut self, input: String) { + self.conversation + .events + .push(ConversationEvent::UserMessage { content: input }); + self.interaction.mode = AppMode::Generating; + } + + /// Generation error occurred + pub fn generation_error(&mut self, error: String) { + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } + + /// Cancel during generation + pub fn cancel_generation(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() { + self.conversation.events.pop(); + } + self.interaction.mode = AppMode::Input; + } + + // ===== Streaming lifecycle methods ===== + + /// Start streaming response. + /// Pushes an empty Text event that will be mutated in-place as chunks arrive. + pub fn start_streaming(&mut self) { + self.conversation.events.push(ConversationEvent::Text { + content: String::new(), + }); + self.interaction.streaming_status = None; + self.interaction.was_interrupted = false; + self.interaction.mode = AppMode::Streaming; + } + + /// Update streaming status from SSE event + pub fn update_streaming_status(&mut self, status: &str) { + self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status)); + } + + /// Cancel streaming with context preservation + pub fn cancel_streaming(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + self.interaction.was_interrupted = true; + + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + if trimmed.is_empty() { + // Remove the empty text event + *content = String::new(); + } else { + *content = format!("{trimmed}\n\n[User cancelled this generation]"); } - _ => {} } + // Remove trailing empty Text events + self.conversation.remove_empty_trailing_text(); + + self.interaction.streaming_status = None; + self.interaction.confirmation_pending = false; + self.interaction.mode = AppMode::Input; } /// Add a tool call event during streaming. /// The current streaming text is already in events, so we just push the tool call. pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { // Trim the streaming text event - if let Some(content) = self.streaming_content_mut() { + if let Some(content) = self.conversation.streaming_content_mut() { let trimmed = content.trim_start().to_string(); *content = trimmed; } - self.remove_empty_trailing_text(); + self.conversation.remove_empty_trailing_text(); let is_suggest_command = name == "suggest_command"; - self.events + self.conversation + .events .push(ConversationEvent::ToolCall { id, name, input }); if is_suggest_command { - self.streaming_status = None; - self.mode = AppMode::Input; + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; } } - /// Add a tool result event during streaming - pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { - self.events.push(ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - }); - } - /// Finalize streaming — trim the accumulated text and change mode pub fn finalize_streaming(&mut self) { - if let Some(content) = self.streaming_content_mut() { + if let Some(content) = self.conversation.streaming_content_mut() { let trimmed = content.trim_start().to_string(); *content = trimmed; } - self.remove_empty_trailing_text(); - self.streaming_status = None; - self.mode = AppMode::Input; + self.conversation.remove_empty_trailing_text(); + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; } /// Streaming error — remove the partial text event pub fn streaming_error(&mut self, error: String) { - self.remove_empty_trailing_text(); - self.error = Some(error); - self.mode = AppMode::Error; + self.conversation.remove_empty_trailing_text(); + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; } - /// Remove trailing empty Text events from the events list - fn remove_empty_trailing_text(&mut self) { - while let Some(ConversationEvent::Text { content }) = self.events.last() { - if content.is_empty() { - self.events.pop(); - } else { - break; - } - } - } + pub(crate) fn handle_client_tool_call(&mut self, id: String, tool: ClientToolCall) { + self.pending_tool_calls.push_back(PendingToolCall { + id: id.clone(), + state: ToolCallState::CheckingPermissions, + tool, + }); - // ===== Edit mode and exit methods ===== + // Client tool calls can only happen at the last part of a turn + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } /// Start edit mode for refinement pub fn start_edit_mode(&mut self) { - self.confirmation_pending = false; - self.mode = AppMode::Input; + self.interaction.confirmation_pending = false; + self.interaction.mode = AppMode::Input; } /// Retry after error pub fn retry(&mut self) { - self.error = None; - self.mode = AppMode::Generating; - } - - /// Handle a slash command - pub fn handle_slash_command(&mut self, command: &str) { - match command.trim() { - "/help" => { - let content = include_str!("./content/help.md"); - - self.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/help".to_string()), - content: content.to_string(), - }); - } - _ => self.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: None, - content: (format!("Unknown command: {command}")), - }), - } + self.interaction.error = None; + self.interaction.mode = AppMode::Generating; } // ===== Query methods ===== /// Get a pending tool call by ID - pub(crate) fn get_pending_tool_call(&self, id: &str) -> Option<&PendingToolCall> { + pub(crate) fn pending_tool_call(&self, id: &str) -> Option<&PendingToolCall> { self.pending_tool_calls.iter().find(|call| call.id == id) } /// Get a mutable pending tool call by ID - pub(crate) fn get_pending_tool_call_mut(&mut self, id: &str) -> Option<&mut PendingToolCall> { + pub(crate) fn pending_tool_call_mut(&mut self, id: &str) -> Option<&mut PendingToolCall> { self.pending_tool_calls .iter_mut() .find(|call| call.id == id) } - /// Get the most recent command from events - pub fn current_command(&self) -> Option<&str> { - self.events.iter().rev().find_map(|e| e.as_command()) - } - - /// Check if the most recent command is marked dangerous - pub fn is_current_command_dangerous(&self) -> bool { - self.events - .iter() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - return Some( - danger_level == "high" || danger_level == "medium" || danger_level == "med", - ); - } - None - }) - .unwrap_or(false) - } - - /// Count non-suggest_command tool calls since the last user message - pub fn tool_count_since_last_user(&self) -> usize { - let last_user_idx = self - .events - .iter() - .rposition(|e| matches!(e, ConversationEvent::UserMessage { .. })) - .unwrap_or(0); - - let mut completed = 0; - let mut in_flight = false; - - for event in &self.events[last_user_idx..] { - match event { - ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => { - if in_flight { - completed += 1; - } - in_flight = true; - } - ConversationEvent::ToolResult { .. } => { - if in_flight { - completed += 1; - in_flight = false; - } - } - _ => {} - } - } - - completed - } - - /// Check if any turn in the conversation has a command - pub fn has_any_command(&self) -> bool { - self.events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() - } else { - false - } - }) - } - /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { - match self.mode { + match self.interaction.mode { AppMode::Input => { - if self.has_any_command() && self.is_input_blank { - if self.confirmation_pending { + if self.conversation.has_any_command() && self.interaction.is_input_blank { + if self.interaction.confirmation_pending { "[Enter] Confirm dangerous command [Esc] Cancel" } else { "[Enter] Execute suggested command [Tab] Insert Command" diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 540aa5e7eb8..0954ab2f049 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -13,7 +13,7 @@ use super::components::atuin_ai::AtuinAi; use super::components::input_box::InputBox; use super::components::markdown::Markdown; use super::components::select::Select; -use super::state::{AppMode, AppState}; +use super::state::{AppMode, Session}; mod turn; @@ -25,23 +25,24 @@ mod turn; /// - Error display (if in error state) /// - Spacer /// - Input box (bordered, with contextual keybindings) -pub(crate) fn ai_view(state: &AppState) -> Elements { +pub(crate) fn ai_view(state: &Session) -> Elements { let mut turn_builder = turn::TurnBuilder::new(); - for event in &state.events { + for event in &state.conversation.events { turn_builder.add_event(event); } let turns = turn_builder.build(); - let busy = state.mode == AppMode::Streaming || state.mode == AppMode::Generating; + let busy = state.interaction.mode == AppMode::Streaming + || state.interaction.mode == AppMode::Generating; let last_index = turns.len().saturating_sub(1); element! { AtuinAi( - mode: state.mode, - has_command: state.has_any_command(), - is_input_blank: state.is_input_blank, - pending_confirmation: state.confirmation_pending, + mode: state.interaction.mode, + has_command: state.conversation.has_any_command(), + is_input_blank: state.interaction.is_input_blank, + pending_confirmation: state.interaction.confirmation_pending, ) { #(for (index, turn) in turns.iter().enumerate() { #(match turn { @@ -64,7 +65,7 @@ pub(crate) fn ai_view(state: &AppState) -> Elements { } } -fn input_view(state: &AppState) -> Elements { +fn input_view(state: &Session) -> Elements { let first_pending_tool_call = state .pending_tool_calls .iter() @@ -82,11 +83,11 @@ fn input_view(state: &AppState) -> Elements { title: "Generate a command or ask a question", title_right: "Atuin AI", footer: state.footer_text(), - active: state.mode == AppMode::Input && !state.confirmation_pending, + active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending, ) - #(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input { - #(if state.confirmation_pending { + #(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input { + #(if state.interaction.confirmation_pending { Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } } else { Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } From 339dcb8714042e07f97c72668c75ee55fc849d7d Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 18:18:11 -0700 Subject: [PATCH 08/52] =?UTF-8?q?atuin-ai:=20Phase=203=20=E2=80=94=20extra?= =?UTF-8?q?ct=20tool=20execution=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extracts three abstractions that untangle the ~180-line on_check_tool_permission handler: **PermissionResolver (permissions/resolver.rs):** Composes the PermissionWalker + PermissionChecker into a single new/check API. The handler no longer imports Walker, Checker, or Request directly. **ToolOutcome + ClientToolCall::execute (tools/mod.rs):** Each tool variant owns its execution logic. ReadToolCall::execute handles both directory listing and file reading (previously only directory listing worked — file reading was a no-op). Returns ToolOutcome::Success or ToolOutcome::Error, replacing the inline match arms. **PendingToolCall state transitions (tools/mod.rs):** mark_asking(), mark_executing(), mark_denied() methods formalize the state machine instead of direct enum variant assignment. **Session::complete_tool_call (tui/state.rs):** Combines add_tool_result + pending_tool_calls.retain into one method, replacing the repeated cleanup pattern in the handler. The handler drops from ~180 lines to ~60 lines. --- crates/atuin-ai/src/permissions/mod.rs | 1 + crates/atuin-ai/src/permissions/resolver.rs | 32 +++++ crates/atuin-ai/src/tools/mod.rs | 69 ++++++++++ crates/atuin-ai/src/tui/dispatch.rs | 144 +++----------------- crates/atuin-ai/src/tui/state.rs | 16 ++- 5 files changed, 137 insertions(+), 125 deletions(-) create mode 100644 crates/atuin-ai/src/permissions/resolver.rs diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs index defb70130a2..5920e7ac893 100644 --- a/crates/atuin-ai/src/permissions/mod.rs +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod check; pub(crate) mod file; +pub(crate) mod resolver; pub(crate) mod rule; pub(crate) mod walker; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs new file mode 100644 index 00000000000..6f26ed4a3b8 --- /dev/null +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -0,0 +1,32 @@ +use std::path::PathBuf; + +use eyre::Result; + +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::tools::ClientToolCall; + +pub(crate) struct PermissionResolver { + checker: PermissionChecker, + working_dir: PathBuf, +} + +impl PermissionResolver { + /// Walk the filesystem from `working_dir` to find permission files, + /// then build a checker from them. + pub async fn new(working_dir: PathBuf, global_dir: Option) -> Result { + let mut walker = PermissionWalker::new(working_dir.clone(), global_dir); + walker.walk().await?; + let checker = PermissionChecker::new(walker.rules().to_owned()); + Ok(Self { + checker, + working_dir, + }) + } + + /// Check whether `tool` is allowed, denied, or needs user confirmation. + pub async fn check(&self, tool: &ClientToolCall) -> Result { + let request = PermissionRequest::new(self.working_dir.clone(), Box::new(tool)); + self.checker.check(&request).await + } +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index d92b42ae72c..0a02d8ff429 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -6,6 +6,12 @@ pub(crate) mod descriptor; use crate::permissions::rule::Rule; +/// Result of executing a client-side tool. +pub(crate) enum ToolOutcome { + Success(String), + Error(String), +} + /// A pending tool call from the server, awaiting permissions or execution. #[derive(Debug, Clone)] pub(crate) struct PendingToolCall { @@ -18,6 +24,23 @@ impl PendingToolCall { pub(crate) fn target_dir(&self) -> Option<&Path> { self.tool.target_dir() } + + /// Mark this tool call as waiting for user permission. + pub fn mark_asking(&mut self) { + self.state = ToolCallState::AskingForPermission; + } + + /// Mark this tool call as currently executing. + #[expect(dead_code)] + pub fn mark_executing(&mut self) { + self.state = ToolCallState::Executing; + } + + /// Mark this tool call as denied. + #[expect(dead_code)] + pub fn mark_denied(&mut self, reason: String) { + self.state = ToolCallState::Denied(reason); + } } /// State of a pending tool call @@ -84,6 +107,14 @@ impl ClientToolCall { ClientToolCall::AtuinHistory(tool) => tool.target_dir(), } } + + /// Execute this client-side tool and return the result. + pub fn execute(&self) -> ToolOutcome { + match self { + ClientToolCall::Read(tool) => tool.execute(), + _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), + } + } } /// A trait for tool calls that can be checked against permission rules. @@ -147,6 +178,44 @@ impl TryFrom<&serde_json::Value> for ReadToolCall { } } +impl ReadToolCall { + fn execute(&self) -> ToolOutcome { + let mut path = self.path.clone(); + + if path.is_relative() + && let Ok(current_dir) = std::env::current_dir() + { + path = current_dir.join(path); + } + + if !path.exists() { + return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); + } + + if path.is_dir() { + let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { + entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.file_name().to_string_lossy().to_string()) + .collect::>() + .into() + }) else { + return ToolOutcome::Error(format!( + "Error: could not read directory: {}", + path.display() + )); + }; + + return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); + } + + match std::fs::read_to_string(&path) { + Ok(content) => ToolOutcome::Success(content), + Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), + } + } +} + impl PermissableToolCall for ReadToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index b2f8fcc41a6..0c9537952aa 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -2,15 +2,14 @@ use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; -use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; -use crate::permissions::walker::PermissionWalker; +use crate::permissions::check::PermissionResponse; +use crate::permissions::resolver::PermissionResolver; use crate::stream::{ChatRequest, run_chat_stream}; use crate::tools::ToolCallState; use crate::tui::ConversationEvent; use crate::tui::events::{AiTuiEvent, PermissionResult}; use crate::tui::state::{ExitAction, Session}; use eye_declare::Handle; -use serde_json::json; use tokio::task::JoinHandle; pub(crate) fn dispatch( @@ -146,143 +145,54 @@ fn on_slash_command(handle: &Handle, command: String) { } fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, id: String) { - eprintln!("Checking tool call permission: {:?}", &id); let h2 = handle.clone(); let tx_for_task = tx.clone(); - let id_clone = id.clone(); + tokio::spawn(async move { + // 1. Fetch the pending tool call let Ok(Some(tool_call)) = h2 .fetch(move |state| state.pending_tool_call(&id).cloned()) .await else { - // todo: raise error - eprintln!("Error getting pending tool call: {:?}", &id_clone); + eprintln!("Pending tool call not found: {:?}", &id_clone); return; }; + // 2. Resolve working directory let Some(working_dir) = tool_call .target_dir() .map(PathBuf::from) .or_else(|| std::env::current_dir().ok()) else { - // todo: raise error eprintln!( - "Error getting working directory for tool call: {:?}", - &tool_call + "Cannot resolve working directory for tool call: {:?}", + &id_clone ); return; }; - let mut walker = PermissionWalker::new(working_dir.clone(), None); // todo: get global dir - - let Ok(_) = walker.walk().await else { - eprintln!("Error walking filesystem for permissions check"); - // todo: raise error + // 3. Create permission resolver and check + let Ok(resolver) = PermissionResolver::new(working_dir, None).await else { + eprintln!("Failed to create permission resolver"); return; }; - let checker = PermissionChecker::new(walker.rules().to_owned()); - let request = PermissionRequest::new(working_dir, Box::new(&tool_call.tool)); - - let Ok(response) = checker.check(&request).await else { - // todo: raise error - eprintln!("Error checking tool call permission"); + let Ok(response) = resolver.check(&tool_call.tool).await else { + eprintln!("Permission check failed for tool call: {:?}", &id_clone); return; }; + // 4. Handle response match response { PermissionResponse::Allowed => { - eprintln!("Executing tool call: {:?}", tool_call); - - let id_clone2 = id_clone.clone(); + let outcome = tool_call.tool.execute(); h2.update(move |state| { - state.add_tool_call(id_clone2.clone(), "read".to_string(), json!({})); - - let mut tool_call = state.pending_tool_call_mut(&id_clone2); - - let Some(tool_call) = tool_call.as_mut() else { - eprintln!("Error getting pending tool call: {:?}", &id_clone2); - return; - }; - - tool_call.state = ToolCallState::Executing; - - // - - // state.events.push(ConversationEvent::OutOfBandOutput { - // name: "System".to_string(), - // content: format!( - // "Permission granted for tool call {:?}", - // &tool_call - // ), - // command: None, - // }); + state.complete_tool_call(&id_clone, outcome); }); - - if let crate::tools::ClientToolCall::Read(read) = tool_call.tool { - let mut path = read.path.clone(); - - if path.is_relative() - && let Ok(current_dir) = std::env::current_dir() - { - path = current_dir.join(path); - } - - if !path.exists() { - let id = id_clone.clone(); - h2.update(move |state| { - state.conversation.add_tool_result( - id.clone(), - format!("Error: file does not exist: {}", path.display()), - true, - ); - state.pending_tool_calls.retain(|c| c.id != id); - }); - return; - } - - if path.is_dir() { - let Some(files) = std::fs::read_dir(&path) - .map_err(|e| { - eprintln!("Error reading directory: {}", e); - e - }) - .ok() - .and_then(|entries| { - entries - .filter_map(|entry| entry.ok()) - .map(|entry| entry.file_name().to_string_lossy().to_string()) - .collect::>() - .into() - }) - else { - h2.update(move |state| { - state.conversation.add_tool_result( - id_clone.clone(), - format!("Error: could not read directory: {}", path.display()), - true, - ); - state.pending_tool_calls.retain(|c| c.id != id_clone); - }); - return; - }; - - h2.update(move |state| { - state.conversation.add_tool_result( - id_clone.clone(), - format!("Directory contents:\n{}", files.join("\n")), - false, - ); - state.pending_tool_calls.retain(|c| c.id != id_clone); - - let _ = tx_for_task.send(AiTuiEvent::ContinueAfterTools); - }); - } - } + let _ = tx_for_task.send(AiTuiEvent::ContinueAfterTools); } PermissionResponse::Denied => { - eprintln!("Permission denied for tool call: {:?}", &tool_call); h2.update(move |state| { state .conversation @@ -295,24 +205,10 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender { - eprintln!("Asking for permission for tool call: {:?}", &tool_call); h2.update(move |state| { - let mut tool_call = state.pending_tool_call_mut(&id_clone); - - let Some(tool_call) = tool_call.as_mut() else { - eprintln!("Error getting pending tool call: {:?}", &id_clone); - return; - }; - - eprintln!( - "Setting tool call state to AskingForPermission: {:?}", - &tool_call - ); - tool_call.state = ToolCallState::AskingForPermission; - eprintln!( - "Tool call state set to AskingForPermission: {:?}", - &tool_call - ); + if let Some(tc) = state.pending_tool_call_mut(&id_clone) { + tc.mark_asking(); + } }); } } diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 561d3e6af26..1c99702a3be 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -7,7 +7,7 @@ use std::collections::VecDeque; use tokio::task::AbortHandle; -use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; +use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState, ToolOutcome}; /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] @@ -578,6 +578,20 @@ impl Session { .find(|call| call.id == id) } + /// Record a tool execution result and remove the pending tool call. + pub fn complete_tool_call(&mut self, id: &str, outcome: ToolOutcome) { + match outcome { + ToolOutcome::Success(content) => { + self.conversation + .add_tool_result(id.to_string(), content, false); + } + ToolOutcome::Error(msg) => { + self.conversation.add_tool_result(id.to_string(), msg, true); + } + } + self.pending_tool_calls.retain(|c| c.id != id); + } + /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { match self.interaction.mode { From 2d3885cfaf8b5dc2174517628a02efa3a7bbc10c Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 19:02:20 -0700 Subject: [PATCH 09/52] Close the client-side tool call loop --- crates/atuin-ai/src/tools/mod.rs | 36 ++++++++++++++++++++++++++--- crates/atuin-ai/src/tui/dispatch.rs | 4 ++-- crates/atuin-ai/src/tui/state.rs | 21 +++++++++++++---- crates/atuin-ai/src/tui/view/mod.rs | 8 +++---- 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 0a02d8ff429..ed47d2f5f90 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -1,4 +1,7 @@ -use std::path::{Path, PathBuf}; +use std::{ + io::BufRead, + path::{Path, PathBuf}, +}; use eyre::Result; @@ -209,8 +212,35 @@ impl ReadToolCall { return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); } - match std::fs::read_to_string(&path) { - Ok(content) => ToolOutcome::Success(content), + let file = match std::fs::File::open(&path) { + Ok(file) => file, + Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), + }; + let reader = std::io::BufReader::new(file); + + let relevent_lines = if let Some((start, end)) = self.view_range { + reader + .lines() + .skip(start as usize) + .take((end - start) as usize) + .collect::, _>>() + } else { + reader.lines().collect::, _>>() + }; + + match relevent_lines { + Ok(lines) => { + let joined = lines.join("\n"); + if joined.len() > 100_000 { + ToolOutcome::Error(format!( + "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", + joined.len(), + lines.len() + )) + } else { + ToolOutcome::Success(joined) + } + } Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), } } diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index 0c9537952aa..7d862f35083 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -188,7 +188,7 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender { let outcome = tool_call.tool.execute(); h2.update(move |state| { - state.complete_tool_call(&id_clone, outcome); + state.complete_tool_call(&tool_call, outcome); }); let _ = tx_for_task.send(AiTuiEvent::ContinueAfterTools); } @@ -199,7 +199,7 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender { self.conversation - .add_tool_result(id.to_string(), content, false); + .add_tool_result(pending.id.clone(), content, false); } ToolOutcome::Error(msg) => { - self.conversation.add_tool_result(id.to_string(), msg, true); + self.conversation + .add_tool_result(pending.id.clone(), msg, true); } } - self.pending_tool_calls.retain(|c| c.id != id); + self.pending_tool_calls.retain(|c| c.id != pending.id); } /// Get the footer text for current mode diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 0954ab2f049..0c9a93c33d5 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -157,7 +157,7 @@ fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { element! { View(padding_top: Cells::from(padding)) { Text { - Span(text: "You", style: label_style) + Span(text: " You ", style: label_style.reversed()) } #(for event in events { #(match event { @@ -185,9 +185,9 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { element! { View { Spinner( - label: "Atuin AI", - label_style: label_style, - done_label_style: label_style, + label: " Atuin AI ", + label_style: label_style.reversed(), + done_label_style: label_style.reversed(), hide_checkmark: true, label_first: true, done: !busy, From 1b84eb6e89df8fcea258a4e664d7d9c75ed0ffd8 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 19:08:52 -0700 Subject: [PATCH 10/52] Remove temp debugging --- crates/atuin-ai/src/permissions/walker.rs | 6 ------ crates/atuin-ai/src/tui/dispatch.rs | 14 +++++++------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs index e7313117228..ee26c5598fe 100644 --- a/crates/atuin-ai/src/permissions/walker.rs +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -43,13 +43,10 @@ impl PermissionWalker { to_check.push(global_path.clone()); } - eprintln!("to_check: {:?}", to_check); - let size = to_check.len(); let mut set: JoinSet>> = JoinSet::new(); for (index, path) in to_check.into_iter().enumerate() { - eprintln!("Checking: {:?}", path); set.spawn(async move { match check_for_permissions(&path).await { Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { @@ -66,7 +63,6 @@ impl PermissionWalker { while let Some(result) = set.join_next().await { let result = result?; // JoinErrors result in failure to walk the filesystem - eprintln!("result: {:?}", result); match result { Ok(Some(FoundRuleFile { depth, file })) => { found.push((depth, file)); @@ -87,8 +83,6 @@ impl PermissionWalker { found.sort_by_key(|(depth, _)| *depth); self.rules = found.into_iter().map(|(_, file)| file).collect(); - eprintln!("rules: {:?}", self.rules); - Ok(()) } } diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index 7d862f35083..0f56faaf108 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -155,7 +155,7 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, tx: &mpsc::Sender Date: Tue, 7 Apr 2026 21:06:32 -0700 Subject: [PATCH 11/52] Ensure all local tool calls resolve before resuming stream --- crates/atuin-ai/src/tui/dispatch.rs | 30 ++++++++++++++++++++++------- crates/atuin-ai/src/tui/state.rs | 10 ++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index 0f56faaf108..dd5996b38b8 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -36,7 +36,7 @@ pub(crate) fn dispatch( on_check_tool_permission(handle, tx, id); } AiTuiEvent::SelectPermission(result) => { - on_select_permission(handle, result); + on_select_permission(handle, tx, result); } AiTuiEvent::CancelGeneration => { on_cancel_generation(handle); @@ -189,10 +189,13 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender { + let tx = tx_for_task.clone(); h2.update(move |state| { state .conversation @@ -202,6 +205,10 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender { @@ -215,10 +222,12 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, permission: PermissionResult) { - // Okay, we have permssion information. - // If accepted, we can start executing. - // If denied, we can show an error message. +fn on_select_permission( + handle: &Handle, + tx: &mpsc::Sender, + permission: PermissionResult, +) { + let tx = tx.clone(); handle.update(move |state| { let tool_call = state .pending_tool_calls @@ -232,7 +241,11 @@ fn on_select_permission(handle: &Handle, permission: PermissionResult) match permission { PermissionResult::Allow => { - state.pending_tool_calls.remove(index); + if let Some(call) = state.pending_tool_calls.remove(index) { + if !state.has_unresolved_tool_calls() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + } } PermissionResult::AlwaysAllowInDir => { // @@ -250,6 +263,9 @@ fn on_select_permission(handle: &Handle, permission: PermissionResult) "Permission denied on the user's system".to_string(), true, ); + if !state.has_unresolved_tool_calls() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } } } }); diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index b15c5f83f0e..221f95ea44b 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -603,6 +603,16 @@ impl Session { self.pending_tool_calls.retain(|c| c.id != pending.id); } + /// Returns true if any tool calls are still in CheckingPermissions or AskingForPermission state. + pub fn has_unresolved_tool_calls(&self) -> bool { + self.pending_tool_calls.iter().any(|tc| { + matches!( + tc.state, + ToolCallState::CheckingPermissions | ToolCallState::AskingForPermission + ) + }) + } + /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { match self.interaction.mode { From d52ada197d974544e80051821c051a31ca3d8b07 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 21:26:40 -0700 Subject: [PATCH 12/52] Implement AtuinHistory client-side tool execution Add the history database to AppContext and plumb it through dispatch so client tools can run async database queries. AtuinHistory::execute uses atuin-client's Database::search with fuzzy matching, the first filter mode from the tool call, and configurable limit. Also fix two pre-existing bugs that prevented client tools from working end-to-end: AtuinHistory::matches_rule had a todo!() panic that crashed the permission check task, and on_select_permission Allow was discarding the tool call instead of executing it. --- .atuin/permissions.ai.toml | 2 +- crates/atuin-ai/src/commands/inline.rs | 7 ++ crates/atuin-ai/src/context.rs | 3 + crates/atuin-ai/src/tools/mod.rs | 82 +++++++++++++++++++++-- crates/atuin-ai/src/tui/dispatch.rs | 93 ++++++++++++++++---------- 5 files changed, 147 insertions(+), 40 deletions(-) diff --git a/.atuin/permissions.ai.toml b/.atuin/permissions.ai.toml index 399c89906c6..fe127ab4ee6 100644 --- a/.atuin/permissions.ai.toml +++ b/.atuin/permissions.ai.toml @@ -1,3 +1,3 @@ [permissions] -allow = ["Read"] +allow = ["Read", "AtuinHistory"] diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 2fb5c56fa22..0d1adaff09e 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; @@ -42,10 +43,16 @@ pub(crate) async fn run( ensure_hub_session(settings).await? }; + let history_db_path = PathBuf::from(settings.db_path.as_str()); + let history_db = atuin_client::database::Sqlite::new(history_db_path, settings.local_timeout) + .await + .context("failed to open history database for AI")?; + let ctx = AppContext { endpoint: endpoint.to_string(), token, send_cwd: settings.ai.send_cwd, + history_db: std::sync::Arc::new(history_db), }; let action = run_inline_tui(ctx, initial_command).await?; diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs index 03dc3f891e2..fca4aa606d2 100644 --- a/crates/atuin-ai/src/context.rs +++ b/crates/atuin-ai/src/context.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use atuin_client::distro::detect_linux_distribution; /// Session-scoped context for the AI chat session. @@ -7,6 +9,7 @@ pub(crate) struct AppContext { pub endpoint: String, pub token: String, pub send_cwd: bool, + pub history_db: Arc, } /// Machine identity — computed once per session. diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index ed47d2f5f90..25cd5130a1f 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -112,9 +112,10 @@ impl ClientToolCall { } /// Execute this client-side tool and return the result. - pub fn execute(&self) -> ToolOutcome { + pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { match self { ClientToolCall::Read(tool) => tool.execute(), + ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), } } @@ -367,6 +368,7 @@ impl PermissableToolCall for ShellToolCall { pub(crate) struct AtuinHistoryToolCall { pub filter_modes: Vec, pub query: String, + pub limit: i64, } #[derive(Debug, Clone)] @@ -378,6 +380,18 @@ pub(crate) enum HistorySearchFilterMode { Workspace, } +impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { + fn from(mode: &HistorySearchFilterMode) -> Self { + match mode { + HistorySearchFilterMode::Global => Self::Global, + HistorySearchFilterMode::Host => Self::Host, + HistorySearchFilterMode::Session => Self::Session, + HistorySearchFilterMode::Directory => Self::Directory, + HistorySearchFilterMode::Workspace => Self::Workspace, + } + } +} + impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { type Error = eyre::Error; @@ -407,9 +421,16 @@ impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { .and_then(|v| v.as_str()) .ok_or(eyre::eyre!("Missing query"))?; + let limit = value + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(10) + .clamp(1, 50); + Ok(AtuinHistoryToolCall { filter_modes, query: query.to_string(), + limit, }) } } @@ -420,10 +441,63 @@ impl PermissableToolCall for AtuinHistoryToolCall { } fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "AtuinHistory" { - return false; + rule.tool == "AtuinHistory" + } +} + +impl AtuinHistoryToolCall { + pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { + use atuin_client::database::{self, Database as _, OptFilters}; + use atuin_client::settings::SearchMode; + + let context = match database::current_context().await { + Ok(ctx) => ctx, + Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), + }; + + let filter_mode = self + .filter_modes + .first() + .map(atuin_client::settings::FilterMode::from) + .unwrap_or(atuin_client::settings::FilterMode::Global); + + let filter_options = OptFilters { + limit: Some(self.limit), + ..Default::default() + }; + + let results = match db + .search( + SearchMode::Fuzzy, + filter_mode, + &context, + &self.query, + filter_options, + ) + .await + { + Ok(results) => results, + Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), + }; + + if results.is_empty() { + return ToolOutcome::Success("No matching history entries found.".to_string()); } - todo!() + let formatted: Vec = results + .iter() + .enumerate() + .map(|(i, h)| { + format!( + "{}. `{}` (cwd: {}, exit: {})", + i + 1, + h.command, + h.cwd, + h.exit + ) + }) + .collect(); + + ToolOutcome::Success(formatted.join("\n")) } } diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index dd5996b38b8..44a11a84e5b 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -33,10 +33,10 @@ pub(crate) fn dispatch( on_slash_command(handle, cmd); } AiTuiEvent::CheckToolCallPermission(id) => { - on_check_tool_permission(handle, tx, id); + on_check_tool_permission(handle, tx, app_ctx, id); } AiTuiEvent::SelectPermission(result) => { - on_select_permission(handle, tx, result); + on_select_permission(handle, tx, app_ctx, result); } AiTuiEvent::CancelGeneration => { on_cancel_generation(handle); @@ -144,10 +144,16 @@ fn on_slash_command(handle: &Handle, command: String) { }); } -fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, id: String) { +fn on_check_tool_permission( + handle: &Handle, + tx: &mpsc::Sender, + app_ctx: &AppContext, + id: String, +) { let h2 = handle.clone(); let tx_for_task = tx.clone(); let id_clone = id.clone(); + let db = app_ctx.history_db.clone(); tokio::spawn(async move { // 1. Fetch the pending tool call @@ -155,7 +161,6 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, tx: &mpsc::Sender { - let outcome = tool_call.tool.execute(); + let outcome = tool_call.tool.execute(&db).await; h2.update(move |state| { state.complete_tool_call(&tool_call, outcome); if !state.has_unresolved_tool_calls() { @@ -225,35 +224,59 @@ fn on_check_tool_permission(handle: &Handle, tx: &mpsc::Sender, tx: &mpsc::Sender, + app_ctx: &AppContext, permission: PermissionResult, ) { let tx = tx.clone(); - handle.update(move |state| { - let tool_call = state - .pending_tool_calls - .iter() - .enumerate() - .find(|(_, call)| call.state == ToolCallState::AskingForPermission); + let h2 = handle.clone(); + let db = app_ctx.history_db.clone(); - let Some((index, _)) = tool_call else { - return; - }; + match permission { + PermissionResult::Allow => { + // Fetch the tool call that's asking for permission, then execute it async + let h3 = h2.clone(); + let tx2 = tx.clone(); + tokio::spawn(async move { + let Ok(Some(tool_call)) = h3 + .fetch(move |state| { + state + .pending_tool_calls + .iter() + .find(|tc| tc.state == ToolCallState::AskingForPermission) + .cloned() + }) + .await + else { + return; + }; - match permission { - PermissionResult::Allow => { - if let Some(call) = state.pending_tool_calls.remove(index) { + let outcome = tool_call.tool.execute(&db).await; + h3.update(move |state| { + state.complete_tool_call(&tool_call, outcome); if !state.has_unresolved_tool_calls() { - let _ = tx.send(AiTuiEvent::ContinueAfterTools); + let _ = tx2.send(AiTuiEvent::ContinueAfterTools); } - } - } - PermissionResult::AlwaysAllowInDir => { - // - } - PermissionResult::AlwaysAllow => { - // - } - PermissionResult::Deny => { + }); + }); + } + PermissionResult::AlwaysAllowInDir => { + // + } + PermissionResult::AlwaysAllow => { + // + } + PermissionResult::Deny => { + h2.update(move |state| { + let tool_call = state + .pending_tool_calls + .iter() + .enumerate() + .find(|(_, call)| call.state == ToolCallState::AskingForPermission); + + let Some((index, _)) = tool_call else { + return; + }; + let Some(call) = state.pending_tool_calls.remove(index) else { return; }; @@ -266,9 +289,9 @@ fn on_select_permission( if !state.has_unresolved_tool_calls() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); } - } + }); } - }); + } } fn on_cancel_generation(handle: &Handle) { From c29add214d6eb3afc2d537c19cb2e41887193ba7 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 21:29:29 -0700 Subject: [PATCH 13/52] Remove permissions file --- .atuin/permissions.ai.toml | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .atuin/permissions.ai.toml diff --git a/.atuin/permissions.ai.toml b/.atuin/permissions.ai.toml deleted file mode 100644 index fe127ab4ee6..00000000000 --- a/.atuin/permissions.ai.toml +++ /dev/null @@ -1,3 +0,0 @@ -[permissions] - -allow = ["Read", "AtuinHistory"] From 885d32aeaa6120f8a0e96be287908c5819a245c1 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 21:29:37 -0700 Subject: [PATCH 14/52] ignore permissions files --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 2b1d63f989c..c4ccffb0795 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ ui/backend/target ui/backend/gen sqlite-server.db* + +.atuin/permissions.*.toml From 0ecfa746c8f73ff226ed167d3e3a06258ba127eb Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 7 Apr 2026 21:47:46 -0700 Subject: [PATCH 15/52] Add timestamp and duration to AtuinHistory results Format results with local timezone timestamp and human-readable duration (e.g. 3s, 1m23s, 120ms) alongside command, cwd, and exit code. --- Cargo.lock | 1 + crates/atuin-ai/Cargo.toml | 1 + crates/atuin-ai/src/tools/mod.rs | 50 ++++++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 79b655c4447..d195b9b79f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -290,6 +290,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.18", + "time", "tokio", "toml", "tracing", diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 8c2d02e5a9f..d6c6a302aa6 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -44,6 +44,7 @@ ratatui-core = "0.1" ratatui-widgets = "0.3" thiserror = { workspace = true } regex = { workspace = true } +time = { workspace = true } toml = "1.1" typed-builder = { workspace = true } diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 25cd5130a1f..72b573d6716 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -449,6 +449,7 @@ impl AtuinHistoryToolCall { pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { use atuin_client::database::{self, Database as _, OptFilters}; use atuin_client::settings::SearchMode; + use time::UtcOffset; let context = match database::current_context().await { Ok(ctx) => ctx, @@ -484,16 +485,33 @@ impl AtuinHistoryToolCall { return ToolOutcome::Success("No matching history entries found.".to_string()); } + let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + let formatted: Vec = results .iter() .enumerate() .map(|(i, h)| { + let ts = h.timestamp.to_offset(local_offset); + let time_str = format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ); + + let duration_str = format_duration(h.duration); + format!( - "{}. `{}` (cwd: {}, exit: {})", + "{}. `{}` [{}] ({}, exit: {}){}", i + 1, h.command, + time_str, h.cwd, - h.exit + h.exit, + duration_str, ) }) .collect(); @@ -501,3 +519,31 @@ impl AtuinHistoryToolCall { ToolOutcome::Success(formatted.join("\n")) } } + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} From 5c8b22b7a5ac39f22b655c93930da8842a0b45ff Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Wed, 8 Apr 2026 10:35:15 -0700 Subject: [PATCH 16/52] Add tree-sitter shell command parsing for scoped permissions Parse shell commands with tree-sitter-bash (POSIX family) and tree-sitter-fish to extract all subcommands from compound expressions (&&, ||, pipes, subshells, $(...), etc). Wire into ShellToolCall::matches_rule for scope matching. Scope matching supports three wildcard styles: - `ls *` (space before *): word-boundary match - `ls*` (no space): prefix/glob match - `git * amend` (middle *): matches zero+ words between segments Also fixes: variable assignments excluded from command text, fallback parser double-split bug, and shell field parsed from tool call input for correct parser selection. --- Cargo.lock | 50 ++ crates/atuin-ai/Cargo.toml | 3 + crates/atuin-ai/src/permissions/mod.rs | 1 + crates/atuin-ai/src/permissions/resolver.rs | 3 +- crates/atuin-ai/src/permissions/shell.rs | 687 ++++++++++++++++++++ crates/atuin-ai/src/tools/mod.rs | 23 +- 6 files changed, 757 insertions(+), 10 deletions(-) create mode 100644 crates/atuin-ai/src/permissions/shell.rs diff --git a/Cargo.lock b/Cargo.lock index d195b9b79f5..12088674605 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -296,6 +296,9 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", + "tree-sitter", + "tree-sitter-bash", + "tree-sitter-fish", "tui-textarea-2", "typed-builder 0.18.2", "unicode-width 0.2.2", @@ -4303,6 +4306,7 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ + "indexmap 2.13.0", "itoa", "memchr", "serde", @@ -4805,6 +4809,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + [[package]] name = "stringprep" version = "0.1.5" @@ -5476,6 +5486,46 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "tree-sitter" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887bd495d0582c5e3e0d8ece2233666169fa56a9644d172fc22ad179ab2d0538" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "serde_json", + "streaming-iterator", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-bash" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5ec769279cc91b561d3df0d8a5deb26b0ad40d183127f409494d6d8fc53062" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-fish" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014e3b299f251e9c2e372e3b5e1b0323ef21196e9aa2e90a5bc1f6130cbe8b18" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" + [[package]] name = "tree_magic_mini" version = "3.2.2" diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index d6c6a302aa6..05984c68be0 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -46,6 +46,9 @@ thiserror = { workspace = true } regex = { workspace = true } time = { workspace = true } toml = "1.1" +tree-sitter = "0.26.8" +tree-sitter-bash = "0.25.1" +tree-sitter-fish = "3.6.0" typed-builder = { workspace = true } [dev-dependencies] diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs index 5920e7ac893..b7e4814e91d 100644 --- a/crates/atuin-ai/src/permissions/mod.rs +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -2,4 +2,5 @@ pub(crate) mod check; pub(crate) mod file; pub(crate) mod resolver; pub(crate) mod rule; +pub(crate) mod shell; pub(crate) mod walker; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs index 6f26ed4a3b8..31afbadbce1 100644 --- a/crates/atuin-ai/src/permissions/resolver.rs +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -6,14 +6,13 @@ use crate::permissions::check::{PermissionChecker, PermissionRequest, Permission use crate::permissions::walker::PermissionWalker; use crate::tools::ClientToolCall; +/// Resolves permissions for client tool calls by walking the filesystem to find permission files, pub(crate) struct PermissionResolver { checker: PermissionChecker, working_dir: PathBuf, } impl PermissionResolver { - /// Walk the filesystem from `working_dir` to find permission files, - /// then build a checker from them. pub async fn new(working_dir: PathBuf, global_dir: Option) -> Result { let mut walker = PermissionWalker::new(working_dir.clone(), global_dir); walker.walk().await?; diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs new file mode 100644 index 00000000000..077201ac98c --- /dev/null +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -0,0 +1,687 @@ +use tree_sitter::{Parser, Tree}; + +/// Extracted command info from a shell command string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ShellCommand { + /// The command name (first word), e.g. "git" + pub name: String, + /// The full invocation including arguments, e.g. "git commit -m msg" + pub full: String, +} + +/// A parsed shell command with all subcommands extracted. +#[derive(Debug)] +pub(crate) struct ParsedShellCommand { + pub subcommands: Vec, +} + +/// Supported shell families for parsing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ShellKind { + /// POSIX sh, bash, zsh — all share similar syntax + Posix, + /// fish shell + Fish, + /// nushell or unknown — fallback to word-level extraction + Other, +} + +impl ShellKind { + pub(crate) fn from_shell_name(name: &str) -> Self { + match name { + "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, + "fish" => Self::Fish, + _ => Self::Other, + } + } +} + +/// Parse a shell command string and extract all subcommands. +pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { + match shell { + ShellKind::Posix => parse_posix(code), + ShellKind::Fish => parse_fish(code), + ShellKind::Other => parse_fallback(code), + } +} + +// ──────────────────────────────────────────────────────────────── +// POSIX (bash/zsh/sh) parser +// ──────────────────────────────────────────────────────────────── + +fn bash_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_bash::LANGUAGE.into()) + .expect("failed to set bash language"); + parser +} + +fn parse_posix(code: &str) -> ParsedShellCommand { + let mut parser = bash_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_bash_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } +} + +/// Leaf node kinds that never contain nested commands. +const BASH_LEAVES: &[&str] = &[ + "command_name", + "word", + "number", + "simple_expansion", + "expansion", + "arithmetic_expansion", + "ansi_c_string", + "special_variable_name", + "variable_name", + "file_descriptor", + "heredoc_body", + "heredoc_start", + "regex", + "heredoc_redirect", + "concatenation", +]; + +fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { + walk_bash_node(tree.root_node(), source, commands); +} + +fn walk_bash_node(node: tree_sitter::Node, source: &[u8], commands: &mut Vec) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_bash_command(node, source) { + commands.push(cmd); + } + // Descend into all non-leaf children to find nested commands + // (e.g. command_substitution inside a string inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if !BASH_LEAVES.contains(&child.kind()) { + walk_bash_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_bash_node(child, source, commands); + } + } + } +} + +/// Extract the full command string and name from a bash `command` node. +fn extract_bash_command(node: tree_sitter::Node, source: &[u8]) -> Option { + // A `command` node has children like: + // variable_assignment* command_name argument* redirect* + // We want the command_name and all arguments (skipping assignments and redirects). + let mut name = None; + let mut name_start = None; + let mut arg_end = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" => { + name = child.utf8_text(source).ok().map(|s| s.to_string()); + name_start = Some(child.start_byte()); + } + "word" + | "string" + | "raw_string" + | "concatenation" + | "number" + | "simple_expansion" + | "expansion" + | "arithmetic_expansion" + | "ansi_c_string" + | "process_substitution" => { + arg_end = Some(child.end_byte()); + } + _ => {} + } + } + + let name = name?; + let full = if let (Some(start), Some(end)) = (name_start, arg_end) { + std::str::from_utf8(&source[start..end]).ok()?.to_string() + } else { + name.clone() + }; + + Some(ShellCommand { name, full }) +} + +// ──────────────────────────────────────────────────────────────── +// Fish parser +// ──────────────────────────────────────────────────────────────── + +fn fish_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_fish::language()) + .expect("failed to set fish language"); + parser +} + +fn parse_fish(code: &str) -> ParsedShellCommand { + let mut parser = fish_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_fish_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } +} + +const FISH_COMPOUND: &[&str] = &[ + "conditional_execution", + "pipe", + "job", + "command_substitution", + "block", + "for_statement", + "while_statement", + "if_statement", + "switch_statement", + "function_definition", + "begin_statement", + "redirected_statement", +]; + +fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { + walk_fish_node(tree.root_node(), source, commands); +} + +fn walk_fish_node(node: tree_sitter::Node, source: &[u8], commands: &mut Vec) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_fish_command(node, source) { + commands.push(cmd); + } + // Still descend into compound children (e.g. command_substitution inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if FISH_COMPOUND.contains(&child.kind()) { + walk_fish_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_fish_node(child, source, commands); + } + } + } +} + +fn extract_fish_command(node: tree_sitter::Node, source: &[u8]) -> Option { + // In fish, a `command` node has: + // name (command_name or word) followed by arguments (word, string, etc.) + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" | "word" => { + let text = child.utf8_text(source).ok()?.to_string(); + if name.is_none() { + name = Some(text); + } + } + "string" + | "concatenation" + | "command_substitution" + | "escape_sequence" + | "double_quote_string" + | "single_quote_string" => {} + _ => {} + } + } + + let name = name?; + // Get the full text of the command node + let full = node.utf8_text(source).ok()?.trim().to_string(); + + Some(ShellCommand { name, full }) +} + +// ──────────────────────────────────────────────────────────────── +// Fallback (word-level extraction for nushell / unknown shells) +// ──────────────────────────────────────────────────────────────── + +fn parse_fallback(code: &str) -> ParsedShellCommand { + // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. + // This is intentionally simple — for unknown shells we can't do better. + let mut commands = Vec::new(); + let mut segment = String::new(); + let mut chars = code.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + ';' => { + push_segment(&mut segment, &mut commands); + } + '|' => { + if chars.peek() == Some(&'|') { + chars.next(); + } + push_segment(&mut segment, &mut commands); + } + '&' if chars.peek() == Some(&'&') => { + chars.next(); + push_segment(&mut segment, &mut commands); + } + _ => segment.push(c), + } + } + push_segment(&mut segment, &mut commands); + + ParsedShellCommand { + subcommands: commands, + } +} + +fn push_segment(segment: &mut String, commands: &mut Vec) { + let trimmed = segment.trim(); + if !trimmed.is_empty() { + if let Some(name) = trimmed.split_whitespace().next() { + commands.push(ShellCommand { + name: name.to_string(), + full: trimmed.to_string(), + }); + } + } + segment.clear(); +} + +// ──────────────────────────────────────────────────────────────── +// Scope matching +// ──────────────────────────────────────────────────────────────── + +/// Check if any of the extracted subcommands match the given scope pattern. +/// +/// Matching semantics depend on where the `*` wildcard appears: +/// - `*` alone — matches everything +/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` +/// - `git commit *` — matches `git commit -m "msg"` (word boundary) +/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) +/// - `rm` (no wildcard) — matches exactly `rm` +/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) +pub(crate) fn any_subcommand_matches(subcommands: &[ShellCommand], scope: &str) -> bool { + let scope = scope.trim(); + + if scope == "*" { + return true; + } + + if let Some(prefix) = scope.strip_suffix(" *") { + // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` + return subcommands.iter().any(|cmd| { + if prefix.is_empty() { + return true; + } + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); + cmd_words.len() >= prefix_words.len() + && cmd_words[..prefix_words.len()] == prefix_words[..] + }); + } + + if scope.ends_with('*') { + // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. + let prefix = &scope[..scope.len() - 1]; + return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); + } + + if scope.contains('*') { + // Middle wildcard: `git * amend` — each `*` matches zero or more words + return subcommands + .iter() + .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); + } + + // No wildcard: word-boundary prefix match + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + subcommands.iter().any(|cmd| { + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + cmd_words.len() >= scope_words.len() && cmd_words[..scope_words.len()] == scope_words[..] + }) +} + +/// Match a scope pattern containing `*` wildcards against a sequence of words. +/// Each `*` matches zero or more words. Consecutive `*` collapse into one. +fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { + let parts: Vec<&str> = scope.split('*').collect(); + if parts.len() == 1 { + // No wildcard (shouldn't reach here, but handle it) + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; + } + + // Each segment between * is a sequence of literal words that must appear in order. + // Walk through `words` consuming segments left to right. + let mut word_idx = 0; + + for (i, part) in parts.iter().enumerate() { + let segment_words: Vec<&str> = part.split_whitespace().collect(); + if segment_words.is_empty() { + continue; + } + + // Find the segment words starting from word_idx + if i == 0 { + // First segment must match at the start + if words.len() < segment_words.len() + || words[..segment_words.len()] != segment_words[..] + { + return false; + } + word_idx = segment_words.len(); + } else if i == parts.len() - 1 { + // Last segment must match at the end + if words.len() - word_idx < segment_words.len() { + return false; + } + let start = words.len() - segment_words.len(); + return words[start..] == segment_words[..]; + } else { + // Middle segment: find it anywhere after word_idx + let found = find_subslice(&words[word_idx..], &segment_words); + match found { + Some(idx) => word_idx += idx + segment_words.len(), + None => return false, + } + } + } + + true +} + +/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. +fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option { + if needle.is_empty() { + return Some(0); + } + if haystack.len() < needle.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) +} + +// ──────────────────────────────────────────────────────────────── +// Tests +// ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.full.as_str()).collect() + } + + #[test] + fn simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn pipeline() { + let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); + } + + #[test] + fn command_chaining() { + let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git", "git"]); + assert_eq!( + fulls(&result.subcommands), + vec!["git add .", "git commit -m 'hi'"] + ); + } + + #[test] + fn command_substitution() { + let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[test] + fn backtick_substitution() { + let result = parse_shell_command("echo `date`", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[test] + fn subshell() { + let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); + } + + #[test] + fn semicolon_separated() { + let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); + } + + #[test] + fn for_loop() { + let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); + assert!(names(&result.subcommands).contains(&"cat")); + } + + #[test] + fn if_statement() { + let result = parse_shell_command( + "if [ -f foo ]; then cat foo; else echo nope; fi", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn scope_matching_wildcard() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "*")); + } + + #[test] + fn scope_matching_prefix() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "git commit *")); + assert!(any_subcommand_matches(&commands, "git commit")); + assert!(!any_subcommand_matches(&commands, "git push *")); + assert!(!any_subcommand_matches(&commands, "git push")); + assert!(any_subcommand_matches(&commands, "npm *")); + } + + #[test] + fn scope_word_boundary_vs_glob() { + let commands = vec![ + ShellCommand { + name: "ls".into(), + full: "ls -a".into(), + }, + ShellCommand { + name: "lsof".into(), + full: "lsof -i :3000".into(), + }, + ]; + // `ls *` — word boundary: matches `ls -a` but not `lsof` + assert!(any_subcommand_matches(&commands, "ls *")); + assert!(!any_subcommand_matches(&commands, "cat *")); + assert!(any_subcommand_matches(&commands, "lsof *")); + + // `ls*` — glob/prefix: matches both `ls -a` and `lsof` + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn scope_exact_match() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + assert!(any_subcommand_matches(&commands, "ls")); + assert!(!any_subcommand_matches(&commands, "cat")); + } + + #[test] + fn nested_substitution() { + let result = parse_shell_command( + "echo \"Result: $(git log --oneline | head -1)\"", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + assert!(n.contains(&"head"), "should contain head: {n:?}"); + } + + #[test] + fn fallback_splits_correctly() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + let n = names(&result.subcommands); + assert!(n.contains(&"ls"), "should contain ls: {n:?}"); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn fish_simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); + assert_eq!(names(&result.subcommands), vec!["ls"]); + } + + #[test] + fn fish_conditional() { + let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[test] + fn fish_command_substitution() { + let result = parse_shell_command("echo (date)", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[test] + fn variable_assignment_excluded() { + let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn variable_assignment_multiple() { + let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git"]); + assert_eq!(fulls(&result.subcommands), vec!["git status"]); + } + + #[test] + fn fallback_double_ampersand_and_pipe_pipe() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); + assert_eq!( + fulls(&result.subcommands), + vec!["ls", "cat foo", "echo fail"] + ); + } + + #[test] + fn fallback_pipe_without_double() { + let result = parse_shell_command("ls | grep foo", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); + assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); + } + + #[test] + fn scope_middle_wildcard() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit -m amend".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * amend")); + assert!(any_subcommand_matches(&commands, "git commit * amend")); + assert!(!any_subcommand_matches(&commands, "git push * amend")); + } + + #[test] + fn scope_middle_wildcard_zero_words() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + // `*` matches zero words, so `git * commit` should match `git commit` + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn scope_leading_wildcard() { + let commands = vec![ShellCommand { + name: "docker".into(), + full: "docker run --rm alpine".into(), + }]; + assert!(any_subcommand_matches(&commands, "* alpine")); + assert!(!any_subcommand_matches(&commands, "* ubuntu")); + } + + #[test] + fn scope_multiple_wildcards() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git rebase -i HEAD~5".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * -i * HEAD~5")); + assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); + } +} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 72b573d6716..242a470a16b 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -322,6 +322,7 @@ impl PermissableToolCall for WriteToolCall { pub(crate) struct ShellToolCall { pub dir: Option, pub command: String, + pub shell: String, } impl TryFrom<&serde_json::Value> for ShellToolCall { @@ -335,9 +336,16 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { .and_then(|v| v.as_str()) .ok_or(eyre::eyre!("Missing command"))?; + let shell = value + .get("shell") + .and_then(|v| v.as_str()) + .unwrap_or("bash") + .to_string(); + Ok(ShellToolCall { dir: dir.map(PathBuf::from), command: command.to_string(), + shell, }) } } @@ -352,15 +360,14 @@ impl PermissableToolCall for ShellToolCall { return false; } - if let Some(scope) = rule.scope.as_ref() { - if scope == "*" { - return true; - } - - todo!("split command into subcommands, check each"); - } + let Some(scope) = rule.scope.as_ref() else { + // Shell without scope matches all shell commands + return true; + }; - true + let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); + let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); + crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, scope) } } From 4271f2086d61adee3a053737c42b4b4516e3e442 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Wed, 8 Apr 2026 10:50:54 -0700 Subject: [PATCH 17/52] Stress-test shell parsing and fix concatenation gap Add 77 adversarial tests covering nested substitutions, variable assignments, control flow, redirections, subshells, background jobs, real-world commands, fish-specific syntax, and scope matching edge cases. Fix: remove `concatenation` from BASH_LEAVES so that subcommands inside argument concatenations (e.g. `make -j$(nproc)`) are properly extracted. Known limitations verified by tests: - `find -exec` body is opaque to tree-sitter (not parsed as commands) - `[` in test conditions is not extracted as a command - `eval`/`exec`/`source` argument bodies are not recursively parsed --- crates/atuin-ai/src/permissions/shell.rs | 607 ++++++++++++++++++++++- 1 file changed, 606 insertions(+), 1 deletion(-) diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs index 077201ac98c..c79e7946169 100644 --- a/crates/atuin-ai/src/permissions/shell.rs +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -86,7 +86,6 @@ const BASH_LEAVES: &[&str] = &[ "heredoc_start", "regex", "heredoc_redirect", - "concatenation", ]; fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { @@ -685,3 +684,609 @@ mod tests { assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); } } + +#[cfg(test)] +mod adversarial { + use super::*; + + /// Helper: parse with POSIX and return sorted unique command names + fn posix(code: &str) -> Vec { + let mut n: Vec = parse_shell_command(code, ShellKind::Posix) + .subcommands + .iter() + .map(|c| c.name.clone()) + .collect(); + n.sort(); + n.dedup(); + n + } + + /// Helper: parse with fish and return sorted unique command names + fn fish(code: &str) -> Vec { + let mut n: Vec = parse_shell_command(code, ShellKind::Fish) + .subcommands + .iter() + .map(|c| c.name.clone()) + .collect(); + n.sort(); + n.dedup(); + n + } + + fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + /// Helper: assert that parsing POSIX extracts all expected command names + fn assert_posix(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Posix); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + fn assert_fish(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Fish); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "Fish parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 1: Basic compounds + // ──────────────────────────────────────────────────────────── + + #[test] + fn a01_triple_chain() { + assert_posix("a && b && c", &["a", "b", "c"]); + } + + #[test] + fn a02_or_chain() { + assert_posix("a || b || c", &["a", "b", "c"]); + } + + #[test] + fn a03_mixed_chain() { + assert_posix("a && b || c && d", &["a", "b", "c", "d"]); + } + + #[test] + fn a04_long_pipeline() { + assert_posix( + "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", + &["cat", "grep", "awk", "sort", "uniq"], + ); + } + + #[test] + fn a05_semicolons() { + assert_posix("a; b; c; d", &["a", "b", "c", "d"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 2: Nested substitution + // ──────────────────────────────────────────────────────────── + + #[test] + fn a06_nested_dollar() { + assert_posix( + "echo $(basename $(dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn a07_deeply_nested() { + // 4 nested echos, all should be extracted + assert_posix( + "echo $(echo $(echo $(echo deep)))", + &["echo", "echo", "echo", "echo"], + ); + } + + #[test] + fn a08_backtick_in_echo() { + assert_posix("echo `hostname`", &["echo", "hostname"]); + } + + #[test] + fn a09_mixed_substitutions() { + assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 3: Subshells and grouping + // ──────────────────────────────────────────────────────────── + + #[test] + fn a10_subshell_chain() { + assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); + } + + #[test] + fn a11_nested_subshells() { + assert_posix("( (inner_cmd) )", &["inner_cmd"]); + } + + #[test] + fn a12_brace_group() { + assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 4: Variable assignments + // ──────────────────────────────────────────────────────────── + + #[test] + fn a13_single_var_assignment() { + let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["ls"]); + assert_eq!(result.subcommands[0].full, "ls"); + } + + #[test] + fn a14_multiple_var_assignments() { + let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["git"]); + assert_eq!(result.subcommands[0].full, "git status"); + } + + #[test] + fn a15_var_assignment_no_command() { + // Variable assignment only — no command to extract + assert_posix("FOO=bar", &[]); + } + + #[test] + fn a16_var_assignment_in_pipeline() { + assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 5: Control flow + // ──────────────────────────────────────────────────────────── + + #[test] + fn a17_if_then_else() { + assert_posix( + "if [ -f foo ]; then cat foo; else echo missing; fi", + &["cat", "echo"], + ); + } + + #[test] + fn a18_elif_chain() { + // Two cat commands (then + elif branch), one echo (else branch). + // [ is part of the test_condition, not extracted as a command. + assert_posix( + "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", + &["cat", "cat", "echo"], + ); + } + + #[test] + fn a19_for_loop() { + assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); + } + + #[test] + fn a20_while_loop() { + // read in the condition is a real command + assert_posix( + "while read line; do echo \"$line\"; done < input.txt", + &["echo", "read"], + ); + } + + #[test] + fn f07_if_statement() { + // test in if-condition is a real command + assert_fish( + "if test -f foo; cat foo; else; echo missing; end", + &["cat", "echo", "test"], + ); + } + + #[test] + fn f09_while_loop() { + // `true` in the condition is a real command + assert_fish( + "while true; echo tick; sleep 1; end", + &["echo", "sleep", "true"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 6: Redirections + // ──────────────────────────────────────────────────────────── + + #[test] + fn a23_redirect_out() { + assert_posix("ls > output.txt", &["ls"]); + } + + #[test] + fn a24_redirect_append() { + assert_posix("ls >> output.txt 2>&1", &["ls"]); + } + + #[test] + fn a25_here_string() { + assert_posix("grep foo <<< \"hello world\"", &["grep"]); + } + + #[test] + fn a26_redirect_in_pipeline() { + assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); + } + + #[test] + fn a27_process_substitution() { + assert_posix( + "diff <(sort a.txt) <(sort b.txt)", + &["diff", "sort", "sort"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 7: Function definitions + // ──────────────────────────────────────────────────────────── + + #[test] + fn a28_function_def() { + assert_posix("foo() { echo hello; }", &["echo"]); + } + + #[test] + fn a29_function_with_subshell() { + assert_posix( + "build() { cargo build && cargo test; }", + &["cargo", "cargo"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 8: Edge cases — empties, weird quoting + // ──────────────────────────────────────────────────────────── + + #[test] + fn a30_empty_string() { + let result = parse_shell_command("", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a31_whitespace_only() { + let result = parse_shell_command(" \t \n ", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a32_single_command_no_args() { + assert_posix("ls", &["ls"]); + } + + #[test] + fn a33_command_with_single_quotes() { + assert_posix("echo 'hello world'", &["echo"]); + } + + #[test] + fn a34_command_with_double_quotes() { + assert_posix("echo \"hello world\"", &["echo"]); + } + + #[test] + fn a35_escaped_spaces() { + // ls\ -la is a single word in bash, not "ls" with flag "-la" + assert_posix("ls\\ -la", &["ls\\ -la"]); + } + + #[test] + fn a36_command_with_dollar_var() { + assert_posix("echo $HOME/.bashrc", &["echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 9: Background jobs and coproc + // ──────────────────────────────────────────────────────────── + + #[test] + fn a37_background_job() { + assert_posix("sleep 10 &", &["sleep"]); + } + + #[test] + fn a38_background_chain() { + assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 10: Real-world complex commands + // ──────────────────────────────────────────────────────────── + + #[test] + fn a39_docker_build_and_run() { + assert_posix( + "docker build -t app . && docker run --rm app npm test", + &["docker", "docker"], + ); + } + + #[test] + fn a40_git_rebase_interactive() { + assert_posix( + "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", + &["git"], + ); + } + + #[test] + fn a41_find_with_exec() { + // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. + // This is a known limitation: args to -exec/-execdir are opaque to the parser. + assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); + } + + #[test] + fn a42_curl_pipe_sh() { + assert_posix( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn a43_xargs() { + assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); + } + + #[test] + fn a44_npm_script_chain() { + assert_posix( + "npm run build && npm run test && npm run lint", + &["npm", "npm", "npm"], + ); + } + + #[test] + fn a45_make_with_redirect() { + assert_posix( + "make -j$(nproc) 2>&1 | tee build.log", + &["make", "nproc", "tee"], + ); + } + + #[test] + fn a46_sudo_chain() { + assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); + } + + #[test] + fn a47_here_doc_with_subcommand() { + assert_posix("cat < output.txt", &["ls"]); + } + + #[test] + fn f13_redirect_append() { + assert_fish("ls >> output.txt", &["ls"]); + } + + #[test] + fn f14_here_string() { + assert_fish("grep foo <<< \"hello\"", &["grep"]); + } + + #[test] + fn f15_curl_pipe() { + assert_fish( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn f16_double_ampersand() { + assert_fish("git add . && git commit -m hi", &["git", "git"]); + } + + #[test] + fn f17_double_pipe() { + assert_fish("test -f foo || echo missing", &["test", "echo"]); + } + + #[test] + fn f18_empty() { + let result = parse_shell_command("", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn f19_whitespace() { + let result = parse_shell_command(" ", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + // ──────────────────────────────────────────────────────────── + // Level 12: Scope matching adversarial + // ──────────────────────────────────────────────────────────── + + #[test] + fn s01_empty_scope() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // Empty scope matches everything (nothing to constrain) + assert!(any_subcommand_matches(&commands, "")); + } + + #[test] + fn s03_only_wildcard_space_star() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // " *" with empty prefix = match anything + assert!(any_subcommand_matches(&commands, " *")); + } + + #[test] + fn s04_glob_matches_empty() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // `ls*` matches `ls` (prefix match with nothing after) + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn s05_middle_wildcard_empty_match() { + // `git * commit` matches `git commit` (* = zero words) + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn s06_consecutive_wildcards() { + // `git ** commit` should behave like `git * commit` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git ** commit")); + } + + #[test] + fn s07_case_sensitivity() { + let commands = vec![ShellCommand { + name: "LS".into(), + full: "LS -la".into(), + }]; + assert!(!any_subcommand_matches(&commands, "ls")); + assert!(any_subcommand_matches(&commands, "LS")); + } + + #[test] + fn s08_multi_word_exact_no_subcommand() { + // `git commit` should not match `git commit-amend` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit-amend".into(), + }]; + assert!(!any_subcommand_matches(&commands, "git commit")); + } +} From e55d8986f0289823f9f1c8fcf303a91e412e330a Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Wed, 8 Apr 2026 17:14:52 -0700 Subject: [PATCH 18/52] implement client-side shell command execution with VT100 preview Shell tool calls now execute locally with a streaming VT100 preview in the TUI. The full stdout and stderr are captured separately and sent to the LLM as structured results with exit code and duration. Key changes: - ToolOutcome::Structured variant with separated stdout/stderr/exit code/duration - execute_shell_command_streaming uses vt100 crate for ANSI/progress bar handling - Shared execute_tool dispatch eliminates duplicated shell execution paths - begin_tool_call/finish_tool_call for ToolCall persistence in chat output - Ctrl+C interrupts running commands instead of exiting the app - Viewport component in eye_declare for fixed-height tail rendering --- Cargo.lock | 26 +- crates/atuin-ai/Cargo.toml | 1 + crates/atuin-ai/src/tools/descriptor.rs | 2 +- crates/atuin-ai/src/tools/mod.rs | 249 +++++++++++++++++- .../atuin-ai/src/tui/components/atuin_ai.rs | 18 +- crates/atuin-ai/src/tui/dispatch.rs | 192 ++++++++++++-- crates/atuin-ai/src/tui/events.rs | 2 + crates/atuin-ai/src/tui/state.rs | 49 ++-- crates/atuin-ai/src/tui/view/mod.rs | 84 +++++- 9 files changed, 572 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 12088674605..2490ade5e24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -303,6 +303,7 @@ dependencies = [ "typed-builder 0.18.2", "unicode-width 0.2.2", "uuid", + "vt100 0.16.2", ] [[package]] @@ -436,7 +437,7 @@ dependencies = [ "eyre", "portable-pty", "signal-hook", - "vt100", + "vt100 0.15.2", ] [[package]] @@ -5769,7 +5770,18 @@ dependencies = [ "itoa", "log", "unicode-width 0.1.14", - "vte", + "vte 0.11.1", +] + +[[package]] +name = "vt100" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ff75fb8fa83e609e685106df4faeffdf3a735d3c74ebce97ec557d5d36fd9" +dependencies = [ + "itoa", + "unicode-width 0.2.2", + "vte 0.15.0", ] [[package]] @@ -5783,6 +5795,16 @@ dependencies = [ "vte_generate_state_changes", ] +[[package]] +name = "vte" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5924018406ce0063cd67f8e008104968b74b563ee1b85dde3ed1f7cb87d3dbd" +dependencies = [ + "arrayvec", + "memchr", +] + [[package]] name = "vte_generate_state_changes" version = "0.1.2" diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 05984c68be0..fab7e12370d 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -50,6 +50,7 @@ tree-sitter = "0.26.8" tree-sitter-bash = "0.25.1" tree-sitter-fish = "3.6.0" typed-builder = { workspace = true } +vt100 = "0.16" [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs index 4518c88a6f0..740f20c6ec9 100644 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -35,7 +35,7 @@ pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { }; pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["shell"], + canonical_names: &["execute_shell_command"], display_verb: "run", progressive_verb: "Running command...", past_verb: "Ran command", diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 242a470a16b..1206edbb101 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -1,6 +1,7 @@ use std::{ io::BufRead, path::{Path, PathBuf}, + time::Duration, }; use eyre::Result; @@ -11,8 +12,73 @@ use crate::permissions::rule::Rule; /// Result of executing a client-side tool. pub(crate) enum ToolOutcome { + /// Simple success with a text result (used by Read, AtuinHistory). Success(String), + /// Error with a message. Error(String), + /// Structured shell result with separated stdout, stderr, exit code, and duration. + Structured { + stdout: String, + stderr: String, + exit_code: Option, + duration_ms: u64, + interrupted: bool, + }, +} + +impl ToolOutcome { + /// Format this outcome as a string for the tool result sent to the LLM. + pub fn format_for_llm(&self) -> String { + match self { + ToolOutcome::Success(s) => s.clone(), + ToolOutcome::Error(e) => e.clone(), + ToolOutcome::Structured { + stdout, + stderr, + exit_code, + duration_ms, + interrupted, + } => { + let mut parts = Vec::new(); + + if let Some(code) = exit_code { + parts.push(format!("Exit code: {code}")); + } + + parts.push(format!("Duration: {duration_ms}ms")); + + if !stdout.is_empty() { + parts.push(format!("stdout:\n{stdout}")); + } else { + parts.push("stdout: (empty)".to_string()); + } + + if !stderr.is_empty() { + parts.push(format!("stderr:\n{stderr}")); + } else { + parts.push("stderr: (empty)".to_string()); + } + + if *interrupted { + parts.push("[Interrupted by user]".to_string()); + } + + parts.join("\n\n") + } + } + } + + /// Whether this outcome represents an error. + pub fn is_error(&self) -> bool { + match self { + ToolOutcome::Error(_) => true, + ToolOutcome::Structured { + exit_code: Some(code), + .. + } if *code != 0 => true, + _ => false, + } + } } /// A pending tool call from the server, awaiting permissions or execution. @@ -44,6 +110,16 @@ impl PendingToolCall { pub fn mark_denied(&mut self, reason: String) { self.state = ToolCallState::Denied(reason); } + + /// Mark this tool call as executing with a live preview. + pub fn mark_executing_preview(&mut self, command: String) { + self.state = ToolCallState::ExecutingPreview { + command, + output_lines: Vec::new(), + exit_code: None, + interrupted: false, + }; + } } /// State of a pending tool call @@ -53,6 +129,16 @@ pub(crate) enum ToolCallState { AskingForPermission, Denied(String), Executing, + /// Shell command is executing with live preview output. + ExecutingPreview { + command: String, + /// Current VT100 screen lines (plain text, viewport-sized). + output_lines: Vec, + /// Exit code once the process completes. + exit_code: Option, + /// Whether the command was interrupted by the user. + interrupted: bool, + }, } /// A tool call from the server, with parsed input parameters. @@ -74,7 +160,7 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { "str_replace" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), "file_create" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), "file_insert" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), - "shell" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), + "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, )), @@ -371,6 +457,167 @@ impl PermissableToolCall for ShellToolCall { } } +/// Preview viewport height for VT100 emulation. +const PREVIEW_HEIGHT: u16 = 10; + +/// Default terminal width for VT100 emulation. +const PREVIEW_WIDTH: u16 = 120; + +/// Extract plain text lines from a VT100 screen buffer. +fn vt100_screen_lines(screen: &vt100::Screen) -> Vec { + let (rows, cols) = screen.size(); + let mut lines = Vec::with_capacity(rows as usize); + for row in 0..rows { + let mut line = String::with_capacity(cols as usize); + for col in 0..cols { + if let Some(cell) = screen.cell(row, col) { + line.push_str(cell.contents()); + } + } + // Trim trailing whitespace for cleaner display + lines.push(line.trim_end().to_string()); + } + lines +} + +/// Execute a shell command with VT100 emulation and streaming output. +/// +/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, +/// progress bars (`\r`), and cursor movement are handled correctly. Periodically +/// sends the current screen state as `Vec` through `output_tx` for the +/// live preview. +/// +/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. +/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. +pub(crate) async fn execute_shell_command_streaming( + shell_call: &ShellToolCall, + output_tx: tokio::sync::mpsc::Sender>, + mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, +) -> ToolOutcome { + use tokio::io::AsyncReadExt; + + let start = std::time::Instant::now(); + + // TODO: check if this is proper for all shells we support + let mut cmd = tokio::process::Command::new(&shell_call.shell); + cmd.arg("-c").arg(&shell_call.command); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + if let Some(ref dir) = shell_call.dir { + cmd.current_dir(dir); + } + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), + }; + + let stdout = child.stdout.take().expect("stdout was piped"); + let stderr = child.stderr.take().expect("stderr was piped"); + + // VT100 emulator for the live preview (viewport-sized) + let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); + + let mut stdout_reader = tokio::io::BufReader::new(stdout); + let mut stderr_reader = tokio::io::BufReader::new(stderr); + + let mut stdout_buf = [0u8; 4096]; + let mut stderr_buf = [0u8; 4096]; + let mut stdout_done = false; + let mut stderr_done = false; + + // Full output buffers (for the LLM, not the preview) + let mut full_stdout = Vec::::new(); + let mut full_stderr = Vec::::new(); + + let mut interval = tokio::time::interval(Duration::from_millis(50)); + + // Send initial empty screen + let initial_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(initial_lines).await; + + let mut interrupted = false; + + loop { + tokio::select! { + biased; + + // Check for interrupt signal + _ = &mut interrupt_rx, if !interrupted => { + interrupted = true; + let _ = child.start_kill(); + } + + // Read stdout + result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { + match result { + Ok(0) => stdout_done = true, + Ok(n) => { + full_stdout.extend_from_slice(&stdout_buf[..n]); + parser.process(&stdout_buf[..n]); + } + Err(_) => stdout_done = true, + } + } + + // Read stderr + result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { + match result { + Ok(0) => stderr_done = true, + Ok(n) => { + full_stderr.extend_from_slice(&stderr_buf[..n]); + // Feed stderr to the preview parser too, so it shows in the VT100 screen + parser.process(&stderr_buf[..n]); + } + Err(_) => stderr_done = true, + } + } + + // Periodic screen snapshot for preview + _ = interval.tick() => { + let lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(lines).await; + } + } + + // Exit when both streams are done + if stdout_done && stderr_done { + break; + } + } + + // Wait for process to finish + let exit_code = match child.wait().await { + Ok(status) => status.code(), + Err(e) => { + if interrupted { + None + } else { + return ToolOutcome::Error(format!("Failed to wait for command: {e}")); + } + } + }; + + let duration = start.elapsed(); + + // Send final screen state + let final_lines = vt100_screen_lines(parser.screen()); + let _ = output_tx.send(final_lines).await; + + // Strip ANSI from the raw bytes for clean LLM output + let stdout_text = String::from_utf8_lossy(&full_stdout).to_string(); + let stderr_text = String::from_utf8_lossy(&full_stderr).to_string(); + + ToolOutcome::Structured { + stdout: stdout_text, + stderr: stderr_text, + exit_code, + duration_ms: duration.as_millis() as u64, + interrupted, + } +} + #[derive(Debug, Clone)] pub(crate) struct AtuinHistoryToolCall { pub filter_modes: Vec, diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index 2db2b216495..f52952d05b9 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -55,9 +55,16 @@ fn atuin_ai( return EventResult::Ignored; }; - // Ctrl+C always exits + // Ctrl+C — interrupt executing command or exit if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - let _ = tx.send(AiTuiEvent::Exit); + match props.mode { + AppMode::ExecutingPreview => { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + } + _ => { + let _ = tx.send(AiTuiEvent::Exit); + } + } return EventResult::Consumed; } @@ -97,6 +104,13 @@ fn atuin_ai( } _ => EventResult::Ignored, }, + AppMode::ExecutingPreview => match code { + KeyCode::Esc => { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + EventResult::Consumed + } + _ => EventResult::Ignored, + }, AppMode::Error => match code { KeyCode::Esc => { let _ = tx.send(AiTuiEvent::Exit); diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index 44a11a84e5b..5955d46a66f 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -5,10 +5,10 @@ use crate::context::{AppContext, ClientContext}; use crate::permissions::check::PermissionResponse; use crate::permissions::resolver::PermissionResolver; use crate::stream::{ChatRequest, run_chat_stream}; -use crate::tools::ToolCallState; +use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; use crate::tui::ConversationEvent; use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::{ExitAction, Session}; +use crate::tui::state::{AppMode, ExitAction, Session}; use eye_declare::Handle; use tokio::task::JoinHandle; @@ -47,6 +47,9 @@ pub(crate) fn dispatch( AiTuiEvent::CancelConfirmation => { on_cancel_confirmation(handle); } + AiTuiEvent::InterruptToolExecution => { + on_interrupt_tool_execution(handle); + } AiTuiEvent::InsertCommand => { on_insert_command(handle); } @@ -144,6 +147,138 @@ fn on_slash_command(handle: &Handle, command: String) { }); } +// ─────────────────────────────────────────────────────────────────── +// Tool execution dispatch +// ─────────────────────────────────────────────────────────────────── + +/// Execute a tool call. Handles Shell tools (streaming with preview) and +/// non-shell tools (synchronous) uniformly. Callers provide the resolved +/// PendingToolCall; this function handles all state transitions. +fn execute_tool( + handle: &Handle, + tx: &mpsc::Sender, + tool_call: PendingToolCall, + db: &std::sync::Arc, +) { + match &tool_call.tool { + ClientToolCall::Shell(shell_call) => { + let shell_call = shell_call.clone(); + execute_shell_tool(handle, tx, tool_call, &shell_call); + } + _ => { + execute_simple_tool(handle, tx, tool_call, db); + } + } +} + +/// Execute a non-shell tool synchronously and complete the tool call. +fn execute_simple_tool( + handle: &Handle, + tx: &mpsc::Sender, + tool_call: PendingToolCall, + db: &std::sync::Arc, +) { + let h = handle.clone(); + let tx = tx.clone(); + let db = db.clone(); + + tokio::spawn(async move { + let outcome = tool_call.tool.execute(&db).await; + h.update(move |state| { + state.complete_tool_call(&tool_call, outcome); + if !state.has_unresolved_tool_calls() { + let _ = tx.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); +} + +/// Execute a shell tool with streaming VT100 preview. The ToolCall event is +/// added to the conversation immediately so it persists in chat output. +/// A live preview renders in the input area during execution. +fn execute_shell_tool( + handle: &Handle, + tx: &mpsc::Sender, + tool_call: PendingToolCall, + shell_call: &crate::tools::ShellToolCall, +) { + let h = handle.clone(); + let tx = tx.clone(); + let shell_call = shell_call.clone(); + let command = shell_call.command.clone(); + + // Extract all data we need before moving into closures + let tc_id = tool_call.id.clone(); + let tc_id_for_update = tc_id.clone(); + let tc_for_finish = tool_call.clone(); + let tool_for_begin = tool_call.tool.clone(); + + // Build the input JSON for the ToolCall event (matches server format) + let input_json = serde_json::json!({ + "command": shell_call.command, + "dir": shell_call.dir, + "shell": shell_call.shell, + }); + + // 1. Immediately add the ToolCall event to conversation and enter preview mode + h.update(move |state| { + state.begin_tool_call(&tc_id_for_update, &tool_for_begin, input_json); + if let Some(tc) = state.pending_tool_call_mut(&tc_id_for_update) { + tc.mark_executing_preview(command); + } + state.interaction.mode = AppMode::ExecutingPreview; + state.shell_abort_tx = None; // will be set below + }); + + // 2. Set up channels for streaming output and interruption + let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::>(32); + let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>(); + + h.update(move |state| { + state.shell_abort_tx = Some(abort_tx); + }); + + // 3. Spawn the streaming execution task + let h_exec = h.clone(); + let tx_exec = tx.clone(); + tokio::spawn(async move { + let outcome = + crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; + + h_exec.update(move |state| { + state.finish_tool_call(&tc_for_finish, outcome); + state.shell_abort_tx = None; + state.interaction.mode = AppMode::Input; + if !state.has_unresolved_tool_calls() { + let _ = tx_exec.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); + + // 4. Spawn a task to consume output updates and feed them to state + let h_output = h.clone(); + let preview_id = tc_id; + tokio::spawn(async move { + while let Some(lines) = output_rx.recv().await { + let id = preview_id.clone(); + h_output.update(move |state| { + if let Some(tc) = state.pending_tool_call_mut(&id) + && let ToolCallState::ExecutingPreview { + ref mut output_lines, + .. + } = tc.state + { + *output_lines = lines; + } + }); + } + }); +} + +// ─────────────────────────────────────────────────────────────────── +// Permission handlers +// ─────────────────────────────────────────────────────────────────── + fn on_check_tool_permission( handle: &Handle, tx: &mpsc::Sender, @@ -185,13 +320,7 @@ fn on_check_tool_permission( // 4. Handle response match response { PermissionResponse::Allowed => { - let outcome = tool_call.tool.execute(&db).await; - h2.update(move |state| { - state.complete_tool_call(&tool_call, outcome); - if !state.has_unresolved_tool_calls() { - let _ = tx_for_task.send(AiTuiEvent::ContinueAfterTools); - } - }); + execute_tool(&h2, &tx_for_task, tool_call, &db); } PermissionResponse::Denied => { let tx = tx_for_task.clone(); @@ -229,13 +358,12 @@ fn on_select_permission( ) { let tx = tx.clone(); let h2 = handle.clone(); - let db = app_ctx.history_db.clone(); match permission { PermissionResult::Allow => { - // Fetch the tool call that's asking for permission, then execute it async + // Fetch the tool call that's asking for permission, then execute it let h3 = h2.clone(); - let tx2 = tx.clone(); + let db = app_ctx.history_db.clone(); tokio::spawn(async move { let Ok(Some(tool_call)) = h3 .fetch(move |state| { @@ -250,13 +378,7 @@ fn on_select_permission( return; }; - let outcome = tool_call.tool.execute(&db).await; - h3.update(move |state| { - state.complete_tool_call(&tool_call, outcome); - if !state.has_unresolved_tool_calls() { - let _ = tx2.send(AiTuiEvent::ContinueAfterTools); - } - }); + execute_tool(&h3, &tx, tool_call, &db); }); } PermissionResult::AlwaysAllowInDir => { @@ -294,6 +416,10 @@ fn on_select_permission( } } +// ─────────────────────────────────────────────────────────────────── +// Other handlers +// ─────────────────────────────────────────────────────────────────── + fn on_cancel_generation(handle: &Handle) { handle.update(|state| match state.interaction.mode { crate::tui::state::AppMode::Generating => { @@ -363,3 +489,31 @@ fn on_exit(handle: &Handle) { h2.exit(); }); } + +fn on_interrupt_tool_execution(handle: &Handle) { + handle.update(move |state| { + // Send interrupt signal to the running shell command + if let Some(abort_tx) = state.shell_abort_tx.take() { + let _ = abort_tx.send(()); + } + + // Mark the executing preview as interrupted + for tc in &mut state.pending_tool_calls { + if let ToolCallState::ExecutingPreview { + ref mut interrupted, + ref mut exit_code, + .. + } = tc.state + { + *interrupted = true; + if exit_code.is_none() { + *exit_code = Some(-1); + } + } + } + + // Return to input mode — the spawned execution task will handle + // finalizing and sending ContinueAfterTools when the process exits. + state.interaction.mode = AppMode::Input; + }); +} diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs index a3aa87942ee..78cc71ebda8 100644 --- a/crates/atuin-ai/src/tui/events.rs +++ b/crates/atuin-ai/src/tui/events.rs @@ -26,6 +26,8 @@ pub(crate) enum AiTuiEvent { InsertCommand, /// Cancel confirmation of dangerous command CancelConfirmation, + /// Interrupt a running tool execution (Ctrl+C during ExecutingPreview) + InterruptToolExecution, /// Retry after error Retry, /// Exit the application diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 221f95ea44b..9c1b5b790dc 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -127,6 +127,8 @@ pub(crate) enum AppMode { Streaming, /// Error state, can retry Error, + /// Shell tool is executing with live preview + ExecutingPreview, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -419,6 +421,8 @@ pub(crate) struct Session { pub exit_action: Option, /// Abort handle for the active streaming task, if any pub stream_abort: Option, + /// Sender to interrupt a running shell command preview. + pub shell_abort_tx: Option>, } impl Session { @@ -429,6 +433,7 @@ impl Session { pending_tool_calls: VecDeque::new(), exit_action: None, stream_abort: None, + shell_abort_tx: None, } } @@ -578,29 +583,30 @@ impl Session { .find(|call| call.id == id) } + /// Record a tool call event in the conversation. + /// Call this BEFORE execution begins so the ToolCall shows in chat output. + pub fn begin_tool_call(&mut self, id: &str, tool: &ClientToolCall, input: serde_json::Value) { + let desc = tool.descriptor(); + self.add_tool_call(id.to_string(), desc.canonical_names[0].to_string(), input); + } + + /// Record the result of a tool call and remove it from the pending queue. + /// Call this AFTER execution completes. The ToolCall event must already exist + /// in the conversation (added by `begin_tool_call`). + pub fn finish_tool_call(&mut self, pending: &PendingToolCall, outcome: ToolOutcome) { + let content = outcome.format_for_llm(); + let is_error = outcome.is_error(); + self.conversation + .add_tool_result(pending.id.clone(), content, is_error); + self.pending_tool_calls.retain(|c| c.id != pending.id); + } + /// Record a tool call, its execution result, and remove it from the pending queue. + /// Convenience method that combines begin + finish for tools that don't need + /// the ToolCall visible during execution. pub fn complete_tool_call(&mut self, pending: &PendingToolCall, outcome: ToolOutcome) { - let desc = pending.tool.descriptor(); - - // Record the tool call so the view can render a ToolCall → ToolResult pair - self.add_tool_call( - pending.id.clone(), - desc.canonical_names[0].to_string(), - serde_json::json!({}), - ); - - // Record the result - match outcome { - ToolOutcome::Success(content) => { - self.conversation - .add_tool_result(pending.id.clone(), content, false); - } - ToolOutcome::Error(msg) => { - self.conversation - .add_tool_result(pending.id.clone(), msg, true); - } - } - self.pending_tool_calls.retain(|c| c.id != pending.id); + self.begin_tool_call(&pending.id, &pending.tool, serde_json::json!({})); + self.finish_tool_call(pending, outcome); } /// Returns true if any tool calls are still in CheckingPermissions or AskingForPermission state. @@ -628,6 +634,7 @@ impl Session { } } AppMode::Generating | AppMode::Streaming => "[Esc] Cancel", + AppMode::ExecutingPreview => "[Ctrl+C] Interrupt [Esc] Interrupt", AppMode::Error => "[Enter]/[r] Retry [Esc] Exit", } } diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 0c9a93c33d5..a60d79787fc 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -1,7 +1,8 @@ //! View function that builds the eye-declare element tree from app state. use eye_declare::{ - Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element, + BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, + WidthConstraint, element, }; use ratatui_core::style::{Color, Modifier, Style}; @@ -66,17 +67,26 @@ pub(crate) fn ai_view(state: &Session) -> Elements { } fn input_view(state: &Session) -> Elements { - let first_pending_tool_call = state + let asking_tool = state .pending_tool_calls .iter() .find(|call| call.state == ToolCallState::AskingForPermission); + let executing_tool = state + .pending_tool_calls + .iter() + .find(|call| matches!(call.state, ToolCallState::ExecutingPreview { .. })); + element! { - #(if first_pending_tool_call.is_some() { - #(tool_call_view(first_pending_tool_call.unwrap())) + #(if let Some(tc) = asking_tool { + #(tool_call_view(tc)) + }) + + #(if let Some(tc) = executing_tool { + #(executing_preview_view(tc)) }) - #(if first_pending_tool_call.is_none() { + #(if asking_tool.is_none() && executing_tool.is_none() { View(key: "input-box", padding_top: Cells::from(1)) { InputBox( key: "input", @@ -147,6 +157,70 @@ fn tool_call_view(tool_call: &PendingToolCall) -> Elements { } } +fn executing_preview_view(tool_call: &PendingToolCall) -> Elements { + let (command, output_lines, exit_code, interrupted) = match &tool_call.state { + ToolCallState::ExecutingPreview { + command, + output_lines, + exit_code, + interrupted, + } => ( + command.clone(), + output_lines.clone(), + *exit_code, + *interrupted, + ), + _ => return element! {}, + }; + + let spinner_done = exit_code.is_some() || interrupted; + + element! { + View(key: format!("preview-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { + // Command header with spinner + Spinner( + label: format!(" Running: {}", command), + label_style: Style::default().fg(Color::Yellow), + done: spinner_done, + ) + + // Fixed-height viewport showing the VT100 screen output + Viewport( + lines: output_lines, + height: 10, + border: BorderType::Plain, + border_style: Style::default().fg(Color::DarkGray), + style: Style::default().fg(Color::White), + ) + + // Status line + #(if let Some(code) = exit_code { + #(if code == 0 { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) + } + } else { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) + } + }) + }) + + #(if interrupted { + Text { + Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + } + }) + + #(if !spinner_done { + Text { + Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) + } + }) + } + } +} + fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { let label_style = Style::default() .fg(Color::Cyan) From 04ef52cc8d3cecce73223d3b2df9fd81fd1f3771 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Wed, 8 Apr 2026 21:08:15 -0700 Subject: [PATCH 19/52] Fix bug in events -> messages conversion --- crates/atuin-ai/src/tui/state.rs | 47 ++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 9c1b5b790dc..b430981e27b 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -174,13 +174,50 @@ impl Conversation { i += 1; } ConversationEvent::Text { content } => { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; + // Check if the next event(s) are ToolCalls — if so, combine + // into a single assistant message with mixed content blocks. + let next_is_tool_call = events + .get(i + 1) + .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); + + if next_is_tool_call { + let mut content_blocks = Vec::new(); + + if !content.is_empty() { + content_blocks.push(serde_json::json!({ + "type": "text", + "text": content + })); + } + + while let Some(ConversationEvent::ToolCall { id, name, input }) = + events.get(i + 1) + { + content_blocks.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } + + messages.push(serde_json::json!({ + "role": "assistant", + "content": content_blocks + })); + i += 1; + } else { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } } ConversationEvent::ToolCall { .. } => { + // ToolCalls without preceding Text (shouldn't normally happen, + // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { if let ConversationEvent::ToolCall { id, name, input } = &events[i] { From 45f41707db8e5d43b95b0105717da2bc3e46a0bd Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Thu, 9 Apr 2026 12:01:13 -0700 Subject: [PATCH 20/52] Introduce ToolTracker as single source of truth for tool execution state Replace the split state model (pending_tool_calls VecDeque + preview field on ConversationEvent::ToolCall) with a unified ToolTracker that owns each tool call through its full lifecycle, including after completion. This eliminates the preview copy-back dance, the two-place preview lookup, and the TurnBuilder second-pass update_previews step. Key changes: - New ToolTracker/TrackedTool/ToolPhase types replace PendingToolCall/ToolCallState - ConversationEvent::ToolCall drops its preview field (now purely API-facing) - shell_abort_tx moves from Session to TrackedTool.abort_tx (per-tool, not per-session) - TurnBuilder takes &ToolTracker reference, looks up previews inline - Fix spinner not animating during shell preview (work around eye_declare interval reset by computing frame from system clock) - Fix word wrapping in shell output preview (use truncation instead of word-boundary wrapping for VT100 content) - Use multi-thread tokio runtime for AI commands --- crates/atuin-ai/src/stream.rs | 9 +- crates/atuin-ai/src/tools/mod.rs | 179 ++++++++++++---- .../atuin-ai/src/tui/components/atuin_ai.rs | 24 +-- crates/atuin-ai/src/tui/dispatch.rs | 194 ++++++++---------- crates/atuin-ai/src/tui/state.rs | 129 ++++++------ crates/atuin-ai/src/tui/view/mod.rs | 166 ++++++++------- crates/atuin-ai/src/tui/view/turn.rs | 16 +- crates/atuin/src/command/client.rs | 14 +- 8 files changed, 411 insertions(+), 320 deletions(-) diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index b93da09de26..6a0f5c264c1 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -263,14 +263,15 @@ fn apply_content_frame( } StreamContent::ToolCall { id, name, input } => { if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) { - // Client-side tool — queue for permission check - let id_for_update = id.clone(); + // Client-side tool — add to tracker and conversation, queue permission check + let id_for_event = id.clone(); + let input_for_event = input.clone(); handle.update(move |state| { - state.handle_client_tool_call(id_for_update, tool); + state.handle_client_tool_call(id_for_event, tool, input_for_event); }); let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id)); } else { - // Server-side tool + // Server-side tool — just add to conversation events handle.update(move |state| { state.add_tool_call(id, name, input); }); diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index 1206edbb101..40a4070e321 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -81,64 +81,167 @@ impl ToolOutcome { } } -/// A pending tool call from the server, awaiting permissions or execution. -#[derive(Debug, Clone)] -pub(crate) struct PendingToolCall { +/// Cached VT100 preview data for a shell tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ToolPreview { + pub lines: Vec, + pub exit_code: Option, + pub interrupted: bool, +} + +/// Lifecycle phase of a tracked tool call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ToolPhase { + CheckingPermissions, + AskingForPermission, + Denied(String), + Executing, + /// Shell command is executing with live preview output. + ExecutingWithPreview { + command: String, + /// Current VT100 screen lines (plain text, viewport-sized). + output_lines: Vec, + /// Exit code once the process completes. + exit_code: Option, + /// Whether the command was interrupted by the user. + interrupted: bool, + }, + /// Tool execution has completed. Preview is cached for rendering history. + Completed { + preview: Option, + }, +} + +/// A tracked tool call through its full lifecycle. +#[derive(Debug)] +pub(crate) struct TrackedTool { pub id: String, - pub state: ToolCallState, pub tool: ClientToolCall, + pub phase: ToolPhase, + /// Sender to interrupt a running shell command (only set during ExecutingWithPreview). + pub abort_tx: Option>, } -impl PendingToolCall { +impl TrackedTool { pub(crate) fn target_dir(&self) -> Option<&Path> { self.tool.target_dir() } - /// Mark this tool call as waiting for user permission. pub fn mark_asking(&mut self) { - self.state = ToolCallState::AskingForPermission; - } - - /// Mark this tool call as currently executing. - #[expect(dead_code)] - pub fn mark_executing(&mut self) { - self.state = ToolCallState::Executing; - } - - /// Mark this tool call as denied. - #[expect(dead_code)] - pub fn mark_denied(&mut self, reason: String) { - self.state = ToolCallState::Denied(reason); + self.phase = ToolPhase::AskingForPermission; } - /// Mark this tool call as executing with a live preview. pub fn mark_executing_preview(&mut self, command: String) { - self.state = ToolCallState::ExecutingPreview { + self.phase = ToolPhase::ExecutingWithPreview { command, output_lines: Vec::new(), exit_code: None, interrupted: false, }; } + + pub fn complete(&mut self, preview: Option) { + self.phase = ToolPhase::Completed { preview }; + self.abort_tx = None; + } + + /// Extract the current preview, whether live or completed. + pub fn preview(&self) -> Option { + match &self.phase { + ToolPhase::ExecutingWithPreview { + output_lines, + exit_code, + interrupted, + .. + } => Some(ToolPreview { + lines: output_lines.clone(), + exit_code: *exit_code, + interrupted: *interrupted, + }), + ToolPhase::Completed { preview } => preview.clone(), + _ => None, + } + } } -/// State of a pending tool call -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ToolCallState { - CheckingPermissions, - AskingForPermission, - Denied(String), - Executing, - /// Shell command is executing with live preview output. - ExecutingPreview { - command: String, - /// Current VT100 screen lines (plain text, viewport-sized). - output_lines: Vec, - /// Exit code once the process completes. - exit_code: Option, - /// Whether the command was interrupted by the user. - interrupted: bool, - }, +/// Tracks all tool calls through their full lifecycle. +/// +/// Single source of truth for tool execution state. Entries persist after +/// completion so cached previews remain available for rendering history. +#[derive(Debug)] +pub(crate) struct ToolTracker { + tools: Vec, +} + +impl ToolTracker { + pub fn new() -> Self { + Self { tools: Vec::new() } + } + + /// Insert a new tool call in CheckingPermissions phase. + pub fn insert(&mut self, id: String, tool: ClientToolCall) { + self.tools.push(TrackedTool { + id, + tool, + phase: ToolPhase::CheckingPermissions, + abort_tx: None, + }); + } + + pub fn get(&self, id: &str) -> Option<&TrackedTool> { + self.tools.iter().find(|t| t.id == id) + } + + pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { + self.tools.iter_mut().find(|t| t.id == id) + } + + /// Remove a tool by ID and return it. + pub fn remove(&mut self, id: &str) -> Option { + let pos = self.tools.iter().position(|t| t.id == id)?; + Some(self.tools.remove(pos)) + } + + /// True if any tool is still in CheckingPermissions or AskingForPermission. + pub fn has_unresolved(&self) -> bool { + self.tools.iter().any(|t| { + matches!( + t.phase, + ToolPhase::CheckingPermissions | ToolPhase::AskingForPermission + ) + }) + } + + /// True if any tool is currently executing with a preview. + pub fn has_executing_preview(&self) -> bool { + self.tools + .iter() + .any(|t| matches!(t.phase, ToolPhase::ExecutingWithPreview { .. })) + } + + /// Find the first tool that is asking for permission. + pub fn asking_for_permission(&self) -> Option<&TrackedTool> { + self.tools + .iter() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Find the first tool that is asking for permission (mutable). + pub fn asking_for_permission_mut(&mut self) -> Option<&mut TrackedTool> { + self.tools + .iter_mut() + .find(|t| t.phase == ToolPhase::AskingForPermission) + } + + /// Get the preview for a tool by ID (live or cached). + pub fn preview_for(&self, id: &str) -> Option { + self.get(id)?.preview() + } + + /// Iterate mutably over all tracked tools. + pub fn iter_mut(&mut self) -> impl Iterator { + self.tools.iter_mut() + } } /// A tool call from the server, with parsed input parameters. diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs index f52952d05b9..c04ac72292a 100644 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ b/crates/atuin-ai/src/tui/components/atuin_ai.rs @@ -22,6 +22,7 @@ pub(crate) struct AtuinAi { pub has_command: bool, pub is_input_blank: bool, pub pending_confirmation: bool, + pub has_executing_preview: bool, } #[derive(Default)] @@ -57,13 +58,10 @@ fn atuin_ai( // Ctrl+C — interrupt executing command or exit if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - match props.mode { - AppMode::ExecutingPreview => { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - } - _ => { - let _ = tx.send(AiTuiEvent::Exit); - } + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + } else { + let _ = tx.send(AiTuiEvent::Exit); } return EventResult::Consumed; } @@ -71,6 +69,11 @@ fn atuin_ai( match props.mode { AppMode::Input => match code { KeyCode::Esc => { + if props.has_executing_preview { + let _ = tx.send(AiTuiEvent::InterruptToolExecution); + return EventResult::Consumed; + } + if props.pending_confirmation { let _ = tx.send(AiTuiEvent::CancelConfirmation); return EventResult::Consumed; @@ -104,13 +107,6 @@ fn atuin_ai( } _ => EventResult::Ignored, }, - AppMode::ExecutingPreview => match code { - KeyCode::Esc => { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - EventResult::Consumed - } - _ => EventResult::Ignored, - }, AppMode::Error => match code { KeyCode::Esc => { let _ = tx.send(AiTuiEvent::Exit); diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index 5955d46a66f..6d33652f934 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -5,10 +5,10 @@ use crate::context::{AppContext, ClientContext}; use crate::permissions::check::PermissionResponse; use crate::permissions::resolver::PermissionResolver; use crate::stream::{ChatRequest, run_chat_stream}; -use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; +use crate::tools::{ClientToolCall, ToolPhase}; use crate::tui::ConversationEvent; use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::{AppMode, ExitAction, Session}; +use crate::tui::state::{ExitAction, Session}; use eye_declare::Handle; use tokio::task::JoinHandle; @@ -152,21 +152,21 @@ fn on_slash_command(handle: &Handle, command: String) { // ─────────────────────────────────────────────────────────────────── /// Execute a tool call. Handles Shell tools (streaming with preview) and -/// non-shell tools (synchronous) uniformly. Callers provide the resolved -/// PendingToolCall; this function handles all state transitions. +/// non-shell tools (synchronous) uniformly. fn execute_tool( handle: &Handle, tx: &mpsc::Sender, - tool_call: PendingToolCall, + tool_id: String, + tool: ClientToolCall, db: &std::sync::Arc, ) { - match &tool_call.tool { + match &tool { ClientToolCall::Shell(shell_call) => { let shell_call = shell_call.clone(); - execute_shell_tool(handle, tx, tool_call, &shell_call); + execute_shell_tool(handle, tx, &tool_id, &shell_call); } _ => { - execute_simple_tool(handle, tx, tool_call, db); + execute_simple_tool(handle, tx, tool_id, tool, db); } } } @@ -175,7 +175,8 @@ fn execute_tool( fn execute_simple_tool( handle: &Handle, tx: &mpsc::Sender, - tool_call: PendingToolCall, + tool_id: String, + tool: ClientToolCall, db: &std::sync::Arc, ) { let h = handle.clone(); @@ -183,96 +184,79 @@ fn execute_simple_tool( let db = db.clone(); tokio::spawn(async move { - let outcome = tool_call.tool.execute(&db).await; + let outcome = tool.execute(&db).await; h.update(move |state| { - state.complete_tool_call(&tool_call, outcome); - if !state.has_unresolved_tool_calls() { + state.complete_tool_call(&tool_id, &tool, outcome); + if !state.tool_tracker.has_unresolved() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); } }); }); } -/// Execute a shell tool with streaming VT100 preview. The ToolCall event is -/// added to the conversation immediately so it persists in chat output. -/// A live preview renders in the input area during execution. +/// Execute a shell tool with streaming VT100 preview. fn execute_shell_tool( handle: &Handle, tx: &mpsc::Sender, - tool_call: PendingToolCall, + tool_id: &str, shell_call: &crate::tools::ShellToolCall, ) { let h = handle.clone(); let tx = tx.clone(); let shell_call = shell_call.clone(); let command = shell_call.command.clone(); + let tc_id = tool_id.to_string(); - // Extract all data we need before moving into closures - let tc_id = tool_call.id.clone(); - let tc_id_for_update = tc_id.clone(); - let tc_for_finish = tool_call.clone(); - let tool_for_begin = tool_call.tool.clone(); - - // Build the input JSON for the ToolCall event (matches server format) - let input_json = serde_json::json!({ - "command": shell_call.command, - "dir": shell_call.dir, - "shell": shell_call.shell, - }); - - // 1. Immediately add the ToolCall event to conversation and enter preview mode - h.update(move |state| { - state.begin_tool_call(&tc_id_for_update, &tool_for_begin, input_json); - if let Some(tc) = state.pending_tool_call_mut(&tc_id_for_update) { - tc.mark_executing_preview(command); - } - state.interaction.mode = AppMode::ExecutingPreview; - state.shell_abort_tx = None; // will be set below - }); - - // 2. Set up channels for streaming output and interruption + // 1. Set up channels for streaming output and interruption let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::>(32); let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>(); + // 2. Mark as executing with preview and store the abort sender on the tracker entry + let tc_id_setup = tc_id.clone(); h.update(move |state| { - state.shell_abort_tx = Some(abort_tx); - }); - - // 3. Spawn the streaming execution task - let h_exec = h.clone(); - let tx_exec = tx.clone(); - tokio::spawn(async move { - let outcome = - crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; - - h_exec.update(move |state| { - state.finish_tool_call(&tc_for_finish, outcome); - state.shell_abort_tx = None; - state.interaction.mode = AppMode::Input; - if !state.has_unresolved_tool_calls() { - let _ = tx_exec.send(AiTuiEvent::ContinueAfterTools); - } - }); + if let Some(tracked) = state.tool_tracker.get_mut(&tc_id_setup) { + tracked.mark_executing_preview(command); + tracked.abort_tx = Some(abort_tx); + } }); - // 4. Spawn a task to consume output updates and feed them to state + // 3. Spawn a task to consume output updates and feed them to state let h_output = h.clone(); - let preview_id = tc_id; - tokio::spawn(async move { + let preview_id = tc_id.clone(); + let output_task = tokio::spawn(async move { while let Some(lines) = output_rx.recv().await { let id = preview_id.clone(); h_output.update(move |state| { - if let Some(tc) = state.pending_tool_call_mut(&id) - && let ToolCallState::ExecutingPreview { + if let Some(tracked) = state.tool_tracker.get_mut(&id) + && let ToolPhase::ExecutingWithPreview { ref mut output_lines, .. - } = tc.state + } = tracked.phase { *output_lines = lines; } }); } }); + + // 4. Spawn the streaming execution task + let h_exec = h.clone(); + let tx_exec = tx.clone(); + let tc_id_finish = tc_id; + tokio::spawn(async move { + let outcome = + crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await; + + // Wait for the output task to finish so the final preview lines are captured + let _ = output_task.await; + + h_exec.update(move |state| { + state.finish_tool_call(&tc_id_finish, outcome); + if !state.tool_tracker.has_unresolved() { + let _ = tx_exec.send(AiTuiEvent::ContinueAfterTools); + } + }); + }); } // ─────────────────────────────────────────────────────────────────── @@ -291,20 +275,21 @@ fn on_check_tool_permission( let db = app_ctx.history_db.clone(); tokio::spawn(async move { - // 1. Fetch the pending tool call - let Ok(Some(tool_call)) = h2 - .fetch(move |state| state.pending_tool_call(&id).cloned()) + // 1. Fetch the tracked tool's data (clone what we need for permission check) + let Ok(Some((tool, target_dir))) = h2 + .fetch(move |state| { + state + .tool_tracker + .get(&id) + .map(|t| (t.tool.clone(), t.target_dir().map(PathBuf::from))) + }) .await else { return; }; // 2. Resolve working directory - let Some(working_dir) = tool_call - .target_dir() - .map(PathBuf::from) - .or_else(|| std::env::current_dir().ok()) - else { + let Some(working_dir) = target_dir.or_else(|| std::env::current_dir().ok()) else { return; }; @@ -313,14 +298,14 @@ fn on_check_tool_permission( return; }; - let Ok(response) = resolver.check(&tool_call.tool).await else { + let Ok(response) = resolver.check(&tool).await else { return; }; // 4. Handle response match response { PermissionResponse::Allowed => { - execute_tool(&h2, &tx_for_task, tool_call, &db); + execute_tool(&h2, &tx_for_task, id_clone, tool, &db); } PermissionResponse::Denied => { let tx = tx_for_task.clone(); @@ -333,16 +318,16 @@ fn on_check_tool_permission( content: format!("Permission denied for tool call {:?}", &id_clone), command: None, }); - state.pending_tool_calls.retain(|c| c.id != id_clone); - if !state.has_unresolved_tool_calls() { + state.tool_tracker.remove(&id_clone); + if !state.tool_tracker.has_unresolved() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); } }); } PermissionResponse::Ask => { h2.update(move |state| { - if let Some(tc) = state.pending_tool_call_mut(&id_clone) { - tc.mark_asking(); + if let Some(tracked) = state.tool_tracker.get_mut(&id_clone) { + tracked.mark_asking(); } }); } @@ -361,24 +346,23 @@ fn on_select_permission( match permission { PermissionResult::Allow => { - // Fetch the tool call that's asking for permission, then execute it + // Fetch the tool that's asking for permission, then execute it let h3 = h2.clone(); let db = app_ctx.history_db.clone(); tokio::spawn(async move { - let Ok(Some(tool_call)) = h3 + let Ok(Some((tool_id, tool))) = h3 .fetch(move |state| { state - .pending_tool_calls - .iter() - .find(|tc| tc.state == ToolCallState::AskingForPermission) - .cloned() + .tool_tracker + .asking_for_permission() + .map(|t| (t.id.clone(), t.tool.clone())) }) .await else { return; }; - execute_tool(&h3, &tx, tool_call, &db); + execute_tool(&h3, &tx, tool_id, tool, &db); }); } PermissionResult::AlwaysAllowInDir => { @@ -389,26 +373,18 @@ fn on_select_permission( } PermissionResult::Deny => { h2.update(move |state| { - let tool_call = state - .pending_tool_calls - .iter() - .enumerate() - .find(|(_, call)| call.state == ToolCallState::AskingForPermission); - - let Some((index, _)) = tool_call else { - return; - }; - - let Some(call) = state.pending_tool_calls.remove(index) else { + let Some(tracked) = state.tool_tracker.asking_for_permission_mut() else { return; }; + let tool_id = tracked.id.clone(); + tracked.complete(None); state.conversation.add_tool_result( - call.id, + tool_id, "Permission denied on the user's system".to_string(), true, ); - if !state.has_unresolved_tool_calls() { + if !state.tool_tracker.has_unresolved() { let _ = tx.send(AiTuiEvent::ContinueAfterTools); } }); @@ -492,28 +468,26 @@ fn on_exit(handle: &Handle) { fn on_interrupt_tool_execution(handle: &Handle) { handle.update(move |state| { - // Send interrupt signal to the running shell command - if let Some(abort_tx) = state.shell_abort_tx.take() { - let _ = abort_tx.send(()); - } - - // Mark the executing preview as interrupted - for tc in &mut state.pending_tool_calls { - if let ToolCallState::ExecutingPreview { + // Find executing previews, send interrupt, and mark as interrupted + for tracked in state.tool_tracker.iter_mut() { + if let ToolPhase::ExecutingWithPreview { ref mut interrupted, ref mut exit_code, .. - } = tc.state + } = tracked.phase { *interrupted = true; if exit_code.is_none() { *exit_code = Some(-1); } + // Send interrupt signal via the tracker entry's abort channel + if let Some(abort_tx) = tracked.abort_tx.take() { + let _ = abort_tx.send(()); + } } } - // Return to input mode — the spawned execution task will handle - // finalizing and sending ContinueAfterTools when the process exits. - state.interaction.mode = AppMode::Input; + // The spawned execution task will handle finalizing and sending + // ContinueAfterTools when the process exits. Input mode is already active. }); } diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index b430981e27b..21e3f45d64c 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -3,11 +3,9 @@ //! This module contains the core state types that represent the application's //! domain model. Conversation events match the API protocol format. -use std::collections::VecDeque; - use tokio::task::AbortHandle; -use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState, ToolOutcome}; +use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker}; /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] @@ -117,6 +115,7 @@ impl ConversationEvent { } } +/// Application mode for key handling and footer text. #[derive(Debug, Clone, PartialEq, Eq, Copy)] pub(crate) enum AppMode { /// User is typing input @@ -127,8 +126,6 @@ pub(crate) enum AppMode { Streaming, /// Error state, can retry Error, - /// Shell tool is executing with live preview - ExecutingPreview, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -190,8 +187,9 @@ impl Conversation { })); } - while let Some(ConversationEvent::ToolCall { id, name, input }) = - events.get(i + 1) + while let Some(ConversationEvent::ToolCall { + id, name, input, .. + }) = events.get(i + 1) { content_blocks.push(serde_json::json!({ "type": "tool_use", @@ -220,7 +218,10 @@ impl Conversation { // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { - if let ConversationEvent::ToolCall { id, name, input } = &events[i] { + if let ConversationEvent::ToolCall { + id, name, input, .. + } = &events[i] + { tool_uses.push(serde_json::json!({ "type": "tool_use", "id": id, @@ -452,14 +453,12 @@ impl Interaction { pub(crate) struct Session { pub conversation: Conversation, pub interaction: Interaction, - /// Tool calls that are pending permission checking + execution - pub pending_tool_calls: VecDeque, + /// Tracks all tool calls through their full lifecycle. + pub tool_tracker: ToolTracker, /// Exit action (set when exiting) pub exit_action: Option, /// Abort handle for the active streaming task, if any pub stream_abort: Option, - /// Sender to interrupt a running shell command preview. - pub shell_abort_tx: Option>, } impl Session { @@ -467,10 +466,9 @@ impl Session { Self { conversation: Conversation::new(), interaction: Interaction::new(), - pending_tool_calls: VecDeque::new(), + tool_tracker: ToolTracker::new(), exit_action: None, stream_abort: None, - shell_abort_tx: None, } } @@ -582,12 +580,22 @@ impl Session { self.interaction.mode = AppMode::Error; } - pub(crate) fn handle_client_tool_call(&mut self, id: String, tool: ClientToolCall) { - self.pending_tool_calls.push_back(PendingToolCall { - id: id.clone(), - state: ToolCallState::CheckingPermissions, - tool, - }); + pub(crate) fn handle_client_tool_call( + &mut self, + id: String, + tool: ClientToolCall, + input: serde_json::Value, + ) { + let desc = tool.descriptor(); + let name = desc.canonical_names[0].to_string(); + + self.tool_tracker.insert(id.clone(), tool); + + // Add the ToolCall event to the conversation immediately so it appears + // in the view. Preview data is sourced from tool_tracker. + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); // Client tool calls can only happen at the last part of a turn self.interaction.streaming_status = None; @@ -606,54 +614,56 @@ impl Session { self.interaction.mode = AppMode::Generating; } - // ===== Query methods ===== + // ===== Tool lifecycle methods ===== - /// Get a pending tool call by ID - pub(crate) fn pending_tool_call(&self, id: &str) -> Option<&PendingToolCall> { - self.pending_tool_calls.iter().find(|call| call.id == id) - } + /// Finish a tool call: transition tracker to Completed, push ToolResult to conversation. + /// + /// For shell commands, captures the final preview from the ExecutingWithPreview phase + /// and patches exit_code/interrupted from the authoritative ToolOutcome. + pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) { + let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview()); - /// Get a mutable pending tool call by ID - pub(crate) fn pending_tool_call_mut(&mut self, id: &str) -> Option<&mut PendingToolCall> { - self.pending_tool_calls - .iter_mut() - .find(|call| call.id == id) - } + // Patch preview with authoritative outcome data (handles race where + // final VT100 update hasn't been applied yet). + if let Some(ref mut p) = preview + && let ToolOutcome::Structured { + exit_code, + interrupted, + .. + } = &outcome + { + p.interrupted = *interrupted; + if p.exit_code.is_none() { + p.exit_code = *exit_code; + } + } - /// Record a tool call event in the conversation. - /// Call this BEFORE execution begins so the ToolCall shows in chat output. - pub fn begin_tool_call(&mut self, id: &str, tool: &ClientToolCall, input: serde_json::Value) { - let desc = tool.descriptor(); - self.add_tool_call(id.to_string(), desc.canonical_names[0].to_string(), input); - } + // Transition tracker entry to Completed + if let Some(tracked) = self.tool_tracker.get_mut(tool_id) { + tracked.complete(preview); + } - /// Record the result of a tool call and remove it from the pending queue. - /// Call this AFTER execution completes. The ToolCall event must already exist - /// in the conversation (added by `begin_tool_call`). - pub fn finish_tool_call(&mut self, pending: &PendingToolCall, outcome: ToolOutcome) { let content = outcome.format_for_llm(); let is_error = outcome.is_error(); self.conversation - .add_tool_result(pending.id.clone(), content, is_error); - self.pending_tool_calls.retain(|c| c.id != pending.id); - } - - /// Record a tool call, its execution result, and remove it from the pending queue. - /// Convenience method that combines begin + finish for tools that don't need - /// the ToolCall visible during execution. - pub fn complete_tool_call(&mut self, pending: &PendingToolCall, outcome: ToolOutcome) { - self.begin_tool_call(&pending.id, &pending.tool, serde_json::json!({})); - self.finish_tool_call(pending, outcome); + .add_tool_result(tool_id.to_string(), content, is_error); } - /// Returns true if any tool calls are still in CheckingPermissions or AskingForPermission state. - pub fn has_unresolved_tool_calls(&self) -> bool { - self.pending_tool_calls.iter().any(|tc| { - matches!( - tc.state, - ToolCallState::CheckingPermissions | ToolCallState::AskingForPermission - ) - }) + /// Record a tool call event + its result in one step (for simple non-preview tools). + pub fn complete_tool_call( + &mut self, + tool_id: &str, + tool: &ClientToolCall, + outcome: ToolOutcome, + ) { + // Push the ToolCall event so it appears in the conversation + let desc = tool.descriptor(); + self.add_tool_call( + tool_id.to_string(), + desc.canonical_names[0].to_string(), + serde_json::json!({}), + ); + self.finish_tool_call(tool_id, outcome); } /// Get the footer text for current mode @@ -671,7 +681,6 @@ impl Session { } } AppMode::Generating | AppMode::Streaming => "[Esc] Cancel", - AppMode::ExecutingPreview => "[Ctrl+C] Interrupt [Esc] Interrupt", AppMode::Error => "[Enter]/[r] Retry [Esc] Exit", } } diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index a60d79787fc..e44f3068876 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -6,7 +6,7 @@ use eye_declare::{ }; use ratatui_core::style::{Color, Modifier, Style}; -use crate::tools::{ClientToolCall, PendingToolCall, ToolCallState}; +use crate::tools::{ClientToolCall, TrackedTool}; use crate::tui::components::select::SelectOption; use crate::tui::events::{AiTuiEvent, PermissionResult}; @@ -27,7 +27,7 @@ mod turn; /// - Spacer /// - Input box (bordered, with contextual keybindings) pub(crate) fn ai_view(state: &Session) -> Elements { - let mut turn_builder = turn::TurnBuilder::new(); + let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker); for event in &state.conversation.events { turn_builder.add_event(event); @@ -44,6 +44,7 @@ pub(crate) fn ai_view(state: &Session) -> Elements { has_command: state.conversation.has_any_command(), is_input_blank: state.interaction.is_input_blank, pending_confirmation: state.interaction.confirmation_pending, + has_executing_preview: state.tool_tracker.has_executing_preview(), ) { #(for (index, turn) in turns.iter().enumerate() { #(match turn { @@ -67,26 +68,14 @@ pub(crate) fn ai_view(state: &Session) -> Elements { } fn input_view(state: &Session) -> Elements { - let asking_tool = state - .pending_tool_calls - .iter() - .find(|call| call.state == ToolCallState::AskingForPermission); - - let executing_tool = state - .pending_tool_calls - .iter() - .find(|call| matches!(call.state, ToolCallState::ExecutingPreview { .. })); + let asking_tool = state.tool_tracker.asking_for_permission(); element! { #(if let Some(tc) = asking_tool { #(tool_call_view(tc)) }) - #(if let Some(tc) = executing_tool { - #(executing_preview_view(tc)) - }) - - #(if asking_tool.is_none() && executing_tool.is_none() { + #(if asking_tool.is_none() { View(key: "input-box", padding_top: Cells::from(1)) { InputBox( key: "input", @@ -108,7 +97,7 @@ fn input_view(state: &Session) -> Elements { } } -fn tool_call_view(tool_call: &PendingToolCall) -> Elements { +fn tool_call_view(tool_call: &TrackedTool) -> Elements { let verb = tool_call.tool.descriptor().display_verb; let tool_desc = match &tool_call.tool { ClientToolCall::Read(tool) => tool.path.display().to_string(), @@ -157,70 +146,6 @@ fn tool_call_view(tool_call: &PendingToolCall) -> Elements { } } -fn executing_preview_view(tool_call: &PendingToolCall) -> Elements { - let (command, output_lines, exit_code, interrupted) = match &tool_call.state { - ToolCallState::ExecutingPreview { - command, - output_lines, - exit_code, - interrupted, - } => ( - command.clone(), - output_lines.clone(), - *exit_code, - *interrupted, - ), - _ => return element! {}, - }; - - let spinner_done = exit_code.is_some() || interrupted; - - element! { - View(key: format!("preview-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { - // Command header with spinner - Spinner( - label: format!(" Running: {}", command), - label_style: Style::default().fg(Color::Yellow), - done: spinner_done, - ) - - // Fixed-height viewport showing the VT100 screen output - Viewport( - lines: output_lines, - height: 10, - border: BorderType::Plain, - border_style: Style::default().fg(Color::DarkGray), - style: Style::default().fg(Color::White), - ) - - // Status line - #(if let Some(code) = exit_code { - #(if code == 0 { - Text { - Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) - } - } else { - Text { - Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) - } - }) - }) - - #(if interrupted { - Text { - Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) - } - }) - - #(if !spinner_done { - Text { - Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) - } - }) - } - } -} - fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { let label_style = Style::default() .fg(Color::Cyan) @@ -282,11 +207,50 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { suggested_command_view(details) }, turn::UiEvent::ToolCall(details) => { + let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted); + let tool_key = details.tool_use_id.clone(); + element! { - View(padding_left: Cells::from(2)) { - Text { - Span(text: format!("Running tool: {}", details.name), style: Style::default().fg(Color::Blue)) - } + View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) { + #(if let Some(ref preview) = details.preview { + View(key: format!("preview-{tool_key}")) { + #(preview_spinner_view(&details.name, preview_done)) + Viewport( + key: format!("viewport-{tool_key}"), + lines: preview.lines.clone(), + height: 10, + border: BorderType::Plain, + border_style: Style::default().fg(Color::DarkGray), + style: Style::default().fg(Color::White), + wrap: false, + ) + #(if let Some(code) = preview.exit_code { + #(if code == 0 { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green)) + } + } else { + Text { + Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red)) + } + }) + }) + #(if preview.interrupted { + Text { + Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) + } + }) + #(if !preview_done { + Text { + Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) + } + }) + } + } else { + Text { + Span(text: format!("Running tool: {}", details.name), style: Style::default().fg(Color::Blue)) + } + }) } } } @@ -334,6 +298,38 @@ fn tool_summary_view(summary: &turn::ToolSummary) -> Elements { } } +const SPINNER_FRAMES: &[&str] = &["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + +/// Render a spinner/status line for a command preview. +/// +/// Uses the system clock to compute the animation frame so it advances on +/// every re-render (triggered by output updates) without needing a separate +/// interval timer. This works around eye_declare's use_interval resetting +/// last_tick on every rebuild. +fn preview_spinner_view(name: &str, done: bool) -> Elements { + if done { + element! { + Text { + Span(text: "✓ ", style: Style::default().fg(Color::Green)) + Span(text: format!("Ran: {name}"), style: Style::default().fg(Color::Green)) + } + } + } else { + let millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let frame = (millis / 80) as usize % SPINNER_FRAMES.len(); + + element! { + Text { + Span(text: format!("{} ", SPINNER_FRAMES[frame]), style: Style::default().fg(Color::DarkGray)) + Span(text: format!("Running: {name}"), style: Style::default().fg(Color::Yellow)) + } + } + } +} + fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { let is_dangerous = matches!( details.danger_level, diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index c92785c4ea8..6949236c300 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -1,4 +1,5 @@ use crate::tools::descriptor; +use crate::tools::{ToolPreview, ToolTracker}; use crate::tui::ConversationEvent; /// Server-sent danger level for a suggested command @@ -92,6 +93,7 @@ pub(crate) struct ToolCallDetails { pub(crate) name: String, pub(crate) status: ToolResultStatus, pub(crate) is_client: bool, + pub(crate) preview: Option, } #[derive(Debug)] @@ -122,17 +124,19 @@ pub(crate) enum UiTurn { OutOfBand { events: Vec }, } -pub(crate) struct TurnBuilder { +pub(crate) struct TurnBuilder<'a> { turns: Vec, current_turn: Option, + tracker: &'a ToolTracker, } /// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. -impl TurnBuilder { - pub(crate) fn new() -> Self { +impl<'a> TurnBuilder<'a> { + pub(crate) fn new(tracker: &'a ToolTracker) -> Self { Self { turns: Vec::new(), current_turn: None, + tracker, } } @@ -311,15 +315,17 @@ impl TurnBuilder { } fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { + let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); + let preview = self.tracker.preview_for(id); + self.start_agent_turn(); if let UiTurn::Agent { events } = self.turn_mut_unsafe() { - let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client); - events.push(UiEvent::ToolCall(ToolCallDetails { tool_use_id: id.to_string(), name: name.to_string(), status: ToolResultStatus::Pending, is_client, + preview, })); } } diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index 7a7dc153238..af6464cebdd 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -150,11 +150,17 @@ impl Cmd { daemon::daemonize_current_process()?; } - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let mut runtime = if matches!(&self, Self::Ai(_)) { + tokio::runtime::Builder::new_multi_thread() + } else { + tokio::runtime::Builder::new_current_thread() + }; + + let runtime = runtime.enable_all().build().unwrap(); + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. let settings = Settings::new().wrap_err("could not load client settings")?; let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); let res = runtime.block_on(self.run_inner(settings, theme_manager)); From cb6b12db49fd77f86ccef87c73693192f6d86836 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Thu, 9 Apr 2026 12:10:50 -0700 Subject: [PATCH 21/52] Change client capabilities to individual tools --- crates/atuin-ai/src/commands/inline.rs | 22 ---------------------- crates/atuin-ai/src/stream.rs | 10 ++++++++-- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 0d1adaff09e..59b98d52441 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -126,28 +126,6 @@ async fn run_inline_tui(ctx: AppContext, initial_prompt: Option) -> Resu let initial_state = Session::new(); let client_ctx = ClientContext::detect(); - // initial_state - // .pending_tool_calls - // .push_back(crate::tools::PendingToolCall { - // id: "1".to_string(), - // state: crate::tools::ToolCallState::CheckingPermissions, - // tool: crate::tools::ClientToolCall::Read(crate::tools::ReadToolCall { - // path: std::path::PathBuf::from("test.txt"), - // }), - // }); - // initial_state - // .pending_tool_calls - // .push_back(crate::tools::PendingToolCall { - // id: "2".to_string(), - // state: crate::tools::ToolCallState::CheckingPermissions, - // tool: crate::tools::ClientToolCall::Shell(crate::tools::ShellToolCall { - // dir: None, - // command: "ls -lah".to_string(), - // }), - // }); - - // let _ = tx.send(AiTuiEvent::CheckToolCallPermission("1".to_string())); - // let _ = tx.send(AiTuiEvent::CheckToolCallPermission("2".to_string())); println!(); diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 6a0f5c264c1..e7d9080c755 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -53,7 +53,6 @@ pub(crate) enum StreamFrame { pub(crate) struct ChatRequest { pub messages: Vec, pub session_id: Option, - /// Requested capabilities. Currently always ["client_tools_v1"]. pub capabilities: Vec, } @@ -62,7 +61,14 @@ impl ChatRequest { Self { messages, session_id, - capabilities: vec!["client_tools_v1".to_string()], + capabilities: vec![ + "client_v1_read_file".to_string(), + "client_v1_atuin_history".to_string(), + "client_v1_execute_shell_command".to_string(), + // "client_v1_create_file".to_string() + // "client_v1_append_to_file".to_string() + // "client_v1_str_replace".to_string() + ], } } } From d380aa399336b91fd35850faa3692d39e45d47e1 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Thu, 9 Apr 2026 12:31:19 -0700 Subject: [PATCH 22/52] Fix list item rendering --- .../atuin-ai/src/tui/components/markdown.rs | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs index 6bbcf41b63d..98a1170b488 100644 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ b/crates/atuin-ai/src/tui/components/markdown.rs @@ -98,6 +98,10 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat let mut style_stack: Vec