Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,52 @@ let next_state = index.next_state(&initial_state, token_id);
let final_states = index.final_states();
```

### Vocabulary

You can create a `Vocabulary` in three ways:

1. **`Vocabulary::from_pretrained(model, parameters)`** - Loads from a pretrained model (as in the example above)

2. **Manual creation** - You can create a vocabulary from token mappings:

1. **`Vocabulary::new(eos_token_id)`** - Creates an empty vocabulary, then add tokens with `try_insert()`:

```rust
let mut vocabulary = Vocabulary::new(50256);
vocabulary.try_insert("hello", 0)?;
vocabulary.try_insert(vec![32], 1)?;
```

2. **`Vocabulary::try_from((eos_token_id, tokens))`** - Creates a vocabulary by directly providing the token mappings.

- It can be done either with the tokens as strings:

```rust
use rustc_hash::FxHashMap as HashMap;

let eos_token_id: u32 = 50256;
let mut tokens: HashMap<String, Vec<u32>> = HashMap::default();
tokens.insert("hello".to_string(), vec![0]);
tokens.insert("world".to_string(), vec![1]);

let vocabulary = Vocabulary::try_from((eos_token_id, tokens))?;
```

- Or with the tokens as byte vector keys:

```rust
use rustc_hash::FxHashMap as HashMap;

let eos_token_id: u32 = 50256;
let mut tokens: HashMap<Vec<u8>, Vec<u32>> = HashMap::default();
tokens.insert(b"hello".to_vec(), vec![0]);
tokens.insert(b"world".to_vec(), vec![1]);

let vocabulary = Vocabulary::try_from((eos_token_id, tokens))?;
```

**Important**: When creating a `Vocabulary` manually from tokenizer data, ensure tokens are converted to their string representations to replace special tokens that wouldn't be recognized by the DFA.

## Python Bindings

Additionally, project provides interfaces to integrate the crate's functionality with Python.
Expand Down
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ pub enum Error {
InvalidRefecencePath(Box<str>),
#[error("Ref recusion limit reached: {0}")]
RefRecursionLimitReached(usize),
#[error("The vocabulary provided is incompatible with the regex '{regex}'. Found no transitions from state {error_state}, missing tokens corresponding to at least one of the following characters: {missing_tokens:?}. This may be due to an encoding issue in your vocabulary.")]
IncompatibleVocabulary {
regex: String,
error_state: u32,
missing_tokens: Vec<String>,
},
}

impl Error {
Expand Down
80 changes: 78 additions & 2 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ impl Index {
let mut next_states: Vec<AutomataStateId> = vec![start_state];

while let Some(current_state) = next_states.pop() {
let mut has_valid_transitions = false;

if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
final_states.insert(current_state.as_u32());
has_valid_transitions = true;
}

'token_loop: for (token, ids) in vocabulary.tokens().iter() {
Expand All @@ -136,6 +139,7 @@ impl Index {
let is_intermediate_state = !dfa.is_match_state(next_state);
let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
if is_intermediate_state || is_full_match_state {
has_valid_transitions = true;
for token_id in ids {
transitions
.entry(current_state.as_u32())
Expand All @@ -148,6 +152,28 @@ impl Index {
next_states.push(next_state);
}
}

// If the current state has no valid transitions and is not a match state,
// it means the vocabulary is incompatible with the regex.
if !has_valid_transitions && !dfa.is_match_state(current_state) {
let mut valid_characters = Vec::new();
for byte in 0..=255u8 {
let test_state = dfa.next_state(current_state, byte);
if !dfa.is_dead_state(test_state) && !dfa.is_quit_state(test_state) {
if byte.is_ascii() {
valid_characters.push(char::from(byte).to_string());
} else {
valid_characters.push(format!("\\x{:02x}", byte));
}
}
}

return Err(Error::IncompatibleVocabulary {
regex: regex.to_string(),
error_state: current_state.as_u32(),
missing_tokens: valid_characters,
});
}
}

// Populate `transitions` with mappings from `final_states` to `eos_token_id`
Expand Down Expand Up @@ -290,7 +316,7 @@ mod tests {
.expect("Insert failed");
}
for (token, token_id) in [
(vec![32, 240, 159, 152], 7),
(vec![32, 240, 159, 152, 136], 7),
(vec![32, 240, 159, 152, 141], 6),
(vec![240, 159, 152, 141], 4),
] {
Expand All @@ -309,10 +335,60 @@ mod tests {
),
(
80,
HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
HashMap::from_iter([(2, 128), (7, 208), (5, 208), (6, 208)]),
),
(128, HashMap::from_iter([(8, 128)])),
]);
assert_eq!(index.transitions(), &expected);
}

#[test]
fn index_incompatible_vocabulary_error() {
let regex = "0 1";
let mut vocabulary = Vocabulary::new(3);
for (token, token_id) in [("0", 0), ("0 ", 1), ("1", 2)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let result = Index::new(regex, &vocabulary);
assert!(result.is_err());

if let Err(Error::IncompatibleVocabulary {
regex: _,
missing_tokens,
..
}) = result
{
assert!(missing_tokens.contains(&" ".to_string()));
} else {
panic!("Expected IncompatibleVocabulary error");
}
}

#[test]
fn index_incompatible_vocabulary_error_non_ascii() {
let regex = "😈😍";
let mut vocabulary = Vocabulary::new(3);
for (token, token_id) in [("😈", 0), (" ", 1), ("b", 2)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let result = Index::new(regex, &vocabulary);
assert!(result.is_err());

if let Err(Error::IncompatibleVocabulary {
regex: _,
missing_tokens,
..
}) = result
{
assert!(missing_tokens.contains(&"\\xf0".to_string()));
} else {
panic!("Expected IncompatibleVocabulary error");
}
}
}
Loading