Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion bindings/python/tests/bindings/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from tokenizers import NormalizedString
from tokenizers import NormalizedString, Regex
from tokenizers.normalizers import (
BertNormalizer,
Lowercase,
Expand Down Expand Up @@ -201,6 +201,23 @@ def test_can_modify(self):
assert normalizer.prepend == "-"


class TestReplace:
def test_replace_with_groups(self):
normalizer = Replace(Regex(r"(l)(e)"), r"$1 $2")
result = normalizer.normalize_str("le travail")
assert result == "l e travail"

def test_replace_with_group_zero(self):
normalizer = Replace(Regex(r"(\w+)"), r"[$0]")
result = normalizer.normalize_str("hello world")
assert result == "[hello] [world]"

def test_replace_no_capture_unchanged(self):
normalizer = Replace(Regex(r"\s+"), " ")
result = normalizer.normalize_str("hello world")
assert result == "hello world"


class TestCustomNormalizer:
class BadCustomNormalizer:
def normalize(self, normalized, wrong):
Expand Down
66 changes: 51 additions & 15 deletions tokenizers/src/normalizers/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub struct Replace {
pub content: String,
#[serde(skip)]
regex: SysRegex,
#[serde(skip)]
expansion_re: Option<regex::Regex>,
}

impl Clone for Replace {
Expand All @@ -70,38 +72,48 @@ impl Replace {
ReplacePattern::String(s) => SysRegex::new(&regex::escape(s))?,
ReplacePattern::Regex(r) => SysRegex::new(r)?,
};
let expansion_re = match &pattern {
ReplacePattern::String(_) => None,
ReplacePattern::Regex(r) => regex::Regex::new(r).ok(),
};

Ok(Self {
pattern,
content: content.into(),
regex,
expansion_re,
})
}
}

impl Normalizer for Replace {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
normalized.replace(&self.regex, &self.content)
normalized.replace_regex(&self.regex, self.expansion_re.as_ref(), &self.content)
}
}

impl Decoder for Replace {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
tokens
.into_iter()
.map(|token| -> Result<String> {
let mut new_token = "".to_string();

for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
if is_match {
new_token.push_str(&self.content);
} else {
new_token.push_str(&token[start..stop]);
match &self.expansion_re {
Some(re) => tokens
.into_iter()
.map(|token| Ok(re.replace_all(&token, &self.content).to_string()))
.collect(),
None => tokens
.into_iter()
.map(|token| {
let mut new_token = String::new();
for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
if is_match {
new_token.push_str(&self.content);
} else {
new_token.push_str(&token[start..stop]);
}
}
}
Ok(new_token)
})
.collect()
Ok(new_token)
})
.collect(),
}
}
}

Expand Down Expand Up @@ -156,4 +168,28 @@ mod tests {
vec!["hello", " hello"]
);
}

#[test]
fn test_replace_regex_groups_dollar() {
let original = "le travail";
let expected = "l e travail";

let mut n = NormalizedString::from(original);
Replace::new(ReplacePattern::Regex(r"(l)(e)".into()), r"$1 $2")
.unwrap()
.normalize(&mut n)
.unwrap();

assert_eq!(&n.get(), &expected);
}

#[test]
fn test_replace_regex_groups_decode() {
let tokens = vec!["le".to_string(), "test".to_string()];
let replace = Replace::new(ReplacePattern::Regex(r"(l)(e)".into()), r"$1 $2").unwrap();
assert_eq!(
replace.decode_chain(tokens).unwrap(),
vec!["l e", "test"]
);
}
}
101 changes: 101 additions & 0 deletions tokenizers/src/tokenizer/normalizer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::pattern::Pattern;
use crate::utils::SysRegex;
use crate::{Offsets, Result};
use std::ops::{Bound, RangeBounds};
use unicode_normalization_alignments::UnicodeNormalization;
Expand Down Expand Up @@ -673,6 +674,106 @@ impl NormalizedString {
Ok(())
}

/// Replace anything that matches the given regex with the given content,
/// supporting backreferences (`$1`, `$2`, etc.) in the replacement string.
pub fn replace_regex(
&mut self,
regex: &SysRegex,
expansion_re: Option<&regex::Regex>,
content: &str,
) -> Result<()> {
let mut new_normalized = String::with_capacity(self.normalized.len()); // Initially allocate for the input size
let mut new_alignments: Vec<(usize, usize)> = Vec::with_capacity(self.alignments.len());
let mut last_end = 0; // Keep track of the last end position

for (start, end) in regex.find_iter(&self.normalized) {
let range = start..end;

let matched = &self.normalized[start..end];
let expanded = match (matched.is_empty(), expansion_re) {
(true, _) | (_, None) => content.to_string(),
(false, Some(re)) => re.replacen(matched, 1, content).to_string(),
};

let removed_chars = self.normalized[range.clone()].chars().count();

// Copy the part of the string that is before the match
new_normalized.push_str(&self.normalized[last_end..start]);
new_alignments.extend(self.alignments[last_end..start].iter().cloned());

let n_range = Range::Normalized(range).into_full_range(self.len());

// Retrieve the original characters that are being replaced. This let us
// compute the change in byte sizes along the way.
let mut replaced_normalized = self.normalized[n_range.clone()]
.chars()
.collect::<Vec<_>>()
.into_iter();
let initial_removed: usize = (&mut replaced_normalized)
.take(removed_chars)
.map(|c| c.len_utf8())
.sum();

let mut offset = (initial_removed + n_range.start) as isize;
let normalized = expanded
.chars()
.map(|c| (c, 1))
.map(|(c, changes): (char, i32)| {
let idx = offset as usize;
let align = if changes.is_positive() {
if idx < 1 {
(0, 0)
} else {
// This is a newly inserted character, so it shares the same alignment
// than the previous one
self.alignments[idx - 1]
}
} else {
self.alignments[idx]
};

// If we are replacing a character, find it and compute the change in size
let replaced_char = if !changes.is_positive() {
replaced_normalized.next()
} else {
None
};
let replaced_char_size = replaced_char.map_or(0, |c| c.len_utf8());

// If we are removing some characters, find them too
let total_bytes_to_remove = if changes.is_negative() {
(&mut replaced_normalized)
.take(-changes as usize)
.map(|c| c.len_utf8())
.sum()
} else {
0
};

// Keep track of the changes for next offsets
offset += replaced_char_size as isize;
offset += total_bytes_to_remove as isize;

new_alignments.extend((0..c.len_utf8()).map(|_| align));

// Then we keep only the char for string reconstruction
c
})
.collect::<String>();

new_normalized.push_str(&normalized);
last_end = end;
}

// Copy the remaining part of the input
new_normalized.push_str(&self.normalized[last_end..]);
new_alignments.extend(&self.alignments[last_end..]);

self.normalized = new_normalized;
self.alignments = new_alignments;
Ok(())
}

/// Clear the normalized part of the string
pub fn clear(&mut self) -> usize {
let len = self.len();
Expand Down