diff --git a/.gitignore b/.gitignore index 23b7ea60..43d9f3ff 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,9 @@ override.tf.json # Ignore transient lock info files created by terraform apply .terraform.tfstate.lock.info +# Local vendored mistral.rs for PIC development +/mistral.rs + # macOS ephemera .DS_Store .env diff --git a/Cargo.lock b/Cargo.lock index b6bb2fa7..de057127 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -516,17 +516,6 @@ dependencies = [ "sha2 0.10.9", ] -[[package]] -name = "bindgen_cuda" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be55fb326843bb67cccceeeaf21c961ef303f60018f9a2ab69494dad8eaf9" -dependencies = [ - "glob", - "num_cpus", - "rayon", -] - [[package]] name = "bit-set" version = "0.5.3" @@ -765,15 +754,14 @@ checksum = "ade8366b8bd5ba243f0a58f036cc0ca8a2f069cff1a2351ef1cac6b083e16fc0" [[package]] name = "candle-core" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c15b675b80d994b2eadb20a4bbe434eabeb454eac3ee5e2b4cf6f147ee9be091" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "byteorder", "candle-kernels", "candle-metal-kernels", "candle-ug", "cudarc 0.19.2", - "float8 0.6.1", + "float8", "gemm 0.19.0", "half", "libm", @@ -794,34 +782,22 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c94ddd2e7bb828777b0a8d999ed40d2d6c3c96c9ef2a3111a69e0d96efc436d2" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "anyhow", - "bindgen_cuda", "candle-core", - "candle-flash-attn-build", + "cudaforge", "half", ] -[[package]] -name = "candle-flash-attn-build" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd79da06f2a3b831cb4f5a1ee393d6f2c5a913e28f5000c678a84108519a78c" -dependencies = [ - "anyhow", -] - [[package]] name = "candle-flash-attn-v3" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1b66ff9f2e95c7502a75ab7304df861c82326b62fef4cdca381c40144afb22" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "anyhow", "candle-core", - "candle-flash-attn-build", + "cudaforge", "half", "num_cpus", "rayon", @@ -830,17 +806,15 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8455f84bd810047c7c41216683c1020c915a9f8a740b3b0eabdd4fb2fbaa660" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ - "bindgen_cuda", + "cudaforge", ] [[package]] name = "candle-metal-kernels" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fdfe9d06de16ce49961e49084e5b79a75a9bdf157246e7c7b6328e87a7aa25d" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "half", "objc2", @@ -854,8 +828,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3045fa9e7aef8567d209a27d56b692f60b96f4d0569f4c3011f8ca6715c65e03" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "candle-core", "candle-metal-kernels", @@ -872,8 +845,7 @@ dependencies = [ [[package]] name = "candle-ug" version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c22d62be69068bf58987a45f690612739d8d2ea1bf508c1b87dc6815a019575d" +source = "git+https://github.com/huggingface/candle.git?rev=c3bb5bf#c3bb5bfb90c9cb5e003a803e3d8941067f32d880" dependencies = [ "ug", "ug-cuda", @@ -1508,7 +1480,7 @@ version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aed81f178e780f3d5d354d12b4c5c5a484c4a9c329ecd037ac57f2a0e0648397" dependencies = [ - "float8 0.7.0", + "float8", "half", "libloading 0.9.0", ] @@ -1877,6 +1849,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" dependencies = [ "bitflags 2.10.0", + "block2", "objc2", ] @@ -2366,26 +2339,16 @@ dependencies = [ [[package]] name = "float8" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719a903cc23e4a89e87962c2a80fdb45cdaad0983a89bd150bb57b4c8571a7d5" +checksum = "c2d1f04709a8ac06e8e8042875a3c466cc4832d3c1a18dbcb9dba3c6e83046bc" dependencies = [ - "cudarc 0.19.2", "half", "num-traits", "rand 0.9.2", "rand_distr 0.5.1", ] -[[package]] -name = "float8" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2d1f04709a8ac06e8e8042875a3c466cc4832d3c1a18dbcb9dba3c6e83046bc" -dependencies = [ - "half", -] - [[package]] name = "fnv" version = "1.0.7" @@ -3586,7 +3549,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.2", + "windows-core 0.58.0", ] [[package]] @@ -4633,8 +4596,8 @@ dependencies = [ [[package]] name = "mistralrs" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "anyhow", "candle-core", @@ -4651,6 +4614,7 @@ dependencies = [ "schemars 1.2.1", "serde", "serde_json", + "thiserror 2.0.18", "tokio", "tracing", "tracing-subscriber", @@ -4659,8 +4623,8 @@ dependencies = [ [[package]] name = "mistralrs-audio" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "anyhow", "apodize", @@ -4670,8 +4634,8 @@ dependencies = [ [[package]] name = "mistralrs-core" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "ahash", "akin", @@ -4697,7 +4661,7 @@ dependencies = [ "derive_more", "dirs", "either", - "float8 0.6.1", + "float8", "futures", "galil-seiferas", "half", @@ -4756,6 +4720,7 @@ dependencies = [ "tokio", "tokio-rayon", "tokio-tungstenite", + "toktrie", "toktrie_hf_tokenizers", "toml", "tqdm", @@ -4769,8 +4734,8 @@ dependencies = [ [[package]] name = "mistralrs-macros" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -4780,8 +4745,8 @@ dependencies = [ [[package]] name = "mistralrs-mcp" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "anyhow", "async-trait", @@ -4800,14 +4765,15 @@ dependencies = [ [[package]] name = "mistralrs-paged-attn" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "anyhow", "candle-core", "candle-metal-kernels", "cudaforge", - "float8 0.6.1", + "dispatch2", + "float8", "half", "objc2-foundation", "objc2-metal", @@ -4816,15 +4782,16 @@ dependencies = [ [[package]] name = "mistralrs-quant" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "byteorder", "candle-core", "candle-metal-kernels", "candle-nn", "cudaforge", - "float8 0.6.1", + "dispatch2", + "float8", "half", "hf-hub", "lazy_static 1.5.0", @@ -4845,8 +4812,8 @@ dependencies = [ [[package]] name = "mistralrs-vision" -version = "0.7.0" -source = "git+https://github.com/EricLBuehler/mistral.rs?rev=dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4#dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4" +version = "0.7.1" +source = "git+https://github.com/starpit/mistral.rs?branch=pic-cache-reuse#e32243c7aa7e5373f4d8384dd7d3534c606e7ad0" dependencies = [ "candle-core", "image", @@ -5674,7 +5641,7 @@ checksum = "9cd31dcfdbbd7431a807ef4df6edd6473228e94d5c805e8cf671227a21bad068" dependencies = [ "anyhow", "clap", - "itertools 0.14.0", + "itertools 0.13.0", "proc-macro2", "quote", "rand 0.8.5", @@ -6045,7 +6012,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.114", diff --git a/Cargo.toml b/Cargo.toml index baa73580..674bdcbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,16 @@ [workspace] resolver = "3" members = ["cli", "spnl"] -exclude = ["cli"] +exclude = ["cli", "mistral.rs"] [workspace.dependencies] cargo-husky = { version = "1", default-features = false, features = ["user-hooks"] } +# For continued development of PIC work, uncomment: +# [patch."https://github.com/starpit/mistral.rs"] +# mistralrs = { path = "mistral.rs/mistralrs" } +# mistralrs-core = { path = "mistral.rs/mistralrs-core" } + [workspace.metadata.cross.target.x86_64-unknown-linux-gnu] # Install libssl-dev:arm64, see pre-build = [ diff --git a/README_PIC.md b/README_PIC.md new file mode 100644 index 00000000..ad1cea2b --- /dev/null +++ b/README_PIC.md @@ -0,0 +1,144 @@ +# Position-Independent Caching (PIC) for spnl Plus + +PIC enables the mistral.rs backend to **reuse KV cache entries regardless of where they appear in a sequence**. This dramatically improves cache locality for multi-turn RAG workloads where the same documents are injected at different positions across requests. + +For engine internals (deferred RoPE, block-attention masking, cache assembly, optimization decisions), see [`mistral.rs/PIC.md`](mistral.rs/PIC.md). + +## Problem + +Standard transformer KV caches store positionally-encoded (RoPE'd) key tensors. A cached entry for "document A at position 100" cannot be reused as "document A at position 500" because the rotary position encoding differs. This means every time a RAG document appears at a new position, the entire KV cache for that document must be recomputed from scratch. + +spnl's `Plus` operator (`+`) marks input fragments as **position-independent** (commutative) -- the semantic contract is that the model's output should be invariant to the ordering of Plus blocks. PIC exploits this contract to share cached computations. + +## Usage in spnl + +### Query structure + +Use the `Plus` operator to mark position-independent blocks in your spnl query: + +```lisp +(seq + (system "You are a helpful assistant.") + (plus + (user "Document 1: The capital of France is Paris...") + (user "Document 2: Quantum computing uses qubits...")) + (user "What is the capital of France?")) +``` + +The `plus` block tells spnl that "Document 1" and "Document 2" are commutative -- their order doesn't matter, and their KV cache can be reused if they appear in a different position in a subsequent request. + +### JSON equivalent + +```json +{ + "seq": [ + {"system": "You are a helpful assistant."}, + {"plus": [ + {"user": "Document 1: The capital of France is Paris..."}, + {"user": "Document 2: Quantum computing uses qubits..."} + ]}, + {"user": "What is the capital of France?"} + ] +} +``` + +### Multi-turn RAG example + +On turn 1: + +```lisp +(seq + (system "You are a helpful assistant.") + (plus (user "Doc A: ...") (user "Doc B: ...")) + (user "Question about Doc A?")) +``` + +On turn 2, the documents appear in a different position but their cached KV entries are reused: + +```lisp +(seq + (system "You are a helpful assistant.") + (user "Question about Doc A?") + (assistant "Answer about Doc A...") + (plus (user "Doc A: ...") (user "Doc C: ...")) + (user "Follow-up question?")) +``` + +Doc A's KV cache from turn 1 is reused in turn 2 despite appearing at a different sequence position. + +## Benchmarking + +The `spnl bench pic` command measures PIC cache reuse. Output is controlled by `-o/--output` (comma-separated, default: `speedup`): + +| Mode | Description | +|------|-------------| +| `speedup` | TTFT speedup ratio (no-cache vs PIC reuse) | +| `latency` | Prefix and PIC p50 latency in ms | +| `hitrate` | PIC cache hit rate | +| `iqr` | PIC interquartile range in ms | +| `json` | All TTFT data as JSON | +| `accuracy` | Plus (PIC) vs flat (causal) response comparison (token F1) | + +Multiple modes can be combined, e.g. `-o speedup,accuracy`. + +```sh +# Single doc length (default: 200 words) +spnl bench pic -m llama3.1:8b + +# Sweep across doc sizes (xs=10w, sm=50w, m=200w, lg=500w, xl=1000w, xxl=2000w) +spnl bench pic -m llama3.1:8b -s xs,sm,m,lg,xl,xxl + +# Select specific sizes +spnl bench pic -m llama3.1:8b -s sm,m,xl + +# Multiple models +spnl bench pic -m llama3.2:1b,llama3.1:8b -s sm,m,xl + +# Full sweep: all sizes × default model set +spnl bench pic --full + +# Accuracy comparison (Plus vs flat attention) +spnl bench pic -m llama3.1:8b -o accuracy + +# Both speedup and accuracy +spnl bench pic -m llama3.1:8b -s sm,m,xl -o speedup,accuracy + +# Accuracy with LLM judge +spnl bench pic -m llama3.1:8b -o accuracy --grading-model llama3.2:3b +``` + +Build with `--features bench,metal` (or `bench,cuda`). Works with `local/` prefix models or pretty names (e.g. `llama3.2:3b`). + +### Protocol + +For each doc-length, the benchmark runs multiple trials. Each trial: +1. Generates fresh synthetic documents (unique per trial) +2. Sends a **no-cache** request (first time these docs are seen -- full prefill) +3. Sends N **reuse** requests with the same docs shuffled into different orders + +With PIC active, reuse requests skip prefill for all Plus blocks (documents) and only compute KV for Cross tokens (system prompt + question). + +### Sample results (llama3.1:8b, Metal) + +``` + Doc Size No-cache p50 Reuse p50 Speedup Saved Hit Rate + ──────────── ──────────── ──────────── ────────── ────────── ────────── + xs 10w 640 ms 272 ms 2.35x 57.4% 100% + s 50w 1551 ms 278 ms 5.59x 82.1% 100% + m 200w 5146 ms 292 ms 17.62x 94.3% 100% + l 500w 13005 ms 367 ms 35.48x 97.2% 100% + xl 1000w 30030 ms 449 ms 66.89x 98.5% 100% + xxl 2000w 81731 ms 680 ms 120.24x 99.2% 100% +``` + +## Comparison with alternatives + +### CacheBlend / LMCache + +[CacheBlend](https://arxiv.org/pdf/2405.16444) reuses position-encoded KV cache from one context in a different position and applies a small correction. It works as an external layer without model changes but produces **approximate** results. + +PIC stores un-rotated K and applies RoPE at attention time using the correct position for each token, so positional encoding is always consistent with the current layout. The tradeoff is that PIC requires model-level changes (deferred RoPE path) and re-applies RoPE to the full cached K sequence on each forward step. + +### vLLM block attention + +spnl's vLLM backend already supports position-independent caching via block attention at the serving layer. PIC brings the same capability to the mistral.rs backend for local/on-device inference. diff --git a/cli/src/bench/mod.rs b/cli/src/bench/mod.rs index 76a55606..2b900bab 100644 --- a/cli/src/bench/mod.rs +++ b/cli/src/bench/mod.rs @@ -1,5 +1,6 @@ mod haystack; mod niah; +mod pic; mod ragcsv; mod ruler; @@ -16,6 +17,8 @@ pub enum BenchCommands { Ruler(ruler::RulerArgs), /// RAG CSV evaluation (accuracy grading from a CSV dataset) Ragcsv(ragcsv::RagcsvArgs), + /// PIC cross-request cache reuse benchmark (latency) + Pic(pic::PicArgs), } pub async fn run(command: BenchCommands) -> Result<(), SpnlError> { @@ -24,6 +27,7 @@ pub async fn run(command: BenchCommands) -> Result<(), SpnlError> { BenchCommands::Niah(args) => niah::run(args), BenchCommands::Ruler(args) => ruler::run(args), BenchCommands::Ragcsv(args) => ragcsv::run(args).await, + BenchCommands::Pic(args) => pic::run(args).await, } } diff --git a/cli/src/bench/pic.rs b/cli/src/bench/pic.rs new file mode 100644 index 00000000..eebbce24 --- /dev/null +++ b/cli/src/bench/pic.rs @@ -0,0 +1,1719 @@ +//! PIC (Position-Independent Caching) cross-request benchmark. +//! +//! Measures the **prefill latency** (TTFT) benefit of reusing cached Plus block +//! KV entries across requests that contain the same documents in different orders. +//! Also reports PIC hit/miss rates. +//! +//! # What is measured +//! +//! - **TTFT (time to first token)**: Approximated by setting `max_tokens=1` so +//! wall-clock time ≈ prefill time. This is where PIC saves work — cached Plus +//! blocks skip prefill entirely. +//! - **PIC hit rate**: How many reuse requests found all their Plus blocks +//! in the content-based PIC. +//! - **Accuracy** (`-o accuracy`): Tests ground-truth correctness using fictional +//! factual documents with verifiable answers. Runs three queries per trial: +//! flat (causal), PIC (Plus blocks), and PIC-shuffled (Plus blocks, shuffled +//! doc order). Reports `flat/pic/shuf` correctness counts. Token F1 and +//! optional LLM-judge (`--grading-model`) are secondary metrics. +//! +//! # Options +//! +//! - **`--length`/`-l`**: Words per document for TTFT, or max tokens for accuracy (default: 200). +//! - **`--size`/`-s`**: Sweep across doc lengths, e.g. `-s xs,sm,m,lg,xl,xxl`. +//! - **`--model`/`-m`**: One or more models: `-m model1,model2`. +//! - **`--full`**: Sugar for `-s xs,sm,m,lg,xl,xxl -m llama3.2:1b,llama3.2:3b,llama3.1:8b,qwen2.5:0.5b,qwen2.5:14b`. +//! +//! Output is controlled by `-o/--output` (comma-separated): speedup, iqr, hitrate, +//! latency, json, accuracy. Multiple modes can be combined, e.g. `-o speedup,accuracy`. +//! +//! # Prerequisites +//! +//! PIC cross-request caching only works with the **local** (mistral.rs) backend. +//! The model must resolve to the local backend, either via: +//! - A `local/` prefix: `-m local/my-hf-model-id` +//! - A pretty name: `-m llama3.2:3b` (resolved to a HuggingFace GGUF model) +//! +//! Remote backends (`ollama/`, `openai/`, `gemini/`, `spnl/`) flatten Plus blocks +//! into ordinary messages and do not support PIC caching. +//! +//! ## Required features +//! +//! Build with `--features bench,local` (or `bench,metal` / `bench,cuda`): +//! +//! ```sh +//! cargo run -p spnl-cli --features bench,metal -- bench pic -m llama3.2:3b -s xs,sm,m +//! cargo run -p spnl-cli --features bench,metal -- bench pic --full +//! ``` +//! +//! ## Required environment variables +//! +//! The PIC path is activated by sentinel token IDs that delimit Plus/Cross +//! blocks in the tokenized sequence. Set these to the token IDs your model uses +//! for the Plus and Cross sentinel tokens: +//! +//! ```sh +//! export SPNL_PIC_PLUS_TOKEN=128011 # example: a reserved special token +//! export SPNL_PIC_CROSS_TOKEN=128012 # example: a reserved special token +//! ``` +//! +//! Without these env vars, cross-request cache reuse will not activate. The +//! benchmark will still run, but reuse requests will show no speedup and the +//! hit rate will be 0%. +//! +//! ## Verifying PIC is active +//! +//! Run with `RUST_LOG=info` to see per-request cache hit messages: +//! +//! ```text +//! INFO ... PIC hit: N Plus blocks reused +//! ``` +//! +//! # Protocol +//! +//! For each model, for each doc-length configuration, for each trial: +//! 1. Generate fresh synthetic documents (unique per trial to avoid inter-trial hits) +//! 2. **No-cache** request (`max_tokens=1`): first time these docs are seen — full prefill +//! 3. N **reuse** requests (`max_tokens=1`): same docs shuffled into different orders +//! - With PIC active, Plus block KVs are reused; only Cross tokens need prefill + +use indicatif::{ProgressBar, ProgressStyle}; +use rand::Rng; +use rand::seq::SliceRandom; +use spnl::{ + ExecuteOptions, SpnlError, execute, + ir::{Message::Assistant, Query}, + optimizer::hlo, + spnl, +}; +use std::time::Instant; + +/// Doc-length spectrum for size sweeps. +/// Each entry is (words_per_doc, short_name). +const FULL_SPECTRUM: &[(usize, &str)] = &[ + (10, "xs"), + (50, "sm"), + (200, "m"), + (500, "lg"), + (1000, "xl"), + (2000, "xxl"), +]; + +/// All valid t-shirt size names, for help text. +const ALL_SIZES: &str = "xs,s(m),m,l(g),xl,xxl"; + +/// Default models for `--full` mode. +const FULL_MODELS: &[&str] = &[ + "llama3.2:1b", + "llama3.2:3b", + "llama3.1:8b", + "qwen2.5:0.5b", + "qwen2.5:14b", +]; + +/// Resolve size names into spectrum entries. +/// Accepts canonical names (xs, sm, m, lg, xl, xxl) and short aliases (s, l). +/// Returns `Err` if any size name is unrecognized. +fn resolve_spectrum(sizes: &[String]) -> Result, String> { + let mut selected = Vec::new(); + for s in sizes { + let s = s.trim(); + // Allow short aliases: s→sm, l→lg + let canonical = match s { + "s" => "sm", + "l" => "lg", + other => other, + }; + match FULL_SPECTRUM.iter().find(|(_, name)| *name == canonical) { + Some(&(len, name)) => selected.push((len, name.to_string())), + None => return Err(s.to_string()), + } + } + Ok(selected) +} + +#[derive(clap::ValueEnum, Clone, Debug, PartialEq, Eq, serde::Serialize)] +pub enum OutputMode { + /// Speedup ratio: "3.75x" + Speedup, + /// PIC interquartile range (ms): "90,150" + Iqr, + /// PIC cache hit rate: "100%" + Hitrate, + /// Prefix and PIC p50 latency (ms): "450,120" + Latency, + /// All data as JSON + Json, + /// Accuracy: Plus (PIC) vs flat (causal) response comparison + Accuracy, +} + +#[derive(clap::Args, Debug, serde::Serialize)] +pub struct PicArgs { + /// Generative model(s): -m model1,model2 or -m model1 -m model2 + #[arg(short, long, value_delimiter = ',', env = "BENCH_MODEL")] + pub model: Vec, + + /// Document sizes to sweep (comma-separated): -s xs,sm,m,lg,xl,xxl + /// Sizes: xs(10w) sm(50w) m(200w) lg(500w) xl(1000w) xxl(2000w) + #[arg(short, long, value_delimiter = ',')] + pub size: Vec, + + /// Number of documents (Plus blocks) + #[arg(long, default_value_t = 4, env = "BENCH_PIC_NUM_DOCS")] + pub num_docs: usize, + + /// Length parameter (used when no --size is specified). + /// For TTFT modes: approximate words per document (default: 200). + /// For accuracy mode: max tokens per response (default: 200). + #[arg(short, long, default_value_t = 200, env = "BENCH_PIC_LENGTH")] + pub length: usize, + + /// Number of reuse (reshuffled) requests after the initial no-cache request + #[arg(long, default_value_t = 5, env = "BENCH_PIC_REUSE_ITERS")] + pub reuse_iters: usize, + + /// Number of full trials (no-cache + reuse cycle) per doc-length + #[arg(long, default_value_t = 3, env = "BENCH_PIC_TRIALS")] + pub trials: usize, + + /// Full sweep: all sizes × default models + /// Equivalent to -s xs,sm,m,lg,xl,xxl -m llama3.2:1b,llama3.2:3b,llama3.1:8b,qwen2.5:0.5b,qwen2.5:14b + #[arg(long)] + pub full: bool, + + /// Output mode(s): -o speedup,accuracy or -o latency -o accuracy + /// Modes: speedup, iqr, hitrate, latency, json, accuracy + #[arg( + short = 'o', + long, + value_delimiter = ',', + value_enum, + default_value = "speedup" + )] + pub output: Vec, + + /// Grading model for LLM-judge semantic equivalence scoring. + /// If omitted, only token F1 is reported (no LLM judge). + #[arg(long, env = "BENCH_PIC_GRADING_MODEL")] + pub grading_model: Option, + + /// Print the optimized query tree before each request + #[arg(short, long)] + pub verbose: bool, +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn make_documents(num_docs: usize, doc_length: usize) -> Vec { + let mut rng = rand::rng(); + (0..num_docs) + .map(|i| { + format!( + "Document {}: The topic of this document is item-{}. {}", + i, + i, + lipsum::lipsum_words_with_rng(&mut rng, doc_length) + ) + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Factual document generation for accuracy testing +// --------------------------------------------------------------------------- +// +// Design notes — why this catches order-sensitivity bugs: +// +// The core idea is CROSS-DOCUMENT INFERENCE CHAINS. Each test uses a pair of +// documents where one defines an alias and the other uses it: +// +// doc_def: "A Vorpal is the local name for a school in Zephyria." +// doc_use: "There are 45 Vorpals in the capital district of Zephyria." +// question: "How many schools are in the capital district of Zephyria?" +// answer: "45" +// +// Answering requires composing "Vorpal = school" (from doc_def) with "45 +// Vorpals" (from doc_use). This makes document order mechanically relevant: +// +// - In CAUSAL (flat) attention, tokens can only attend to earlier positions. +// If doc_def precedes doc_use, the model can resolve the alias when it +// reaches doc_use. If the order is reversed, doc_use is processed before +// the definition exists — the model may fail. +// +// - In PIC (Plus block attention), each document is an independent Plus block +// whose KV cache is computed without attending to other blocks. The Cross +// block (question) then attends to ALL Plus blocks simultaneously. So +// document order should not matter — both docs are equally accessible. +// +// This gives us a deterministic, mechanistic test rather than relying on +// weak statistical signals like primacy/recency bias. +// +// The benchmark runs 4 queries per trial and the comparison matrix is: +// +// flat (causal, original order) → should work (def before use, causal chains forward) +// fshuf (causal, shuffled order) → may fail (use before def, causal can't look ahead) +// pic (Plus blocks, orig order) → should work (Plus blocks are independent) +// pshuf (Plus blocks, shuffled) → should work (same reason — PIC is order-invariant) +// +// Key diagnostic comparisons: +// +// flat vs fshuf → measures causal order-sensitivity (expected for inference chains) +// pic vs pshuf → measures PIC order-sensitivity (should be zero if PIC is correct) +// fshuf vs pshuf → PIC's value proposition: does Plus attention recover what causal loses? +// flat vs pic → sanity check: does Plus attention match causal in the easy case? +// +// If pshuf < pic, the PIC cache likely has a position-encoding bug: cached +// KV entries retain RoPE rotations from their original positions and produce +// garbled attention when reused at different positions. +// +// If fshuf ≈ pshuf (both low), Plus blocks may not be helping with cross-doc +// reasoning at all — the model is equally order-sensitive in both modes. +// +// Additional design choices: +// +// - FICTIONAL ALIASES: Names like "Vorpal" and "Zephyria" are fictional, so +// the model can't resolve them from pretraining. It must read both docs. +// +// - LIPSUM PADDING: Facts are buried in lipsum filler (controlled by +// --length). This forces real attention over long contexts. +// +// - FILLER DOCUMENTS: Additional lipsum-only docs are added to reach +// num_docs, acting as distractors that dilute attention. +// +// - TERSE ANSWER FORMAT: The question asks for ONLY the value. This keeps +// responses short and makes substring checking reliable. + +/// Inference-chain entries for cross-document accuracy testing. +/// +/// Each entry: (alias, real_meaning, context, quantity, unit, location). +/// +/// These produce two documents per entry: +/// doc_def: "A {alias} is the local name for a {real_meaning} in {location}." +/// doc_use: "There are {quantity} {alias}s {unit} in {location}." +/// +/// The question asks about the real_meaning, forcing the model to resolve the +/// alias across documents. +const INFERENCE_BANK: &[InferenceEntry] = &[ + InferenceEntry { + alias: "Vorpal", + real_meaning: "school", + location: "Zephyria", + quantity: "45", + unit: "in the capital district of", + }, + InferenceEntry { + alias: "Thalweg", + real_meaning: "hospital", + location: "Brontaal", + quantity: "12", + unit: "across the northern provinces of", + }, + InferenceEntry { + alias: "Crenn", + real_meaning: "bridge", + location: "Velouria", + quantity: "89", + unit: "spanning the rivers of", + }, + InferenceEntry { + alias: "Spelkraft", + real_meaning: "factory", + location: "Caskara", + quantity: "23", + unit: "along the coast of", + }, + InferenceEntry { + alias: "Darvon", + real_meaning: "park", + location: "Nimbustan", + quantity: "67", + unit: "within the borders of", + }, + InferenceEntry { + alias: "Quelm", + real_meaning: "library", + location: "Flarovia", + quantity: "31", + unit: "in the eastern region of", + }, + InferenceEntry { + alias: "Broxite", + real_meaning: "market", + location: "Glacendia", + quantity: "8", + unit: "in the highland towns of", + }, + InferenceEntry { + alias: "Vintara", + real_meaning: "temple", + location: "Terranova", + quantity: "54", + unit: "throughout the valleys of", + }, + InferenceEntry { + alias: "Pellith", + real_meaning: "harbor", + location: "Quintara", + quantity: "19", + unit: "dotting the shoreline of", + }, + InferenceEntry { + alias: "Orrenth", + real_meaning: "mine", + location: "Marivel", + quantity: "76", + unit: "beneath the mountains of", + }, + InferenceEntry { + alias: "Straven", + real_meaning: "theater", + location: "Pyrothen", + quantity: "14", + unit: "in the old quarter of", + }, + InferenceEntry { + alias: "Calyx", + real_meaning: "well", + location: "Solanthis", + quantity: "103", + unit: "across the desert of", + }, + InferenceEntry { + alias: "Nimrath", + real_meaning: "tower", + location: "Verdantia", + quantity: "37", + unit: "rising above the canopy of", + }, + InferenceEntry { + alias: "Glenth", + real_meaning: "dam", + location: "Aethermoor", + quantity: "5", + unit: "controlling the rivers of", + }, + InferenceEntry { + alias: "Feroshi", + real_meaning: "clinic", + location: "Korundel", + quantity: "28", + unit: "in the outer villages of", + }, + InferenceEntry { + alias: "Mordwyn", + real_meaning: "fort", + location: "Drakmere", + quantity: "41", + unit: "guarding the passes of", + }, +]; + +struct InferenceEntry { + alias: &'static str, + real_meaning: &'static str, + location: &'static str, + quantity: &'static str, + unit: &'static str, +} + +/// Result of generating accuracy test documents. +struct AccuracyDocs { + /// All documents in the "correct" order (definition before usage, then fillers). + docs: Vec, + /// The question that requires cross-document inference. + question: String, + /// The expected answer (substring to check for). + expected: String, +} + +/// Generate documents for one accuracy trial. +/// +/// Picks a random inference-chain entry, generates a definition doc and a +/// usage doc (in that order), then fills up to `num_docs` with lipsum filler +/// documents. All documents are padded to `doc_length` words. +/// +/// The returned `docs` are in "correct" order: definition first, usage second, +/// then fillers. The caller shuffles as needed for the shuffled test. +fn make_accuracy_docs(num_docs: usize, doc_length: usize) -> AccuracyDocs { + assert!(num_docs >= 2, "need at least 2 docs for inference chain"); + let mut rng = rand::rng(); + + // Pick a random inference chain + let entry = &INFERENCE_BANK[rng.random_range(0..INFERENCE_BANK.len())]; + + let mut pad = |text: &str| -> String { + let word_count = text.split_whitespace().count(); + if doc_length > word_count { + let padding = lipsum::lipsum_words_with_rng(&mut rng, doc_length - word_count); + format!("{text} {padding}") + } else { + text.to_string() + } + }; + + // Definition document: defines the alias + let doc_def = pad(&format!( + "A {} is the local name for a {} in {}.", + entry.alias, entry.real_meaning, entry.location, + )); + + // Usage document: uses the alias with a quantity + let doc_use = pad(&format!( + "There are {} {}s {} {}.", + entry.quantity, entry.alias, entry.unit, entry.location, + )); + + let mut docs = vec![doc_def, doc_use]; + + // Filler documents (lipsum only, as distractors) + for _ in 2..num_docs { + docs.push(lipsum::lipsum_words_with_rng(&mut rng, doc_length)); + } + + let question = format!( + "How many {}s are there {} {}? Answer with ONLY the number, nothing else.", + entry.real_meaning, entry.unit, entry.location, + ); + + AccuracyDocs { + docs, + question, + expected: entry.quantity.to_string(), + } +} + +/// Case-insensitive substring check for ground-truth answer verification. +fn check_answer(response: &str, expected: &str) -> bool { + response.to_lowercase().contains(&expected.to_lowercase()) +} + +/// Build a query with Plus-wrapped documents (PIC block attention). +fn build_query(model: &str, docs: &[String], question: &str, max_tokens: i32) -> Query { + let model = model.to_string(); + let system_prompt = + "You are a helpful assistant. Answer based on the provided documents.".to_string(); + let temperature: f32 = 0.0; + + let doc_messages: Vec = docs + .iter() + .map(|text| { + let text = text.clone(); + spnl!(user text) + }) + .collect(); + + let question = question.to_string(); + + spnl!( + g model + (cross + (system system_prompt) + (plus doc_messages) + (user question) + ) + temperature + max_tokens + ) +} + +/// Build a query with docs as plain sequential messages (standard causal attention, no Plus). +fn build_query_flat(model: &str, docs: &[String], question: &str, max_tokens: i32) -> Query { + use spnl::ir::{Generate, GenerateMetadata}; + + let system_prompt = + "You are a helpful assistant. Answer based on the provided documents.".to_string(); + + let mut messages: Vec = Vec::new(); + messages.push(spnl!(system system_prompt)); + for text in docs { + let text = text.clone(); + messages.push(spnl!(user text)); + } + let question = question.to_string(); + messages.push(spnl!(user question)); + + Query::Generate(Generate { + metadata: GenerateMetadata { + model: model.to_string(), + max_tokens: Some(max_tokens), + temperature: Some(0.0), + }, + input: Box::new(Query::Cross(messages)), + }) +} + +async fn timed_request( + query: &Query, + options: &ExecuteOptions, + verbose: bool, +) -> anyhow::Result<(f64, String)> { + let query = &hlo::optimize(query, &hlo::Options::default()).await?; + if verbose { + eprintln!("--- optimized query ---"); + ptree::write_tree(query, ::std::io::stderr())?; + eprintln!("--- end ---"); + } + let start = Instant::now(); + let result = execute(query, options).await?; + let ms = start.elapsed().as_secs_f64() * 1000.0; + let text = match result { + Query::Message(Assistant(s)) => s, + _ => String::new(), + }; + Ok((ms, text)) +} + +fn shuffle_docs(docs: &[String]) -> Vec { + let mut shuffled = docs.to_vec(); + let mut rng = rand::rng(); + shuffled.shuffle(&mut rng); + shuffled +} + +fn take_pic_stats() -> (u64, u64) { + #[cfg(feature = "local")] + { + spnl::pic_stats::take_cache_stats() + } + #[cfg(not(feature = "local"))] + { + (0, 0) + } +} + +async fn unload_models() { + #[cfg(feature = "local")] + { + spnl::model_pool::unload_all().await; + } +} + +// --------------------------------------------------------------------------- +// Shared run context +// --------------------------------------------------------------------------- + +pub(crate) struct RunCtx<'a> { + pub(crate) model: &'a str, + pub(crate) num_docs: usize, + pub(crate) doc_length: usize, + pub(crate) label: &'a str, + pub(crate) pb: &'a ProgressBar, + pub(crate) step_prefix: &'a str, + pub(crate) verbose: bool, +} + +// --------------------------------------------------------------------------- +// Single doc-length run — returns collected results +// --------------------------------------------------------------------------- + +pub(crate) struct RunResult { + label: String, + doc_length: usize, + pub(crate) nocache_ttfts: Vec, + pub(crate) reuse_ttfts: Vec, + reuse_hits: u64, + reuse_misses: u64, +} + +pub(crate) async fn run_one( + ctx: &RunCtx<'_>, + reuse_iters: usize, + trials: usize, +) -> anyhow::Result { + let RunCtx { + model, + num_docs, + doc_length, + label, + pb, + step_prefix, + verbose, + } = ctx; + let options = ExecuteOptions { + silent: true, + ..Default::default() + }; + let verbose = *verbose; + let question = "Summarize the key topics from all the documents."; + + let mut nocache_ttfts = Vec::new(); + let mut reuse_ttfts = Vec::new(); + let mut total_hits: u64 = 0; + let mut total_misses: u64 = 0; + + let _ = take_pic_stats(); + + for trial in 0..trials { + let docs = make_documents(*num_docs, *doc_length); + + // No-cache request: first time these docs are seen — full prefill + let _ = take_pic_stats(); + let nocache_query = build_query(model, &docs, question, 1); + let (nocache_ms, _) = timed_request(&nocache_query, &options, verbose).await?; + nocache_ttfts.push(nocache_ms); + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · trial {}/{trials} · Prefix · {nocache_ms:.0}ms", + trial + 1, + )); + + // Reuse requests: same docs reshuffled — PIC should serve Plus blocks from cache + let _ = take_pic_stats(); + for reuse_i in 0..reuse_iters { + let shuffled = shuffle_docs(&docs); + let reuse_query = build_query(model, &shuffled, question, 1); + let (reuse_ms, _) = timed_request(&reuse_query, &options, verbose).await?; + reuse_ttfts.push(reuse_ms); + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · trial {}/{trials} · PIC {}/{reuse_iters} · {reuse_ms:.0}ms", + trial + 1, + reuse_i + 1, + )); + } + + let (hits, misses) = take_pic_stats(); + total_hits += hits; + total_misses += misses; + } + + Ok(RunResult { + label: label.to_string(), + doc_length: *doc_length, + nocache_ttfts, + reuse_ttfts, + reuse_hits: total_hits, + reuse_misses: total_misses, + }) +} + +// --------------------------------------------------------------------------- +// Accuracy trials — compare Plus (PIC attention) vs flat (standard causal) +// --------------------------------------------------------------------------- + +use std::collections::HashMap; + +struct AccuracyResult { + label: String, + doc_length: usize, + trials: Vec, +} + +struct AccuracyTrial { + flat_correct: bool, + flat_shuffled_correct: bool, + pic_correct: bool, + pic_shuffled_correct: bool, + /// Token F1 between flat and PIC responses (secondary metric) + token_f1: f64, + /// LLM-judge semantic equivalence score (0-100), None if no grading model + llm_score: Option, +} + +impl AccuracyResult { + fn flat_accuracy(&self) -> (usize, usize) { + let correct = self.trials.iter().filter(|t| t.flat_correct).count(); + (correct, self.trials.len()) + } + + fn flat_shuffled_accuracy(&self) -> (usize, usize) { + let correct = self + .trials + .iter() + .filter(|t| t.flat_shuffled_correct) + .count(); + (correct, self.trials.len()) + } + + fn pic_accuracy(&self) -> (usize, usize) { + let correct = self.trials.iter().filter(|t| t.pic_correct).count(); + (correct, self.trials.len()) + } + + fn shuffle_accuracy(&self) -> (usize, usize) { + let correct = self + .trials + .iter() + .filter(|t| t.pic_shuffled_correct) + .count(); + (correct, self.trials.len()) + } + + fn avg_token_f1(&self) -> f64 { + if self.trials.is_empty() { + return 0.0; + } + self.trials.iter().map(|t| t.token_f1).sum::() / self.trials.len() as f64 + } + + fn avg_llm_score(&self) -> Option { + let scores: Vec = self.trials.iter().filter_map(|t| t.llm_score).collect(); + if scores.is_empty() { + return None; + } + Some(scores.iter().sum::() / scores.len() as f64) + } +} + +/// Normalize text into lowercase word tokens for comparison. +fn normalize_tokens(text: &str) -> Vec { + text.to_lowercase() + .split(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect() +} + +/// Token-level F1 between two texts (0-100). Measures word overlap regardless of order. +fn token_f1(reference: &str, candidate: &str) -> f64 { + let ref_tokens = normalize_tokens(reference); + let cand_tokens = normalize_tokens(candidate); + if ref_tokens.is_empty() && cand_tokens.is_empty() { + return 100.0; + } + if ref_tokens.is_empty() || cand_tokens.is_empty() { + return 0.0; + } + + let ref_counts: HashMap<&str, usize> = ref_tokens.iter().fold(HashMap::new(), |mut m, t| { + *m.entry(t.as_str()).or_insert(0) += 1; + m + }); + let cand_counts: HashMap<&str, usize> = cand_tokens.iter().fold(HashMap::new(), |mut m, t| { + *m.entry(t.as_str()).or_insert(0) += 1; + m + }); + + let mut common = 0usize; + for (tok, &count) in &cand_counts { + common += count.min(*ref_counts.get(tok).unwrap_or(&0)); + } + + if common == 0 { + return 0.0; + } + + let precision = common as f64 / cand_tokens.len() as f64; + let recall = common as f64 / ref_tokens.len() as f64; + 2.0 * precision * recall / (precision + recall) * 100.0 +} + +/// Build an LLM-judge query that scores semantic equivalence of two responses. +fn build_equivalence_query(model: &str, reference: &str, candidate: &str) -> Query { + use spnl::ir::{Generate, GenerateMetadata}; + + let system_prompt = "You are a semantic equivalence evaluator. You will be given two responses to the same question with the same documents. Determine whether they convey the same meaning. Return ONLY a single integer 0-100. 100 means semantically identical, 0 means completely different meaning.".to_string(); + let user_prompt = format!( + "Response A (reference):\n{reference}\n\nResponse B (candidate):\n{candidate}\n\nSemantic equivalence score (0-100):" + ); + + Query::Generate(Generate { + metadata: GenerateMetadata { + model: model.to_string(), + max_tokens: Some(16), + temperature: Some(0.0), + }, + input: Box::new(Query::Cross(vec![ + spnl!(system system_prompt), + spnl!(user user_prompt), + ])), + }) +} + +/// Parse a numeric score from an LLM response. +fn parse_score(response: &str) -> f64 { + response + .trim() + .split(|c: char| !c.is_ascii_digit()) + .find(|s| !s.is_empty()) + .and_then(|s| s.parse::().ok()) + .map(|v| v.clamp(0.0, 100.0)) + .unwrap_or(0.0) +} + +async fn run_accuracy( + ctx: &RunCtx<'_>, + grading_model: Option<&str>, + accuracy_tokens: i32, + trials: usize, +) -> anyhow::Result { + let RunCtx { + model, + num_docs, + doc_length, + label, + pb, + step_prefix, + verbose, + } = ctx; + let options = ExecuteOptions { + silent: true, + ..Default::default() + }; + let verbose = *verbose; + + let mut trial_results = Vec::new(); + + for trial in 0..trials { + // Generate inference-chain documents (definition before usage, then fillers) + let AccuracyDocs { + docs, + question, + expected, + } = make_accuracy_docs(*num_docs, *doc_length); + + // 1. Flat: causal attention, original doc order (def before use) + let flat_query = build_query_flat(model, &docs, &question, accuracy_tokens); + let (_, flat_response) = timed_request(&flat_query, &options, verbose).await?; + let flat_correct = check_answer(&flat_response, &expected); + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · accuracy {}/{trials} · flat · {}", + trial + 1, + if flat_correct { "Y" } else { "N" }, + )); + + // 2. Flat-shuffled: causal attention, shuffled doc order (no PIC involvement) + let shuffled = shuffle_docs(&docs); + let flat_shuf_query = build_query_flat(model, &shuffled, &question, accuracy_tokens); + let (_, flat_shuf_response) = timed_request(&flat_shuf_query, &options, verbose).await?; + let flat_shuffled_correct = check_answer(&flat_shuf_response, &expected); + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · accuracy {}/{trials} · fshuf · {}", + trial + 1, + if flat_shuffled_correct { "Y" } else { "N" }, + )); + + // 3. PIC: Plus block attention, original doc order + let _ = take_pic_stats(); + let pic_query = build_query(model, &docs, &question, accuracy_tokens); + let (_, pic_response) = timed_request(&pic_query, &options, verbose).await?; + let pic_correct = check_answer(&pic_response, &expected); + if verbose { + let (h, m) = take_pic_stats(); + eprintln!( + "[pic] expected={expected:?} response={pic_response:?} correct={pic_correct} hits={h} misses={m}" + ); + } + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · accuracy {}/{trials} · pic · {}", + trial + 1, + if pic_correct { "Y" } else { "N" }, + )); + + // 4. PIC-shuffled: Plus block attention, shuffled doc order + // (reuses cached Plus blocks from step 3 — tests PIC cache correctness) + let pic_shuffled = shuffle_docs(&docs); + let pic_shuf_query = build_query(model, &pic_shuffled, &question, accuracy_tokens); + let (_, pic_shuf_response) = timed_request(&pic_shuf_query, &options, verbose).await?; + let pic_shuffled_correct = check_answer(&pic_shuf_response, &expected); + if verbose { + let (h, m) = take_pic_stats(); + eprintln!( + "[pshuf] expected={expected:?} response={pic_shuf_response:?} correct={pic_shuffled_correct} hits={h} misses={m}" + ); + } + pb.inc(1); + pb.set_message(format!( + "{step_prefix}{label} · accuracy {}/{trials} · pshuf · {}", + trial + 1, + if pic_shuffled_correct { "Y" } else { "N" }, + )); + + // Token F1 between flat and PIC (secondary metric) + let f1 = token_f1(&flat_response, &pic_response); + + // LLM-judge (if grading model provided) + let llm_score = if let Some(gm) = grading_model { + let judge_query = build_equivalence_query(gm, &flat_response, &pic_response); + match execute(&judge_query, &options).await { + Ok(Query::Message(Assistant(s))) => { + pb.inc(1); + Some(parse_score(&s)) + } + _ => { + pb.inc(1); + None + } + } + } else { + None + }; + + let score_str = llm_score + .map(|s| format!(" llm={s:.0}")) + .unwrap_or_default(); + pb.set_message(format!( + "{step_prefix}{label} · accuracy {}/{trials} · {}/{}/{}/{} f1={f1:.0}{score_str}", + trial + 1, + if flat_correct { "Y" } else { "N" }, + if flat_shuffled_correct { "Y" } else { "N" }, + if pic_correct { "Y" } else { "N" }, + if pic_shuffled_correct { "Y" } else { "N" }, + )); + + trial_results.push(AccuracyTrial { + flat_correct, + flat_shuffled_correct, + pic_correct, + pic_shuffled_correct, + token_f1: f1, + llm_score, + }); + } + + Ok(AccuracyResult { + label: label.to_string(), + doc_length: *doc_length, + trials: trial_results, + }) +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +pub async fn run(args: PicArgs) -> Result<(), SpnlError> { + // --- Resolve models and sizes (--full provides defaults) --- + let mut models: Vec = if !args.model.is_empty() { + args.model.clone() + } else if args.full { + FULL_MODELS.iter().map(|s| s.to_string()).collect() + } else { + anyhow::bail!("At least one model is required (-m MODEL)"); + }; + // Deduplicate while preserving order + { + let mut seen = std::collections::HashSet::new(); + models.retain(|m| seen.insert(m.clone())); + } + + let spectrum: Vec<(usize, String)> = if !args.size.is_empty() { + resolve_spectrum(&args.size).map_err(|bad| { + anyhow::anyhow!( + "Unknown size '{}' in --size={}. Valid sizes: {}", + bad, + args.size.join(","), + ALL_SIZES + ) + })? + } else if args.full { + FULL_SPECTRUM + .iter() + .map(|&(len, name)| (len, name.to_string())) + .collect() + } else { + vec![(args.length, format!("{}w", args.length))] + }; + + // --- Prerequisite checks --- + for model in &models { + let is_remote = model.starts_with("ollama/") + || model.starts_with("openai/") + || model.starts_with("gemini/") + || model.starts_with("spnl/"); + if is_remote { + eprintln!("WARNING: Model '{}' uses a remote backend.", model); + eprintln!( + " PIC cross-request caching only works with the local mistral.rs backend." + ); + eprintln!(); + } + } + + // --- Derive which protocols to run --- + let wants_ttft = args.output.iter().any(|m| { + matches!( + m, + OutputMode::Speedup + | OutputMode::Iqr + | OutputMode::Hitrate + | OutputMode::Latency + | OutputMode::Json + ) + }); + let wants_accuracy = args.output.contains(&OutputMode::Accuracy); + + // --- Print config --- + eprintln!("=== PIC Cross-Request Cache Benchmark ==="); + if models.len() == 1 { + eprintln!(" Model: {}", models[0]); + } else { + eprintln!(" Models: {}", models.join(", ")); + } + eprintln!(" Documents: {} per request", args.num_docs); + let size_names: Vec = spectrum + .iter() + .map(|(len, name)| format!("{name}({len}w)")) + .collect(); + eprintln!(" Doc sizes: {}", size_names.join(", ")); + let output_labels: Vec<&str> = args + .output + .iter() + .map(|m| match m { + OutputMode::Speedup => "speedup", + OutputMode::Iqr => "iqr (p25,p75 ms)", + OutputMode::Hitrate => "hitrate", + OutputMode::Latency => "latency (prefix,pic ms)", + OutputMode::Json => "json", + OutputMode::Accuracy => "accuracy", + }) + .collect(); + eprintln!(" Output: {}", output_labels.join(", ")); + if wants_ttft { + eprintln!(" Reuse iters: {} per trial", args.reuse_iters); + eprintln!(" Trials: {} per doc-length", args.trials); + eprintln!(" Max tokens: 1 (TTFT measurement)"); + } + if wants_accuracy { + let metric = match args.grading_model.as_deref() { + Some(m) => format!("secondary=token-f1+llm-judge({m})"), + None => "secondary=token-f1, llm-judge disabled (see --grading-model)".to_string(), + }; + eprintln!( + " Accuracy: {} tokens, {} trials, flat/fshuf/pic/pshuf ground-truth, {metric}", + args.length, args.trials + ); + } + eprintln!(); + + // --- Progress bar --- + let ttft_steps_per_len = if wants_ttft { + args.trials * (1 + args.reuse_iters) + } else { + 0 + }; + let accuracy_steps_per_trial = if args.grading_model.is_some() { 5 } else { 4 }; // flat + flat-shuf + PIC + PIC-shuf + optional judge + let accuracy_steps_per_len = if wants_accuracy { + args.trials * accuracy_steps_per_trial + } else { + 0 + }; + let total_steps = + (models.len() * spectrum.len() * (ttft_steps_per_len + accuracy_steps_per_len)) as u64; + let pb = ProgressBar::new(total_steps); + pb.set_style( + ProgressStyle::with_template("[{elapsed_precise}] {bar:30.cyan/dim} {pos}/{len} {msg}") + .unwrap() + .progress_chars("━╸─"), + ); + + // --- Run all models × sizes --- + let mut all_model_results: Vec<(String, Vec)> = Vec::new(); + let mut all_accuracy_results: Vec<(String, Vec)> = Vec::new(); + + for (mi, model) in models.iter().enumerate() { + // Unload previous model to free GPU memory before loading the next one + if mi > 0 { + unload_models().await; + } + + let mut results = Vec::new(); + let mut accuracy_results = Vec::new(); + for (si, (doc_length, label)) in spectrum.iter().enumerate() { + let step_prefix = match (models.len() > 1, spectrum.len() > 1) { + (true, true) => format!("{} [{}/{}] ", model, si + 1, spectrum.len()), + (true, false) => format!("{} ", model), + (false, true) => format!("[{}/{}] ", si + 1, spectrum.len()), + (false, false) => String::new(), + }; + let ctx = RunCtx { + model, + num_docs: args.num_docs, + doc_length: *doc_length, + label, + pb: &pb, + step_prefix: &step_prefix, + verbose: args.verbose, + }; + + if wants_ttft { + let r = run_one(&ctx, args.reuse_iters, args.trials).await?; + results.push(r); + } + + if wants_accuracy { + let ar = run_accuracy( + &ctx, + args.grading_model.as_deref(), + args.length as i32, + args.trials, + ) + .await?; + accuracy_results.push(ar); + } + } + if wants_ttft { + all_model_results.push((model.clone(), results)); + } + if wants_accuracy { + all_accuracy_results.push((model.clone(), accuracy_results)); + } + } + + pb.finish_and_clear(); + eprintln!(); + + // --- Report --- + if !all_model_results.is_empty() { + // Find the first TTFT output mode for the table (Json is handled separately) + let ttft_modes: Vec<&OutputMode> = args + .output + .iter() + .filter(|m| { + matches!( + m, + OutputMode::Speedup + | OutputMode::Iqr + | OutputMode::Hitrate + | OutputMode::Latency + ) + }) + .collect(); + if args.output.contains(&OutputMode::Json) { + print_results_json(&all_model_results); + } + for mode in ttft_modes { + print_results_table(&all_model_results, mode); + } + } + + // --- Accuracy report --- + if !all_accuracy_results.is_empty() { + print_accuracy_table(&all_accuracy_results); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Results table (single unified output for all modes) +// --------------------------------------------------------------------------- + +fn format_cell(r: &RunResult, output: &OutputMode) -> String { + match output { + OutputMode::Speedup => { + let nocache_p50 = percentile(&r.nocache_ttfts, 50); + let reuse_p50 = percentile(&r.reuse_ttfts, 50); + let speedup = if reuse_p50 > 0.0 { + nocache_p50 / reuse_p50 + } else { + 0.0 + }; + format!("{speedup:.2}x") + } + OutputMode::Iqr => { + let p25 = percentile(&r.reuse_ttfts, 25); + let p75 = percentile(&r.reuse_ttfts, 75); + format!("{p25:.0},{p75:.0}") + } + OutputMode::Hitrate => { + let (rate, _) = compute_hit_rate(r.reuse_hits, r.reuse_misses); + format!("{rate:.0}%") + } + OutputMode::Latency => { + let nocache_p50 = percentile(&r.nocache_ttfts, 50); + let reuse_p50 = percentile(&r.reuse_ttfts, 50); + format!("{nocache_p50:.0},{reuse_p50:.0}") + } + OutputMode::Json | OutputMode::Accuracy => unreachable!(), + } +} + +fn print_results_table(all_model_results: &[(String, Vec)], output: &OutputMode) { + // Column labels from the first model's results + let size_labels: Vec = all_model_results[0] + .1 + .iter() + .map(|r| { + if r.label.ends_with('w') { + r.label.clone() + } else { + format!("{} {}w", r.label, r.doc_length) + } + }) + .collect(); + + // Pre-format all cells to compute column widths + let rows: Vec<(&str, Vec)> = all_model_results + .iter() + .map(|(model, results)| { + let cells: Vec = results.iter().map(|r| format_cell(r, output)).collect(); + (model.as_str(), cells) + }) + .collect(); + + let model_w = rows.iter().map(|(m, _)| m.len()).max().unwrap_or(5).max(5); + let col_ws: Vec = (0..size_labels.len()) + .map(|i| { + let header_w = size_labels[i].len(); + let data_w = rows + .iter() + .map(|(_, cells)| cells[i].len()) + .max() + .unwrap_or(0); + header_w.max(data_w).max(6) + }) + .collect(); + + // Header + let mut header = format!(" {:w$}", label, w = col_ws[i])); + sep.push_str(&format!(" {:>w$}", "─".repeat(col_ws[i]), w = col_ws[i])); + } + eprintln!("{header}"); + eprintln!("{sep}"); + + // Data rows + for (model, cells) in &rows { + let mut row = format!(" {:w$}", cell, w = col_ws[i])); + } + eprintln!("{row}"); + } + + eprintln!(); +} + +fn print_results_json(all_model_results: &[(String, Vec)]) { + let results: Vec = all_model_results + .iter() + .flat_map(|(model, runs)| { + runs.iter().map(move |r| { + let nocache_p50 = percentile(&r.nocache_ttfts, 50); + let reuse_p50 = percentile(&r.reuse_ttfts, 50); + let speedup = if reuse_p50 > 0.0 { + nocache_p50 / reuse_p50 + } else { + 0.0 + }; + let (hit_rate, _) = compute_hit_rate(r.reuse_hits, r.reuse_misses); + serde_json::json!({ + "model": model, + "size": r.label, + "doc_length": r.doc_length, + "speedup": (speedup * 100.0).round() / 100.0, + "prefix_p50_ms": (nocache_p50 * 100.0).round() / 100.0, + "pic_p50_ms": (reuse_p50 * 100.0).round() / 100.0, + "pic_p25_ms": (percentile(&r.reuse_ttfts, 25) * 100.0).round() / 100.0, + "pic_p75_ms": (percentile(&r.reuse_ttfts, 75) * 100.0).round() / 100.0, + "hit_rate": (hit_rate * 100.0).round() / 100.0, + "hits": r.reuse_hits, + "misses": r.reuse_misses, + "prefix_ttfts_ms": r.nocache_ttfts, + "pic_ttfts_ms": r.reuse_ttfts, + }) + }) + }) + .collect(); + println!("{}", serde_json::to_string_pretty(&results).unwrap()); +} + +// --------------------------------------------------------------------------- +// Accuracy table (same layout as TTFT tables: rows=models, columns=doc sizes) +// --------------------------------------------------------------------------- + +fn format_accuracy_cell(r: &AccuracyResult) -> String { + let (fc, _) = r.flat_accuracy(); + let (fsc, _) = r.flat_shuffled_accuracy(); + let (pc, _) = r.pic_accuracy(); + let (psc, total) = r.shuffle_accuracy(); + let base = format!("{fc}/{fsc}/{pc}/{psc}"); + // Append secondary metrics if available + let f1 = r.avg_token_f1(); + match r.avg_llm_score() { + Some(llm) => format!("{base} f1={f1:.0},llm={llm:.0}"), + None if total > 0 => format!("{base} f1={f1:.0}"), + _ => base, + } +} + +fn print_accuracy_table(all: &[(String, Vec)]) { + // Legend + let total = all[0].1.first().map(|r| r.trials.len()).unwrap_or(0); + eprintln!(" Accuracy: flat/fshuf/pic/pshuf correct out of {total} trials"); + eprintln!(" flat=causal fshuf=causal+shuffled pic=Plus blocks pshuf=Plus+shuffled"); + eprintln!(); + + let size_labels: Vec = all[0] + .1 + .iter() + .map(|r| { + if r.label.ends_with('w') { + r.label.clone() + } else { + format!("{} {}w", r.label, r.doc_length) + } + }) + .collect(); + + let rows: Vec<(&str, Vec)> = all + .iter() + .map(|(model, results)| { + let cells: Vec = results.iter().map(format_accuracy_cell).collect(); + (model.as_str(), cells) + }) + .collect(); + + let model_w = rows.iter().map(|(m, _)| m.len()).max().unwrap_or(5).max(5); + let col_ws: Vec = (0..size_labels.len()) + .map(|i| { + let header_w = size_labels[i].len(); + let data_w = rows + .iter() + .map(|(_, cells)| cells[i].len()) + .max() + .unwrap_or(0); + header_w.max(data_w).max(6) + }) + .collect(); + + // Header + let mut header = format!(" {:w$}", label, w = col_ws[i])); + sep.push_str(&format!(" {:>w$}", "─".repeat(col_ws[i]), w = col_ws[i])); + } + eprintln!("{header}"); + eprintln!("{sep}"); + + // Data rows + for (model, cells) in &rows { + let mut row = format!(" {:w$}", cell, w = col_ws[i])); + } + eprintln!("{row}"); + } + + eprintln!(); +} + +// --------------------------------------------------------------------------- +// Math helpers +// --------------------------------------------------------------------------- + +fn compute_hit_rate(hits: u64, misses: u64) -> (f64, u64) { + let total = hits + misses; + let rate = if total > 0 { + hits as f64 / total as f64 * 100.0 + } else { + 0.0 + }; + (rate, total) +} + +fn percentile(values: &[f64], pct: usize) -> f64 { + if values.is_empty() { + return 0.0; + } + let mut sorted = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let idx = (sorted.len() * pct / 100).min(sorted.len() - 1); + sorted[idx] +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- token_f1 ---- + + #[test] + fn token_f1_identical_texts() { + assert_eq!(token_f1("hello world", "hello world"), 100.0); + } + + #[test] + fn token_f1_completely_different() { + assert_eq!(token_f1("hello world", "foo bar"), 0.0); + } + + #[test] + fn token_f1_partial_overlap() { + // "the cat sat" vs "the cat ran": shared = {the:1, cat:1} = 2 + // precision = 2/3, recall = 2/3, f1 = 2/3 * 100 ≈ 66.67 + let f1 = token_f1("the cat sat", "the cat ran"); + assert!((f1 - 66.67).abs() < 0.1, "got {f1}"); + } + + #[test] + fn token_f1_both_empty() { + assert_eq!(token_f1("", ""), 100.0); + } + + #[test] + fn token_f1_one_empty() { + assert_eq!(token_f1("hello", ""), 0.0); + assert_eq!(token_f1("", "hello"), 0.0); + } + + // ---- normalize_tokens ---- + + #[test] + fn normalize_tokens_basic() { + assert_eq!(normalize_tokens("Hello World!"), vec!["hello", "world"]); + } + + #[test] + fn normalize_tokens_punctuation() { + assert_eq!(normalize_tokens("foo-bar, baz."), vec!["foo", "bar", "baz"]); + } + + #[test] + fn normalize_tokens_empty() { + let empty: Vec = vec![]; + assert_eq!(normalize_tokens(""), empty); + } + + // ---- parse_score ---- + + #[test] + fn parse_score_plain_number() { + assert_eq!(parse_score("85"), 85.0); + } + + #[test] + fn parse_score_with_text() { + assert_eq!(parse_score("Score: 92/100"), 92.0); + } + + #[test] + fn parse_score_garbage() { + assert_eq!(parse_score("no numbers"), 0.0); + } + + #[test] + fn parse_score_clamped() { + assert_eq!(parse_score("150"), 100.0); + } + + // ---- check_answer ---- + + #[test] + fn check_answer_exact_match() { + assert!(check_answer("45", "45")); + } + + #[test] + fn check_answer_case_insensitive() { + assert!(check_answer("there are 45 schools", "45")); + assert!(check_answer("FORTY-FIVE or 45", "45")); + } + + #[test] + fn check_answer_not_present() { + assert!(!check_answer("There are 12 hospitals.", "45")); + } + + #[test] + fn check_answer_empty_response() { + assert!(!check_answer("", "45")); + } + + #[test] + fn check_answer_embedded_in_sentence() { + assert!(check_answer( + "Based on the documents, there are 45 schools in the capital district.", + "45" + )); + } + + // ---- make_accuracy_docs ---- + + #[test] + fn make_accuracy_docs_correct_count() { + let result = make_accuracy_docs(4, 50); + assert_eq!(result.docs.len(), 4); + } + + #[test] + fn make_accuracy_docs_minimum_two() { + let result = make_accuracy_docs(2, 50); + assert_eq!(result.docs.len(), 2); + } + + #[test] + #[should_panic(expected = "need at least 2 docs")] + fn make_accuracy_docs_panics_with_one() { + make_accuracy_docs(1, 50); + } + + #[test] + fn make_accuracy_docs_def_contains_alias_and_meaning() { + // Run several times to cover different random entries + for _ in 0..10 { + let result = make_accuracy_docs(2, 100); + let doc_def = &result.docs[0]; + // Definition doc should contain "local name for a" (our template) + assert!( + doc_def.contains("local name for a"), + "def doc should contain definition template: {doc_def}" + ); + } + } + + #[test] + fn make_accuracy_docs_use_contains_quantity() { + for _ in 0..10 { + let result = make_accuracy_docs(2, 100); + let doc_use = &result.docs[1]; + // Usage doc should contain the expected answer (the quantity) + assert!( + doc_use.contains(&result.expected), + "usage doc should contain quantity '{}': {doc_use}", + result.expected, + ); + } + } + + #[test] + fn make_accuracy_docs_question_uses_real_meaning() { + for _ in 0..10 { + let result = make_accuracy_docs(2, 50); + // Question should NOT contain the alias (it uses the real meaning) + // and should ask "How many" + assert!( + result.question.starts_with("How many"), + "question should ask 'How many': {}", + result.question, + ); + // Question should contain "ONLY the number" + assert!( + result.question.contains("ONLY the number"), + "question should include terse-answer instruction: {}", + result.question, + ); + } + } + + #[test] + fn make_accuracy_docs_question_does_not_leak_alias() { + // The question should use the real_meaning, not the alias. + // Collect all aliases to check none appear in the question. + let aliases: Vec<&str> = INFERENCE_BANK.iter().map(|e| e.alias).collect(); + for _ in 0..20 { + let result = make_accuracy_docs(2, 50); + for alias in &aliases { + assert!( + !result.question.contains(alias), + "question should not leak alias '{}': {}", + alias, + result.question, + ); + } + } + } + + // ---- resolve_spectrum ---- + + #[test] + fn resolve_spectrum_valid_sizes() { + let sizes: Vec = vec!["xs".into(), "m".into(), "xxl".into()]; + let result = resolve_spectrum(&sizes).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0], (10, "xs".to_string())); + assert_eq!(result[1], (200, "m".to_string())); + assert_eq!(result[2], (2000, "xxl".to_string())); + } + + #[test] + fn resolve_spectrum_invalid_sizes() { + let sizes: Vec = vec!["bogus".into()]; + assert_eq!(resolve_spectrum(&sizes).unwrap_err(), "bogus"); + } + + #[test] + fn resolve_spectrum_empty() { + let sizes: Vec = vec![]; + assert_eq!(resolve_spectrum(&sizes).unwrap(), vec![]); + } + + #[test] + fn resolve_spectrum_mixed_rejects_bad() { + let sizes: Vec = vec!["xs".into(), "bogus".into(), "m".into()]; + assert_eq!(resolve_spectrum(&sizes).unwrap_err(), "bogus"); + } + + #[test] + fn resolve_spectrum_short_aliases() { + let sizes: Vec = vec!["s".into(), "m".into(), "l".into()]; + let result = resolve_spectrum(&sizes).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0], (50, "sm".to_string())); + assert_eq!(result[1], (200, "m".to_string())); + assert_eq!(result[2], (500, "lg".to_string())); + } + + // ---- compute_hit_rate ---- + + #[test] + fn compute_hit_rate_zero_total() { + assert_eq!(compute_hit_rate(0, 0), (0.0, 0)); + } + + #[test] + fn compute_hit_rate_all_hits() { + assert_eq!(compute_hit_rate(10, 0), (100.0, 10)); + } + + #[test] + fn compute_hit_rate_half_and_half() { + assert_eq!(compute_hit_rate(5, 5), (50.0, 10)); + } + + // ---- percentile ---- + + #[test] + fn percentile_empty() { + assert_eq!(percentile(&[], 50), 0.0); + } + + #[test] + fn percentile_single() { + assert_eq!(percentile(&[42.0], 50), 42.0); + } + + #[test] + fn percentile_known_data() { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + assert_eq!(percentile(&data, 50), 3.0); + assert_eq!(percentile(&data, 25), 2.0); + } + + // ---- integration test: benchmark produces speedup > 1 ---- + + #[cfg(all(feature = "bench", feature = "local"))] + #[tokio::test] + #[ignore] // requires local model and GPU + async fn pic_benchmark_shows_speedup() { + let model = std::env::var("BENCH_MODEL").unwrap_or_else(|_| "llama3.2:1b".to_string()); + let pb = ProgressBar::hidden(); + let ctx = RunCtx { + model: &model, + num_docs: 2, + doc_length: 50, + label: "test", + pb: &pb, + step_prefix: "", + verbose: false, + }; + let result = run_one(&ctx, 3, 1).await.expect("benchmark should succeed"); + let nocache_p50 = percentile(&result.nocache_ttfts, 50); + let reuse_p50 = percentile(&result.reuse_ttfts, 50); + let speedup = nocache_p50 / reuse_p50; + assert!( + speedup > 1.0, + "Expected PIC speedup > 1.0, got {speedup:.2}x" + ); + } +} diff --git a/spnl/Cargo.toml b/spnl/Cargo.toml index a8a35b0d..cc68b05a 100644 --- a/spnl/Cargo.toml +++ b/spnl/Cargo.toml @@ -77,8 +77,8 @@ serde_yaml2 = { version = "0.1.3", optional = true } sha2 = { version = "0.10.9", optional = true } pyo3 = { version = "0.28.0", features = ["macros"], optional = true } tokenizers = { version = "0.22.0", default-features = false, features = ["onig", "esaxx_fast", "hf-hub", "http"], optional = true } -mistralrs = { version = "0.7.0", git = "https://github.com/EricLBuehler/mistral.rs", rev = "dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4", optional = true } -mistralrs-core = { version = "0.7.0", git = "https://github.com/EricLBuehler/mistral.rs", rev = "dd8d0c6f9dc2fe3129ff4228de667e2a0efcf5b4", optional = true, default-features = false } +mistralrs = { version = "0.7.0", git = "https://github.com/starpit/mistral.rs", branch = "pic-cache-reuse", optional = true } +mistralrs-core = { version = "0.7.0", git = "https://github.com/starpit/mistral.rs", branch = "pic-cache-reuse", optional = true, default-features = false } hf-hub = { version = "0.4.3", features = ["tokio"], optional = true } derive_builder = "0.20.2" moka = { version = "0.12.10", features = ["sync"], optional = true } diff --git a/spnl/src/execute/mod.rs b/spnl/src/execute/mod.rs index f704650b..2d513a26 100644 --- a/spnl/src/execute/mod.rs +++ b/spnl/src/execute/mod.rs @@ -44,12 +44,9 @@ async fn plus(units: &[Query], rp: &ExecuteOptions) -> SpnlResult { let evaluated = futures::future::try_join_all(units.iter().map(|u| run_subtree(u, rp, Some(&m)))).await?; - if evaluated.len() == 1 { - // the unwrap() is safe here, due to the len() == 1 guard - Ok(evaluated.into_iter().next().unwrap()) - } else { - Ok(Query::Plus(evaluated)) - } + // Always keep Plus wrapping — even for single elements — so that PIC + // tagging is preserved (prepare_fragment generates need the Plus sentinel). + Ok(Query::Plus(evaluated)) } /// Intersperse a in-between every element of b @@ -103,9 +100,10 @@ async fn run_subtree_(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgre } Query::Monad(q) => { - // ignore output + // Execute for side effects (e.g. populating PIC cache), discard output. + // Return empty Seq to avoid phantom empty-message tokens. let _ = run_subtree(q, rp, m).await?; - Ok("".into()) + Ok(Query::Seq(vec![])) } Query::Bulk(Bulk::Repeat(repeat)) => crate::generate::generate(repeat.clone(), m, rp).await, diff --git a/spnl/src/generate/backend/capabilities.rs b/spnl/src/generate/backend/capabilities.rs index f315e5a0..a30421c8 100644 --- a/spnl/src/generate/backend/capabilities.rs +++ b/spnl/src/generate/backend/capabilities.rs @@ -1,7 +1,15 @@ -/// Does the given provider support the spnl REST API? +/// Does the given provider support PIC (Position-Independent Caching)? +/// +/// This includes: +/// - `spnl/` prefixed models (vLLM spans backend) +/// - `local/` prefixed models (mistral.rs backend) +/// - Pretty names that resolve to the local backend (e.g. `llama3.1:8b`) pub fn supports_spans(provider_slash_model: &str) -> bool { - // for now... - provider_slash_model.starts_with("spnl/") + if provider_slash_model.starts_with("spnl/") || provider_slash_model.starts_with("local/") { + return true; + } + // No recognized prefix → falls through to prettynames → local backend + !provider_slash_model.contains('/') } /// Does the given provider support the bulk-repeat API (generate with `n`)? diff --git a/spnl/src/generate/backend/mistralrs/loader.rs b/spnl/src/generate/backend/mistralrs/loader.rs index 4eca3b65..4a7edf22 100644 --- a/spnl/src/generate/backend/mistralrs/loader.rs +++ b/spnl/src/generate/backend/mistralrs/loader.rs @@ -3,7 +3,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use mistralrs::{ Device, GgufModelBuilder, IsqType, Model, PagedAttentionMetaBuilder, TextModelBuilder, - best_device, paged_attn_supported, + TokenSource, best_device, paged_attn_supported, }; use std::collections::HashMap; use std::path::PathBuf; @@ -11,6 +11,27 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; +/// Get the HuggingFace token from environment, checking HF_TOKEN first, +/// then HUGGING_FACE_HUB_TOKEN. +fn get_hf_token() -> Option { + std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok() + .filter(|t| !t.is_empty()) +} + +/// Get the appropriate TokenSource for mistral.rs builders. +/// Prefers HF_TOKEN env var if set, otherwise falls back to cached token. +fn get_token_source() -> TokenSource { + if std::env::var("HF_TOKEN").is_ok() { + TokenSource::EnvVar("HF_TOKEN".to_string()) + } else if std::env::var("HUGGING_FACE_HUB_TOKEN").is_ok() { + TokenSource::EnvVar("HUGGING_FACE_HUB_TOKEN".to_string()) + } else { + TokenSource::CacheToken + } +} + /// Get the HuggingFace cache directory /// This matches the cache used by hf_hub crate /// Returns the path to the hub directory where models are stored @@ -34,8 +55,7 @@ fn get_hf_cache_dir() -> PathBuf { /// Get PagedAttention configuration from environment variables /// Returns None if PagedAttention is disabled or not supported -fn get_paged_attn_config() --> Option anyhow::Result> { +fn get_paged_attn_config() -> Option { // Check if explicitly enabled via environment variable (disabled by default for faster startup) let enabled = std::env::var("MISTRALRS_PAGED_ATTN") .unwrap_or_else(|_| "false".to_string()) @@ -47,7 +67,7 @@ fn get_paged_attn_config() // Check if PagedAttention is supported on this platform if !paged_attn_supported() { - if should_enable_logging() { + if verbose() { eprintln!("PagedAttention not supported on this platform"); } return None; @@ -59,15 +79,14 @@ fn get_paged_attn_config() .and_then(|s| s.parse::().ok()) .unwrap_or(32); - if should_enable_logging() { + if verbose() { eprintln!("Enabling PagedAttention with block_size={}", block_size); } - Some(move || { - PagedAttentionMetaBuilder::default() - .with_block_size(block_size) - .build() - }) + PagedAttentionMetaBuilder::default() + .with_block_size(block_size) + .build() + .ok() } /// Get prefix cache size from environment variables @@ -76,20 +95,20 @@ fn get_prefix_cache_n() -> Option { match std::env::var("MISTRALRS_PREFIX_CACHE_N") { Ok(val) => { if val.to_lowercase() == "false" || val == "0" { - if should_enable_logging() { + if verbose() { eprintln!("Prefix caching disabled via MISTRALRS_PREFIX_CACHE_N"); } None } else { match val.parse::() { Ok(n) => { - if should_enable_logging() { + if verbose() { eprintln!("Prefix caching enabled with n={}", n); } Some(n) } Err(_) => { - if should_enable_logging() { + if verbose() { eprintln!("Invalid MISTRALRS_PREFIX_CACHE_N value, using default (16)"); } Some(16) @@ -104,9 +123,8 @@ fn get_prefix_cache_n() -> Option { } } -/// Check if logging should be enabled -/// Returns true if MISTRALRS_VERBOSE env var is set to "true" or "1" -fn should_enable_logging() -> bool { +/// Check if verbose logging is enabled (MISTRALRS_VERBOSE=true|1) +fn verbose() -> bool { std::env::var("MISTRALRS_VERBOSE") .map(|v| v.to_lowercase() == "true" || v == "1") .unwrap_or(false) @@ -152,13 +170,13 @@ fn get_isq_type() -> Option { } else { match parse_isq_type(&val) { Ok(isq_type) => { - if should_enable_logging() { + if verbose() { eprintln!("Enabling in-situ quantization: {:?}", isq_type); } Some(isq_type) } Err(e) => { - if should_enable_logging() { + if verbose() { eprintln!("Warning: {}", e); } None @@ -192,7 +210,7 @@ impl ModelPool { } /// Get or load a model - pub async fn get_or_load(&self, model_name: &str) -> anyhow::Result> { + pub async fn get_or_load(&self, model_name: &str, silent: bool) -> anyhow::Result> { // Check if model is already loaded { let models = self.models.read().await; @@ -202,7 +220,7 @@ impl ModelPool { } // Model not in cache, load it - let model = self.load_model(model_name).await?; + let model = self.load_model(model_name, silent).await?; // Cache the loaded model { @@ -269,14 +287,14 @@ impl ModelPool { .collect(); if !cached_files.is_empty() { - if should_enable_logging() { + if verbose() { eprintln!("Found cached GGUF files: {:?}", cached_files); } // Return the first priority format that's cached for filename in &priority_formats { if cached_files.contains(filename) { - if should_enable_logging() { + if verbose() { eprintln!("Using cached GGUF file: {}", filename); } return Ok(vec![filename.clone()]); @@ -289,16 +307,33 @@ impl ModelPool { } // If not in cache, query HF API to find which file to download - if should_enable_logging() { + if verbose() { eprintln!("Model not in cache, querying HuggingFace API..."); } let url = format!("https://huggingface.co/api/models/{}/tree/main", model_name); - let response = reqwest::get(&url).await?; + let client = reqwest::Client::new(); + let mut request = client.get(&url); + if let Some(token) = get_hf_token() { + request = request.bearer_auth(token); + } + let response = request.send().await?; if !response.status().is_success() { + let status = response.status(); + let hint = if status == 401 { + ". Set HF_TOKEN for gated/private models".to_string() + } else if status == 404 { + format!( + ". Check that '{}' is a valid HuggingFace model ID", + model_name + ) + } else { + String::new() + }; return Err(anyhow::anyhow!( - "Failed to fetch model files from HuggingFace: HTTP {}", - response.status() + "Failed to fetch model files from HuggingFace: HTTP {} ({}){hint}", + status, + url, )); } @@ -313,14 +348,14 @@ impl ModelPool { }) .collect(); - if should_enable_logging() { + if verbose() { eprintln!("Available GGUF files in repo: {:?}", available_files); } // Return the first priority format that exists for filename in &priority_formats { if available_files.contains(filename) { - if should_enable_logging() { + if verbose() { eprintln!("Will download GGUF file: {}", filename); } return Ok(vec![filename.clone()]); @@ -337,7 +372,7 @@ impl ModelPool { } /// Load a model from HuggingFace using appropriate builder - async fn load_model(&self, model_name: &str) -> anyhow::Result> { + async fn load_model(&self, model_name: &str, silent: bool) -> anyhow::Result> { // Check if this is a GGUF model (contains "GGUF" in the name) let is_gguf = model_name.to_uppercase().contains("GGUF"); @@ -348,7 +383,7 @@ impl ModelPool { let device = best_device(false).expect("Failed to detect device"); // Log the selected device if logging is enabled - if should_enable_logging() { + if verbose() { match &device { Device::Cuda(_) => eprintln!("Using CUDA GPU acceleration"), Device::Metal(_) => eprintln!("Using Metal GPU acceleration"), @@ -358,7 +393,7 @@ impl ModelPool { // Build the model using the appropriate builder let model = if is_gguf { - if should_enable_logging() { + if verbose() { eprintln!("Detected GGUF model, using GgufModelBuilder"); } @@ -366,29 +401,36 @@ impl ModelPool { let gguf_files = self.select_gguf_files(model_name).await?; if let Some(first_file) = gguf_files.first() - && should_enable_logging() + && verbose() { eprintln!("Using GGUF file: {}", first_file); } // Use GgufModelBuilder for GGUF models - let mut builder = GgufModelBuilder::new(model_name, gguf_files).with_device(device); + let mut builder = GgufModelBuilder::new(model_name, gguf_files) + .with_device(device) + .with_token_source(get_token_source()); // Optionally enable logging - if should_enable_logging() { + if verbose() { builder = builder.with_logging(); } // Apply PagedAttention if available and enabled if let Some(paged_config) = get_paged_attn_config() { - builder = builder.with_paged_attn(paged_config)?; + builder = builder.with_paged_attn(paged_config); } // Apply prefix caching configuration builder = builder.with_prefix_cache_n(get_prefix_cache_n()); - // Create spinner ONLY if model is cached (no download needed) - let spinner = if !should_enable_logging() && is_cached { + // silent: no output; verbose: text log; normal+cached: spinner + let spinner = if silent { + None + } else if verbose() { + eprintln!("Initializing model: {}", model_name); + None + } else if is_cached { let pb = ProgressBar::new_spinner(); pb.set_style( ProgressStyle::default_spinner() @@ -398,9 +440,6 @@ impl ModelPool { pb.enable_steady_tick(Duration::from_millis(100)); pb.set_message(format!("Initializing {}", model_name)); Some(pb) - } else if should_enable_logging() { - eprintln!("Initializing model: {}", model_name); - None } else { None }; @@ -413,17 +452,18 @@ impl ModelPool { result } else { - if should_enable_logging() { + if verbose() { eprintln!("Using TextModelBuilder for standard model"); } // Use TextModelBuilder for normal models let mut builder = TextModelBuilder::new(model_name) // .with_dtype(mistralrs::ModelDType::F32) // for future reference: might be needed for ISQ on metal - .with_device(device); + .with_device(device) + .with_token_source(get_token_source()); // Optionally enable logging - if should_enable_logging() { + if verbose() { builder = builder.with_logging(); } @@ -434,14 +474,19 @@ impl ModelPool { // Apply PagedAttention if available and enabled if let Some(paged_config) = get_paged_attn_config() { - builder = builder.with_paged_attn(paged_config)?; + builder = builder.with_paged_attn(paged_config); } // Apply prefix caching configuration builder = builder.with_prefix_cache_n(get_prefix_cache_n()); - // Create spinner ONLY if model is cached (no download needed) - let spinner = if !should_enable_logging() && is_cached { + // silent: no output; verbose: text log; normal+cached: spinner + let spinner = if silent { + None + } else if verbose() { + eprintln!("Initializing model: {}", model_name); + None + } else if is_cached { let pb = ProgressBar::new_spinner(); pb.set_style( ProgressStyle::default_spinner() @@ -451,9 +496,6 @@ impl ModelPool { pb.enable_steady_tick(Duration::from_millis(100)); pb.set_message(format!("Initializing {}", model_name)); Some(pb) - } else if should_enable_logging() { - eprintln!("Initializing model: {}", model_name); - None } else { None }; @@ -467,7 +509,7 @@ impl ModelPool { result }; - if should_enable_logging() { + if verbose() { eprintln!("Model loaded successfully: {}", model_name); } diff --git a/spnl/src/generate/backend/mistralrs/mod.rs b/spnl/src/generate/backend/mistralrs/mod.rs index 0d612fa6..277365d5 100644 --- a/spnl/src/generate/backend/mistralrs/mod.rs +++ b/spnl/src/generate/backend/mistralrs/mod.rs @@ -99,7 +99,7 @@ pub async fn generate_completion( // Get or load the model let pool = get_model_pool(); - let model = pool.get_or_load(&model_name).await?; + let model = pool.get_or_load(&model_name, options.silent).await?; // Timing tracking let start_time = if options.time { @@ -318,7 +318,7 @@ pub async fn generate_chat( // Get or load the model let pool = get_model_pool(); - let model = pool.get_or_load(&model_name).await?; + let model = pool.get_or_load(&model_name, options.silent).await?; // Create semaphore for concurrency control let max_parallel = get_max_parallel(); @@ -492,8 +492,18 @@ pub async fn generate_chat( Ok(Query::Par(final_results)) } -/// Add messages from a Query to a RequestBuilder +/// Add messages from a Query to a RequestBuilder. +/// When `in_plus` is true, messages are tagged with a `pic_plus` metadata field +/// so the engine can identify Plus blocks after tokenization. fn add_messages_from_query(builder: &mut RequestBuilder, query: &Query) -> anyhow::Result<()> { + add_messages_from_query_inner(builder, query, false) +} + +fn add_messages_from_query_inner( + builder: &mut RequestBuilder, + query: &Query, + in_plus: bool, +) -> anyhow::Result<()> { match query { Query::Message(msg) => { let (role, content) = match msg { @@ -501,14 +511,25 @@ fn add_messages_from_query(builder: &mut RequestBuilder, query: &Query) -> anyho Assistant(content) => (TextMessageRole::Assistant, content), System(content) => (TextMessageRole::System, content), }; - *builder = builder.clone().add_message(role, content.clone()); + if in_plus { + // Tag this message so the engine knows it's a Plus block. + // Uses add_message_with_metadata to attach an in-band flag + // that add_request.rs will extract and strip before templating. + *builder = builder + .clone() + .add_message(role, format!("\x00PIC_PLUS\x00{}", content)); + } else { + *builder = builder.clone().add_message(role, content.clone()); + } + } + Query::Plus(queries) => { + for q in queries { + add_messages_from_query_inner(builder, q, true)?; + } } - Query::Seq(queries) - | Query::Par(queries) - | Query::Plus(queries) - | Query::Cross(queries) => { + Query::Seq(queries) | Query::Par(queries) | Query::Cross(queries) => { for q in queries { - add_messages_from_query(builder, q)?; + add_messages_from_query_inner(builder, q, in_plus)?; } } _ => { @@ -520,4 +541,93 @@ fn add_messages_from_query(builder: &mut RequestBuilder, query: &Query) -> anyho Ok(()) } +#[cfg(all(test, feature = "local"))] +mod tests { + use super::*; + use mistralrs::{MessageContent, RequestLike}; + + /// Extract the text content of a message (the "content" key, Left variant). + fn content_str(msg: &indexmap::IndexMap) -> &str { + match msg.get("content").unwrap() { + either::Either::Left(s) => s.as_str(), + _ => panic!("expected Left(String) content"), + } + } + + #[test] + fn plain_user_message() { + let query = Query::Message(User("hello".to_string())); + let mut builder = RequestBuilder::new(); + add_messages_from_query_inner(&mut builder, &query, false).unwrap(); + let msgs = builder.messages_ref(); + assert_eq!(msgs.len(), 1); + assert_eq!(content_str(&msgs[0]), "hello"); + } + + #[test] + fn plus_tags_messages() { + let query = Query::Plus(vec![Query::Message(User("doc1".to_string()))]); + let mut builder = RequestBuilder::new(); + add_messages_from_query(&mut builder, &query).unwrap(); + let msgs = builder.messages_ref(); + assert_eq!(msgs.len(), 1); + assert!( + content_str(&msgs[0]).starts_with("\x00PIC_PLUS\x00"), + "expected PIC_PLUS prefix, got: {:?}", + content_str(&msgs[0]) + ); + } + + #[test] + fn seq_inside_plus_all_tagged() { + let query = Query::Plus(vec![Query::Seq(vec![ + Query::Message(User("a".to_string())), + Query::Message(User("b".to_string())), + ])]); + let mut builder = RequestBuilder::new(); + add_messages_from_query(&mut builder, &query).unwrap(); + let msgs = builder.messages_ref(); + assert_eq!(msgs.len(), 2); + for msg in msgs { + assert!( + content_str(msg).starts_with("\x00PIC_PLUS\x00"), + "expected PIC_PLUS prefix, got: {:?}", + content_str(msg) + ); + } + } + + #[test] + fn plus_inside_cross() { + let query = Query::Cross(vec![ + Query::Message(User("q".to_string())), + Query::Plus(vec![Query::Message(User("doc".to_string()))]), + ]); + let mut builder = RequestBuilder::new(); + add_messages_from_query(&mut builder, &query).unwrap(); + let msgs = builder.messages_ref(); + assert_eq!(msgs.len(), 2); + // First message: plain (no prefix) + assert_eq!(content_str(&msgs[0]), "q"); + // Second message: tagged with PIC_PLUS + assert_eq!(content_str(&msgs[1]), "\x00PIC_PLUS\x00doc"); + } + + #[test] + fn unsupported_query_type() { + use crate::ir::{Generate, GenerateMetadata}; + let query = Query::Generate(Generate { + metadata: GenerateMetadata { + model: "test".to_string(), + max_tokens: Some(1), + temperature: Some(0.0), + }, + input: Box::new(Query::Message(User("hi".to_string()))), + }); + let mut builder = RequestBuilder::new(); + let result = add_messages_from_query(&mut builder, &query); + assert!(result.is_err()); + } +} + // Made with Bob diff --git a/spnl/src/lib.rs b/spnl/src/lib.rs index f2c931c2..fe55e3e9 100644 --- a/spnl/src/lib.rs +++ b/spnl/src/lib.rs @@ -39,3 +39,13 @@ pub mod model_pool { crate::generate::backend::mistralrs::unload_all_models().await } } + +/// PIC cache hit/miss statistics (delegates to mistralrs-core). +/// Only available with the `local` feature. +#[cfg(feature = "local")] +pub mod pic_stats { + /// Read and reset global PIC cache hit/miss counters. Returns `(hits, misses)`. + pub fn take_cache_stats() -> (u64, u64) { + mistralrs::pic::take_cache_stats() + } +} diff --git a/spnl/src/optimizer/hlo.rs b/spnl/src/optimizer/hlo.rs index ca2fbca6..cda7b7f4 100644 --- a/spnl/src/optimizer/hlo.rs +++ b/spnl/src/optimizer/hlo.rs @@ -46,12 +46,35 @@ async fn optimize_vec_iter<'a>( ) } -/// Wrap a 1-token inner generate around each fragment -#[cfg(feature = "rag")] +/// Wrap a 1-token inner generate around each fragment. +/// +/// The generate includes any Cross context (e.g. system prompt) from the +/// parent generate so that the cached Plus block KV encodes system-prompt +/// attention — making it usable by the main query without re-prefilling +/// the Plus blocks. fn prepare_fragment(m: &Query, parent_generate: Option<&Generate>) -> Option { if let Some(g) = parent_generate && supports_spans(&g.metadata.model) { + // Extract leading Cross context (system prompt, etc.) from parent input, + // stopping before Plus blocks or the trailing question. + let cross_prefix: Vec = match &*g.input { + Query::Cross(items) => items + .iter() + .take_while(|q| matches!(q, Query::Message(crate::ir::Message::System(_)))) + .cloned() + .collect(), + _ => vec![], + }; + + let input = if cross_prefix.is_empty() { + Query::Plus(vec![m.clone()]) + } else { + let mut children = cross_prefix; + children.push(Query::Plus(vec![m.clone()])); + Query::Cross(children) + }; + Some(Query::Generate( GenerateBuilder::from(g) .metadata( @@ -61,7 +84,7 @@ fn prepare_fragment(m: &Query, parent_generate: Option<&Generate>) -> Option) -> Option) -> Option { if !prepares.is_empty() { Some(Query::Monad(Query::Plus(prepares).into())) @@ -86,7 +108,17 @@ async fn optimize_iter<'a>( attrs: &'a InheritedAttributes<'a>, ) -> anyhow::Result { match query { - Query::Plus(v) => Ok(Query::Plus(optimize_vec_iter(v, attrs).await?)), + Query::Plus(v) => { + let optimized = optimize_vec_iter(v, attrs).await?; + let prepares: Vec<_> = optimized + .iter() + .filter_map(|m| prepare_fragment(m, attrs.parent_generate)) + .collect(); + match prepare_monad(prepares) { + Some(monad) => Ok(Query::Seq(vec![monad, Query::Plus(optimized)])), + None => Ok(Query::Plus(optimized)), + } + } Query::Cross(v) => Ok(Query::Cross(optimize_vec_iter(v, attrs).await?)), #[cfg(feature = "rag")] @@ -238,7 +270,7 @@ mod tests { #[tokio::test] // <-- needed for async tests async fn nested_gen_expect_no_span_optimization() -> anyhow::Result<()> { - let (outer_generate, _, _, _, _) = nested_gen_query("m")?; + let (outer_generate, _, _, _, _) = nested_gen_query("openai/m")?; assert_eq!( optimize(&outer_generate, &Options::default()).await?, outer_generate,