Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 83 additions & 12 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use serde::{Deserialize, Serialize};
use crate::utils::iter::ResultShunt;
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

mod added_vocabulary;
mod encoding;
Expand Down Expand Up @@ -1309,10 +1310,8 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings =
self.run_batch(inputs, |this, input| this.encode(input, add_special_tokens))?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
Expand All @@ -1332,10 +1331,9 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode_char_offsets(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings = self.run_batch(inputs, |this, input| {
this.encode_char_offsets(input, add_special_tokens)
})?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
Expand All @@ -1354,10 +1352,9 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode_fast(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings = self.run_batch(inputs, |this, input| {
this.encode_fast(input, add_special_tokens)
})?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
Expand All @@ -1367,6 +1364,66 @@ where
Ok(encodings)
}

/// Shared implementation for the batch-encode entry points.
///
/// Applies `encode_fn` to every input, in parallel when parallelism is
/// enabled and more than one worker is available.
///
/// Parallel work is distributed with rayon's
/// `IndexedParallelIterator::with_min_len`, which sets a floor on the
/// number of items a single rayon task processes sequentially. Without
/// that floor rayon's `bridge_producer_consumer` splits the batch down
/// to one item per task; at high core counts the per-task wake/join
/// signaling (a CAS / LDADD atomic on arm64, a `lock`-prefixed op on
/// x86_64) then dominates. Coarsening tasks to `min_len` items
/// amortizes that signaling by the same factor while still handing each
/// worker several tasks for load balance.
fn run_batch<'s, E, F>(&self, inputs: Vec<E>, encode_fn: F) -> Result<Vec<Encoding>>
where
E: Into<EncodeInput<'s>> + Send,
F: Fn(&Self, EncodeInput<'s>) -> Result<Encoding> + Sync,
{
let n = inputs.len();
if n == 0 {
return Ok(vec![]);
}

let parallelism = get_parallelism();
if parallelism {
// Mirror the side effect of `into_maybe_par_iter`: record that
// the rayon pool may be used, so the Python `pthread_atfork`
// hook disables parallelism in forked children.
set_parallelism_used();
}
let num_threads = if parallelism {
current_num_threads().min(n)
} else {
1
};

if num_threads <= 1 {
return inputs
.into_iter()
.map(|input| encode_fn(self, input.into()))
.collect();
}

// Aim for ~WINDOWS_PER_THREAD tasks per worker for load balance,
// each task at least `min_len` items so per-task wake/join atomics
// are amortized, capped at MAX_MIN_LEN to avoid oversized tasks.
const WINDOWS_PER_THREAD: usize = 4;
const MAX_MIN_LEN: usize = 8;
let min_len = n
.div_ceil(num_threads.saturating_mul(WINDOWS_PER_THREAD).max(1))
.clamp(1, MAX_MIN_LEN);

inputs
.into_par_iter()
.with_min_len(min_len)
.map(|input| encode_fn(self, input.into()))
.collect()
}

/// Decode all sentences in parallel
pub fn decode_batch(
&self,
Expand Down Expand Up @@ -1632,6 +1689,20 @@ mod tests {
tokenizer
}

#[test]
fn encode_batch_marks_parallelism_used() {
// Regression: the batch-encode path must record parallelism usage
// so the Python pthread_atfork hook disables rayon in forked
// children. USED_PARALLELISM is a sticky process-global; for a
// strict check run this test in isolation:
// cargo test encode_batch_marks_parallelism_used
set_parallelism(true);
let tok = test_tokenizer();
let inputs: Vec<&str> = vec!["a b c", "d e f", "g h i"];
let _ = tok.encode_batch(inputs, false).unwrap();
assert!(has_parallelism_been_used());
}

#[test]
fn right_truncation_early_exit_matches_full_encode() {
// "a b c d e f g h i j" → 10 tokens [0,1,2,3,4,5,6,7,8,9]
Expand Down
9 changes: 9 additions & 0 deletions tokenizers/src/utils/parallelism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ pub fn has_parallelism_been_used() -> bool {
USED_PARALLELISM.load(Ordering::SeqCst)
}

/// Record that a parallel iterator has been used.
///
/// `into_maybe_par_iter` sets this automatically; code paths that drive
/// rayon directly (e.g. `with_min_len`) must call this so the Python
/// `pthread_atfork` hook can disable parallelism in forked children.
pub fn set_parallelism_used() {
USED_PARALLELISM.store(true, Ordering::SeqCst);
}

/// Get internally set parallelism
fn get_override_parallelism() -> Option<bool> {
match PARALLELISM.load(Ordering::SeqCst) {
Expand Down
Loading