Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions tokenizers/src/models/unigram/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,11 @@ impl UnigramTrainer {

// After removing the sentencepiece[i], its frequency freq[i] is
// re-assigned to alternatives.
// new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
// = current_sum + freq[i] (alternatives - 1)

let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln();
// new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size()
// = current_sum + freq[i] * (alternatives[i].size() - 1)
// Per-piece alternatives[id], not the total piece count; mirrors SentencePiece
// src/unigram_model_trainer.cc, Trainer::PruneSentencePieces.
let logsum_alt = (sum + freq[id] * (alternatives[id].len() - 1) as f64).ln();

// The frequencies of alternatives are increased by freq[i].
let mut logprob_alt = 0.0;
Expand Down Expand Up @@ -812,6 +813,106 @@ mod tests {
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
}

#[test]
fn test_prune_sentence_pieces_keeps_costly_alternative() {
// Two candidates compete for a single surviving slot. "xy" is expensive
// to replace (its parts never occur alone), so SentencePiece keeps it;
// "pq" is cheap (its parts are frequent). Scoring the loss with the total
// piece count instead of alternatives[id].len() over-credits "pq".
//
// <UNK> and the six single chars are always kept (no alternative
// segmentation); m/n are padding so that pruning leaves exactly one slot
// for the pq-vs-xy contest. The loss ranking decides the survivor: the
// fix keeps "xy"; the old `alternatives.len()` line keeps "pq".
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.vocab_size(1)
.shrinking_factor(0.9)
.build()
.unwrap();

let pieces: Vec<SentencePiece> = vec![
("<UNK>".into(), 0.0),
("p".into(), -10.0),
("q".into(), -10.0),
("x".into(), -10.0),
("y".into(), -10.0),
("m".into(), -10.0),
("n".into(), -10.0),
("pq".into(), -1.0),
("xy".into(), -1.0),
];
let sentences: Vec<Sentence> = vec![
("pq".into(), 20),
("xy".into(), 5),
("p".into(), 50),
("q".into(), 50),
];
let model = Unigram::from(pieces.clone(), Some(0), false).unwrap();

let kept: Vec<String> = trainer
.prune_sentence_pieces(&model, &pieces, &sentences)
.into_iter()
.filter(|(token, _)| token != "<UNK>" && token.chars().count() >= 2)
.map(|(token, _)| token)
.collect();
assert_eq!(kept, vec!["xy".to_string()]);
}

#[test]
fn test_do_train_runs_prune() {
// A seed vocabulary larger than desired_vocab_size forces the EM loop to
// call prune_sentence_pieces, unlike the small-corpus tests above which
// break out of the loop before pruning.
let vocab_size = 60;
let prefixes = ["", "re", "un", "in", "de", "pre", "over", "under"];
let stems = [
"nation",
"organ",
"real",
"general",
"central",
"local",
"global",
"modern",
"civil",
"token",
"format",
"structure",
"compute",
"special",
];
let suffixes = [
"", "al", "ize", "ation", "ized", "izing", "ism", "ist", "ly", "s",
];
let mut sentences: Vec<Sentence> = Vec::new();
for prefix in prefixes {
for stem in stems {
for suffix in suffixes {
sentences.push((format!("{}{}{}", prefix, stem, suffix), 5));
}
}
}

let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.vocab_size(vocab_size)
.build()
.unwrap();
let mut model = Unigram::default();
trainer.do_train(sentences.clone(), &mut model).unwrap();

assert!(model.iter().count() <= vocab_size as usize);
let vocab: AHashSet<String> = model.iter().map(|(token, _)| token.clone()).collect();
for required in trainer.required_chars(&sentences) {
assert!(
vocab.contains(&required),
"missing required char {:?}",
required
);
}
}

#[test]
fn test_to_log_prob() {
let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)];
Expand Down