diff --git a/bindings/python/tests/bindings/test_normalizers.py b/bindings/python/tests/bindings/test_normalizers.py index 99ab07d393..f2500eac39 100644 --- a/bindings/python/tests/bindings/test_normalizers.py +++ b/bindings/python/tests/bindings/test_normalizers.py @@ -2,7 +2,7 @@ import pytest -from tokenizers import NormalizedString +from tokenizers import NormalizedString, Regex from tokenizers.normalizers import ( BertNormalizer, Lowercase, @@ -201,6 +201,23 @@ def test_can_modify(self): assert normalizer.prepend == "-" +class TestReplace: + def test_replace_with_groups(self): + normalizer = Replace(Regex(r"(l)(e)"), r"$1 $2") + result = normalizer.normalize_str("le travail") + assert result == "l e travail" + + def test_replace_with_group_zero(self): + normalizer = Replace(Regex(r"(\w+)"), r"[$0]") + result = normalizer.normalize_str("hello world") + assert result == "[hello] [world]" + + def test_replace_no_capture_unchanged(self): + normalizer = Replace(Regex(r"\s+"), " ") + result = normalizer.normalize_str("hello world") + assert result == "hello world" + + class TestCustomNormalizer: class BadCustomNormalizer: def normalize(self, normalized, wrong): diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 5657574830..9bf48471b6 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -49,6 +49,8 @@ pub struct Replace { pub content: String, #[serde(skip)] regex: SysRegex, + #[serde(skip)] + expansion_re: Option, } impl Clone for Replace { @@ -70,38 +72,48 @@ impl Replace { ReplacePattern::String(s) => SysRegex::new(®ex::escape(s))?, ReplacePattern::Regex(r) => SysRegex::new(r)?, }; + let expansion_re = match &pattern { + ReplacePattern::String(_) => None, + ReplacePattern::Regex(r) => regex::Regex::new(r).ok(), + }; Ok(Self { pattern, content: content.into(), regex, + expansion_re, }) } } impl Normalizer for Replace { fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { - normalized.replace(&self.regex, &self.content) + normalized.replace_regex(&self.regex, self.expansion_re.as_ref(), &self.content) } } impl Decoder for Replace { fn decode_chain(&self, tokens: Vec) -> Result> { - tokens - .into_iter() - .map(|token| -> Result { - let mut new_token = "".to_string(); - - for ((start, stop), is_match) in (&self.regex).find_matches(&token)? { - if is_match { - new_token.push_str(&self.content); - } else { - new_token.push_str(&token[start..stop]); + match &self.expansion_re { + Some(re) => tokens + .into_iter() + .map(|token| Ok(re.replace_all(&token, &self.content).to_string())) + .collect(), + None => tokens + .into_iter() + .map(|token| { + let mut new_token = String::new(); + for ((start, stop), is_match) in (&self.regex).find_matches(&token)? { + if is_match { + new_token.push_str(&self.content); + } else { + new_token.push_str(&token[start..stop]); + } } - } - Ok(new_token) - }) - .collect() + Ok(new_token) + }) + .collect(), + } } } @@ -156,4 +168,28 @@ mod tests { vec!["hello", " hello"] ); } + + #[test] + fn test_replace_regex_groups_dollar() { + let original = "le travail"; + let expected = "l e travail"; + + let mut n = NormalizedString::from(original); + Replace::new(ReplacePattern::Regex(r"(l)(e)".into()), r"$1 $2") + .unwrap() + .normalize(&mut n) + .unwrap(); + + assert_eq!(&n.get(), &expected); + } + + #[test] + fn test_replace_regex_groups_decode() { + let tokens = vec!["le".to_string(), "test".to_string()]; + let replace = Replace::new(ReplacePattern::Regex(r"(l)(e)".into()), r"$1 $2").unwrap(); + assert_eq!( + replace.decode_chain(tokens).unwrap(), + vec!["l e", "test"] + ); + } } diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 5bebd5f7b4..61b825f6fd 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,4 +1,5 @@ use crate::pattern::Pattern; +use crate::utils::SysRegex; use crate::{Offsets, Result}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; @@ -673,6 +674,106 @@ impl NormalizedString { Ok(()) } + /// Replace anything that matches the given regex with the given content, + /// supporting backreferences (`$1`, `$2`, etc.) in the replacement string. + pub fn replace_regex( + &mut self, + regex: &SysRegex, + expansion_re: Option<®ex::Regex>, + content: &str, + ) -> Result<()> { + let mut new_normalized = String::with_capacity(self.normalized.len()); // Initially allocate for the input size + let mut new_alignments: Vec<(usize, usize)> = Vec::with_capacity(self.alignments.len()); + let mut last_end = 0; // Keep track of the last end position + + for (start, end) in regex.find_iter(&self.normalized) { + let range = start..end; + + let matched = &self.normalized[start..end]; + let expanded = match (matched.is_empty(), expansion_re) { + (true, _) | (_, None) => content.to_string(), + (false, Some(re)) => re.replacen(matched, 1, content).to_string(), + }; + + let removed_chars = self.normalized[range.clone()].chars().count(); + + // Copy the part of the string that is before the match + new_normalized.push_str(&self.normalized[last_end..start]); + new_alignments.extend(self.alignments[last_end..start].iter().cloned()); + + let n_range = Range::Normalized(range).into_full_range(self.len()); + + // Retrieve the original characters that are being replaced. This let us + // compute the change in byte sizes along the way. + let mut replaced_normalized = self.normalized[n_range.clone()] + .chars() + .collect::>() + .into_iter(); + let initial_removed: usize = (&mut replaced_normalized) + .take(removed_chars) + .map(|c| c.len_utf8()) + .sum(); + + let mut offset = (initial_removed + n_range.start) as isize; + let normalized = expanded + .chars() + .map(|c| (c, 1)) + .map(|(c, changes): (char, i32)| { + let idx = offset as usize; + let align = if changes.is_positive() { + if idx < 1 { + (0, 0) + } else { + // This is a newly inserted character, so it shares the same alignment + // than the previous one + self.alignments[idx - 1] + } + } else { + self.alignments[idx] + }; + + // If we are replacing a character, find it and compute the change in size + let replaced_char = if !changes.is_positive() { + replaced_normalized.next() + } else { + None + }; + let replaced_char_size = replaced_char.map_or(0, |c| c.len_utf8()); + + // If we are removing some characters, find them too + let total_bytes_to_remove = if changes.is_negative() { + (&mut replaced_normalized) + .take(-changes as usize) + .map(|c| c.len_utf8()) + .sum() + } else { + 0 + }; + + // Keep track of the changes for next offsets + offset += replaced_char_size as isize; + offset += total_bytes_to_remove as isize; + + new_alignments.extend((0..c.len_utf8()).map(|_| align)); + + // Then we keep only the char for string reconstruction + c + }) + .collect::(); + + new_normalized.push_str(&normalized); + last_end = end; + } + + // Copy the remaining part of the input + new_normalized.push_str(&self.normalized[last_end..]); + new_alignments.extend(&self.alignments[last_end..]); + + self.normalized = new_normalized; + self.alignments = new_alignments; + Ok(()) + } + /// Clear the normalized part of the string pub fn clear(&mut self) -> usize { let len = self.len();