From 1e868b47e840f3b66dfa52a757a153d518445939 Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Sun, 24 May 2026 10:29:42 -0400 Subject: [PATCH 1/2] Fix Unigram trainer prune loss to use per-piece alternative count prune_sentence_pieces computed logsum_alt with alternatives.len() (the total piece count) instead of alternatives[id].len() (the alternatives for the piece being scored), diverging from SentencePiece's PruneSentencePieces. This inflates the loss for high-frequency pieces and changes which pieces survive pruning. Add the first tests for prune_sentence_pieces (previously uncovered): a regression test locking the per-piece keep decision, and a smoke test exercising the EM + prune loop end to end. Refs huggingface/tokenizers#2069 Co-Authored-By: Claude Opus 4.7 (1M context) --- tokenizers/src/models/unigram/trainer.rs | 101 ++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index ff5ca9428a..8715e5fb12 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -396,10 +396,10 @@ impl UnigramTrainer { // After removing the sentencepiece[i], its frequency freq[i] is // re-assigned to alternatives. - // new_sum = current_sum - freq[i] + freq[i] * alternatives.size() - // = current_sum + freq[i] (alternatives - 1) + // new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size() + // = current_sum + freq[i] * (alternatives[i] - 1) - let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln(); + let logsum_alt = (sum + freq[id] * (alternatives[id].len() - 1) as f64).ln(); // The frequencies of alternatives are increased by freq[i]. let mut logprob_alt = 0.0; @@ -812,6 +812,101 @@ mod tests { assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0))); } + #[test] + fn test_prune_sentence_pieces_keeps_costly_alternative() { + // Two candidates compete for a single surviving slot. "xy" is expensive + // to replace (its parts never occur alone), so SentencePiece keeps it; + // "pq" is cheap (its parts are frequent). Scoring the loss with the total + // piece count instead of alternatives[id].len() over-credits "pq". + let trainer = UnigramTrainerBuilder::default() + .show_progress(false) + .vocab_size(1) + .shrinking_factor(0.9) + .build() + .unwrap(); + + let pieces: Vec = vec![ + ("".into(), 0.0), + ("p".into(), -10.0), + ("q".into(), -10.0), + ("x".into(), -10.0), + ("y".into(), -10.0), + ("m".into(), -10.0), + ("n".into(), -10.0), + ("pq".into(), -1.0), + ("xy".into(), -1.0), + ]; + let sentences: Vec = vec![ + ("pq".into(), 20), + ("xy".into(), 5), + ("p".into(), 50), + ("q".into(), 50), + ]; + let model = Unigram::from(pieces.clone(), Some(0), false).unwrap(); + + let kept: Vec = trainer + .prune_sentence_pieces(&model, &pieces, &sentences) + .into_iter() + .filter(|(token, _)| token != "" && token.chars().count() >= 2) + .map(|(token, _)| token) + .collect(); + assert_eq!(kept, vec!["xy".to_string()]); + } + + #[test] + fn test_do_train_runs_prune() { + // A seed vocabulary larger than desired_vocab_size forces the EM loop to + // call prune_sentence_pieces, unlike the small-corpus tests above which + // break out of the loop before pruning. + let vocab_size = 60; + let prefixes = ["", "re", "un", "in", "de", "pre", "over", "under"]; + let stems = [ + "nation", + "organ", + "real", + "general", + "central", + "local", + "global", + "modern", + "civil", + "token", + "format", + "structure", + "compute", + "special", + ]; + let suffixes = [ + "", "al", "ize", "ation", "ized", "izing", "ism", "ist", "ly", "s", + ]; + let mut sentences: Vec = Vec::new(); + for prefix in prefixes { + for stem in stems { + for suffix in suffixes { + sentences.push((format!("{}{}{}", prefix, stem, suffix), 5)); + } + } + } + + let trainer = UnigramTrainerBuilder::default() + .show_progress(false) + .vocab_size(vocab_size) + .build() + .unwrap(); + let mut model = Unigram::default(); + trainer.do_train(sentences.clone(), &mut model).unwrap(); + + assert!(model.iter().count() <= vocab_size as usize); + let vocab: AHashSet = model.iter().map(|(token, _)| token.clone()).collect(); + for required in trainer.required_chars(&sentences) { + assert!( + vocab.contains(&required), + "missing required char {:?}", + required + ); + } + } + #[test] fn test_to_log_prob() { let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)]; From 4f6e3760f98e66dbf84d21fd3f0adbfa2bb1a585 Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Sun, 24 May 2026 10:45:16 -0400 Subject: [PATCH 2/2] Document prune loss parity with SentencePiece Comment-only follow-up to the per-piece alternatives fix (no logic change): - Cite SentencePiece's Trainer::PruneSentencePieces at the fix site and note why the term must use alternatives[id], not alternatives.len(), to guard against silent re-drift of the ported comment. - Tidy the drifted comment: (alternatives[i] - 1) -> (alternatives[i].size() - 1). - Explain the slot arithmetic in test_prune_sentence_pieces_keeps_costly_alternative (why exactly one slot is contested, and what m/n padding is for). Co-Authored-By: Claude Opus 4.7 (1M context) --- tokenizers/src/models/unigram/trainer.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 8715e5fb12..6f604b8b31 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -397,8 +397,9 @@ impl UnigramTrainer { // After removing the sentencepiece[i], its frequency freq[i] is // re-assigned to alternatives. // new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size() - // = current_sum + freq[i] * (alternatives[i] - 1) - + // = current_sum + freq[i] * (alternatives[i].size() - 1) + // Per-piece alternatives[id], not the total piece count; mirrors SentencePiece + // src/unigram_model_trainer.cc, Trainer::PruneSentencePieces. let logsum_alt = (sum + freq[id] * (alternatives[id].len() - 1) as f64).ln(); // The frequencies of alternatives are increased by freq[i]. @@ -818,6 +819,11 @@ mod tests { // to replace (its parts never occur alone), so SentencePiece keeps it; // "pq" is cheap (its parts are frequent). Scoring the loss with the total // piece count instead of alternatives[id].len() over-credits "pq". + // + // and the six single chars are always kept (no alternative + // segmentation); m/n are padding so that pruning leaves exactly one slot + // for the pq-vs-xy contest. The loss ranking decides the survivor: the + // fix keeps "xy"; the old `alternatives.len()` line keeps "pq". let trainer = UnigramTrainerBuilder::default() .show_progress(false) .vocab_size(1)