From 5dedaf06b57855f5ad9e6a95ed6e7a3ac7fddf02 Mon Sep 17 00:00:00 2001 From: Sebastian Pop Date: Thu, 28 May 2026 16:20:01 -0500 Subject: [PATCH] Batch encode: coarsen rayon tasks with with_min_len MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `encode_batch`, `encode_batch_char_offsets` and `encode_batch_fast` ran `inputs.into_maybe_par_iter().map(...).collect()`. rayon's `bridge_producer_consumer` splits that down toward one item per task, and at high core counts the per-task wake/join signaling (an LSE CAS/LDADD on aarch64) becomes a large fraction of cycles. Route all three through a shared `run_batch` helper that: - takes an explicit serial fast path (`into_iter().map().collect()`) when parallelism is disabled or only one worker is available, so single-threaded callers pay no rayon cost at all; and - otherwise distributes work with `IndexedParallelIterator::with_min_len(min_len)`, which sets a floor on the items processed sequentially per task. `min_len` aims for ~4 tasks per worker (load balance) and is capped at 8 (avoid oversized tasks): `min_len = ceil(n / (threads*4)).clamp(1, 8)`. This is intentionally the smallest change that coarsens the tasks: it is one rayon method call rather than a hand-written work queue, and it keeps rayon's scheduler. Because `run_batch` drives rayon directly instead of going through `into_maybe_par_iter`, it must reproduce that helper's side effect of recording parallelism use. A new `parallelism::set_parallelism_used()` is called whenever the parallel path is taken, so the Python `pthread_atfork` child hook still disables rayon in forked children (the documented multiprocessing-deadlock protection). A regression test, `encode_batch_marks_parallelism_used`, asserts `has_parallelism_been_used()` after a parallel `encode_batch`. Measured on three machines, `bpe_benchmark`/`bpe-encode/BPE GPT2 encode batch` (data/big.txt, encode_batch through the post-processor), baseline = current main, patched = this change. NVIDIA Vera (aarch64 Olympus, 88 physical / 176 logical): threads baseline patched change 1 3.91 MiB/s 4.63 MiB/s +18% (serial fast path) 88 18.03 MiB/s 19.57 MiB/s +8.5% 176 17.21 MiB/s 18.67 MiB/s +8.5% AMD EPYC 9124 (x86_64, 16 physical / 32 logical): threads baseline patched change 1 3.69 MiB/s 3.89 MiB/s +5.5% 16 23.38 MiB/s 25.05 MiB/s +7.1% 32 24.11 MiB/s 25.81 MiB/s +7.0% Apple M3 (aarch64, 12 cores, dev host): threads baseline patched change 1 4.66 MiB/s 4.70 MiB/s ~0 6 14.97 MiB/s 14.67 MiB/s ~0 (within noise) 12 19.08 MiB/s 17.89 MiB/s within thermal noise The M3 is a thermally-limited 12-core laptop; its criterion intervals are wide (+/-5%) and vary across a long measurement session, and the result is insensitive to the `min_len` cap (cap 8/4/2 all land in the same band). Treat it as approximately neutral, not a regression — the reliable signal is the two isolated servers. Atomics / bottleneck, `perf record -g --call-graph fp -F 4999`: aarch64 (Vera) at 88T -- LSE atomic outlined-call share (sum of `__aarch64_*`): baseline 4.97% patched 0.91% (-5.5x) Fewer rayon tasks means fewer per-task atomic operations, which is why the LSE share drops sharply on aarch64 where each LSE op is expensive. x86_64 (EPYC) at 16T -- the same rayon/crossbeam-epoch machinery is cheap here, so there is no comparable atomic bottleneck to remove: symbol baseline patched crossbeam_epoch::default::with_handle 1.19% 1.62% rayon_core::WorkerThread::wait_until_cold 0.69% 0.96% crossbeam_epoch::internal::Global::try_advance 0.19% 0.25% The x86_64 gain therefore comes from reduced task-scheduling overhead generally, not from atomics -- x86_64 never had the atomic problem that aarch64 has. Remaining ceiling: on aarch64, crossbeam-epoch dispatch (`try_advance` + `with_handle` + `wait_until_cold`) is ~56% of cycles in both baseline and patched -- with_min_len does not touch it. Removing that requires replacing rayon's scheduler on the hot path and is left to a follow-up. cargo test --lib --features http: 202 passed, 0 failed. --- tokenizers/src/tokenizer/mod.rs | 95 +++++++++++++++++++++++++---- tokenizers/src/utils/parallelism.rs | 9 +++ 2 files changed, 92 insertions(+), 12 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a5bea1e4c5..ec09995f44 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -23,6 +23,7 @@ use serde::{Deserialize, Serialize}; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; mod added_vocabulary; mod encoding; @@ -1342,10 +1343,8 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = + self.run_batch(inputs, |this, input| this.encode(input, add_special_tokens))?; if let Some(params) = &self.padding { // We do the padding here to make sure we handle the batch padding @@ -1365,10 +1364,9 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_char_offsets(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_char_offsets(input, add_special_tokens) + })?; if let Some(params) = &self.padding { // We do the padding here to make sure we handle the batch padding @@ -1387,10 +1385,9 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_fast(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_fast(input, add_special_tokens) + })?; if let Some(params) = &self.padding { // We do the padding here to make sure we handle the batch padding @@ -1400,6 +1397,66 @@ where Ok(encodings) } + /// Shared implementation for the batch-encode entry points. + /// + /// Applies `encode_fn` to every input, in parallel when parallelism is + /// enabled and more than one worker is available. + /// + /// Parallel work is distributed with rayon's + /// `IndexedParallelIterator::with_min_len`, which sets a floor on the + /// number of items a single rayon task processes sequentially. Without + /// that floor rayon's `bridge_producer_consumer` splits the batch down + /// to one item per task; at high core counts the per-task wake/join + /// signaling (a CAS / LDADD atomic on arm64, a `lock`-prefixed op on + /// x86_64) then dominates. Coarsening tasks to `min_len` items + /// amortizes that signaling by the same factor while still handing each + /// worker several tasks for load balance. + fn run_batch<'s, E, F>(&self, inputs: Vec, encode_fn: F) -> Result> + where + E: Into> + Send, + F: Fn(&Self, EncodeInput<'s>) -> Result + Sync, + { + let n = inputs.len(); + if n == 0 { + return Ok(vec![]); + } + + let parallelism = get_parallelism(); + if parallelism { + // Mirror the side effect of `into_maybe_par_iter`: record that + // the rayon pool may be used, so the Python `pthread_atfork` + // hook disables parallelism in forked children. + set_parallelism_used(); + } + let num_threads = if parallelism { + current_num_threads().min(n) + } else { + 1 + }; + + if num_threads <= 1 { + return inputs + .into_iter() + .map(|input| encode_fn(self, input.into())) + .collect(); + } + + // Aim for ~WINDOWS_PER_THREAD tasks per worker for load balance, + // each task at least `min_len` items so per-task wake/join atomics + // are amortized, capped at MAX_MIN_LEN to avoid oversized tasks. + const WINDOWS_PER_THREAD: usize = 4; + const MAX_MIN_LEN: usize = 8; + let min_len = n + .div_ceil(num_threads.saturating_mul(WINDOWS_PER_THREAD).max(1)) + .clamp(1, MAX_MIN_LEN); + + inputs + .into_par_iter() + .with_min_len(min_len) + .map(|input| encode_fn(self, input.into())) + .collect() + } + /// Decode all sentences in parallel pub fn decode_batch( &self, @@ -1665,6 +1722,20 @@ mod tests { tokenizer } + #[test] + fn encode_batch_marks_parallelism_used() { + // Regression: the batch-encode path must record parallelism usage + // so the Python pthread_atfork hook disables rayon in forked + // children. USED_PARALLELISM is a sticky process-global; for a + // strict check run this test in isolation: + // cargo test encode_batch_marks_parallelism_used + set_parallelism(true); + let tok = test_tokenizer(); + let inputs: Vec<&str> = vec!["a b c", "d e f", "g h i"]; + let _ = tok.encode_batch(inputs, false).unwrap(); + assert!(has_parallelism_been_used()); + } + #[test] fn right_truncation_early_exit_matches_full_encode() { // "a b c d e f g h i j" → 10 tokens [0,1,2,3,4,5,6,7,8,9] diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index ea2fd331a6..4d81b65b76 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -27,6 +27,15 @@ pub fn has_parallelism_been_used() -> bool { USED_PARALLELISM.load(Ordering::SeqCst) } +/// Record that a parallel iterator has been used. +/// +/// `into_maybe_par_iter` sets this automatically; code paths that drive +/// rayon directly (e.g. `with_min_len`) must call this so the Python +/// `pthread_atfork` hook can disable parallelism in forked children. +pub fn set_parallelism_used() { + USED_PARALLELISM.store(true, Ordering::SeqCst); +} + /// Get internally set parallelism fn get_override_parallelism() -> Option { match PARALLELISM.load(Ordering::SeqCst) {