Skip to content

Commit

Permalink
Disallow insert of eos token into Vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 9, 2025
1 parent 7b6781b commit 1fab872
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 39 deletions.
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ pub type Result<T, E = crate::Error> = std::result::Result<T, E>;

#[derive(Error, Debug)]
pub enum Error {
// Index Errors
#[error("The vocabulary does not allow to build an index that matches the input")]
InsufficientVocabulary,
#[error("Failed to build DFA {0}")]
IndexDfaError(#[from] Box<regex_automata::dfa::dense::BuildError>),
#[error("Index failed since anchored universal start state doesn't exist")]
DfaHasNoStartState,
// Vocabulary Errors
#[error("EOS token should not be inserted into Vocabulary")]
EOSTokenDisallowed,
#[error(transparent)]
TokenizersError(#[from] tokenizers::Error),
#[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")]
Expand Down
16 changes: 12 additions & 4 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ mod tests {
let regex = "0|[1-9][0-9]*";
let mut vocabulary = Vocabulary::new(4);
for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
vocabulary.insert(token, token_id as u32);
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
Expand All @@ -163,7 +165,9 @@ mod tests {
let regex = "`\\n(\\.\\n)?`\\n";
let mut vocabulary = Vocabulary::new(104);
for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
vocabulary.insert(token, token_id as u32);
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
Expand All @@ -179,14 +183,18 @@ mod tests {
let mut vocabulary = Vocabulary::new(8);
for (token, token_id) in [(" 😍", 5), ("blah", 0), ("πŸ˜‡", 2), ("😈a", 1), ("😍", 3)]
{
vocabulary.insert(token, token_id as u32);
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
for (token, token_id) in [
(vec![32, 240, 159, 152], 7),
(vec![32, 240, 159, 152, 141], 6),
(vec![240, 159, 152, 141], 4),
] {
vocabulary.insert(token, token_id as u32);
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
Expand Down
10 changes: 4 additions & 6 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ impl PyVocabulary {
#[new]
fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py<PyAny>) -> PyResult<PyVocabulary> {
if let Ok(dict) = map.extract::<HashMap<String, Vec<TokenId>>>(py) {
return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict))));
return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?));
}
if let Ok(dict) = map.extract::<HashMap<Vec<u8>, Vec<TokenId>>>(py) {
return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict))));
return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?));
}

let message = "Expected a dict with keys of type str or bytes and values of type list[int]";
Expand Down Expand Up @@ -248,12 +248,10 @@ impl PyVocabulary {

fn insert(&mut self, py: Python<'_>, token: Py<PyAny>, token_id: TokenId) -> PyResult<()> {
if let Ok(t) = token.extract::<String>(py) {
self.0.insert(t, token_id);
return Ok(());
return Ok(self.0.try_insert(t, token_id)?);
}
if let Ok(t) = token.extract::<Token>(py) {
self.0.insert(t, token_id);
return Ok(());
return Ok(self.0.try_insert(t, token_id)?);
}
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"Expected a token of type str or bytes, got {:?}",
Expand Down
75 changes: 52 additions & 23 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ mod processor;
/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None);
/// ```
///
/// ### Create an empty vocabulary.
/// ### Create an empty vocabulary and manually insert tokens.
/// ```rust
/// # use outlines_core::prelude::*;
/// #
/// let mut vocabulary = Vocabulary::new(1);
/// vocabulary.insert("token", 0);
/// let eos_token_id = 1;
/// let mut vocabulary = Vocabulary::new(eos_token_id);
/// vocabulary.try_insert("token", 0).expect("New token inserted");
/// ```
#[derive(Clone, Debug, Default, PartialEq, Encode, Decode)]
pub struct Vocabulary {
Expand All @@ -47,9 +48,13 @@ impl Vocabulary {
}

/// Inserts a token to the vocabulary with the specified identifier.
pub fn insert(&mut self, token: impl Into<Token>, id: TokenId) {
pub fn try_insert(&mut self, token: impl Into<Token>, id: TokenId) -> Result<(), Error> {
if id == self.eos_token_id {
return Err(Error::EOSTokenDisallowed);
}
let token = token.into();
self.tokens.entry(token).or_default().push(id);
Ok(())
}

/// Creates the vocabulary of pre-trained model from Hugging Face Hub.
Expand Down Expand Up @@ -81,8 +86,8 @@ impl Vocabulary {
// Start building the vocabulary from eos_token_id and added tokens.
let mut vocabulary = Vocabulary::new(eos_token_id);
for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() {
if !added_token.special {
vocabulary.insert(added_token.content.clone(), *id);
if !added_token.special && id != &eos_token_id {
vocabulary.try_insert(added_token.content.clone(), *id)?
}
}

Expand All @@ -94,8 +99,10 @@ impl Vocabulary {
});
};
for (token, token_id) in tokenizer.get_vocab(false) {
let processed_token = processor.process(token)?;
vocabulary.insert(processed_token, token_id);
if token_id != eos_token_id {
let processed_token = processor.process(&token)?;
vocabulary.try_insert(processed_token, token_id)?;
}
}

Ok(vocabulary)
Expand Down Expand Up @@ -169,26 +176,39 @@ impl std::fmt::Display for Vocabulary {
}
}

impl From<(TokenId, HashMap<Token, Vec<TokenId>>)> for Vocabulary {
fn from(values: (TokenId, HashMap<Token, Vec<TokenId>>)) -> Vocabulary {
impl TryFrom<(TokenId, HashMap<Token, Vec<TokenId>>)> for Vocabulary {
type Error = Error;

fn try_from(values: (TokenId, HashMap<Token, Vec<TokenId>>)) -> Result<Self, Self::Error> {
let (eos_token_id, tokens) = values;
Vocabulary {
if tokens.iter().any(|(_, ids)| ids.contains(&eos_token_id)) {
return Err(Error::EOSTokenDisallowed);
}
Ok(Vocabulary {
eos_token_id,
tokens,
}
})
}
}

impl From<(TokenId, HashMap<String, Vec<TokenId>>)> for Vocabulary {
fn from(values: (TokenId, HashMap<String, Vec<TokenId>>)) -> Vocabulary {
impl TryFrom<(TokenId, HashMap<String, Vec<TokenId>>)> for Vocabulary {
type Error = Error;

fn try_from(values: (TokenId, HashMap<String, Vec<TokenId>>)) -> Result<Self, Self::Error> {
let (eos_token_id, tokens) = values;
Vocabulary {
Ok(Vocabulary {
eos_token_id,
tokens: tokens
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v))
.collect::<HashMap<Token, Vec<TokenId>>>(),
}
.map(|(k, v)| {
if v.contains(&eos_token_id) {
Err(Error::EOSTokenDisallowed)
} else {
Ok((k.as_bytes().to_vec(), v))
}
})
.collect::<Result<HashMap<Token, Vec<TokenId>>, _>>()?,
})
}
}

Expand All @@ -202,32 +222,41 @@ mod tests {
let eos_token_id = 3;
let mut vocabulary = Vocabulary::new(eos_token_id);

match vocabulary.try_insert("eos-token", eos_token_id) {
Err(Error::EOSTokenDisallowed) => {}
_ => unreachable!(),
}

// New empty vocabulary.
assert_eq!(vocabulary.eos_token_id, eos_token_id);
assert!(vocabulary.tokens.is_empty());

for (token, id) in [("zero", 0), ("one", 1), ("two", 2)] {
vocabulary.insert(token, id);
vocabulary.try_insert(token, id).expect("Insert failed");
assert_eq!(vocabulary.token_to_ids(token), Some(&vec![id]));
}
assert_eq!(vocabulary.tokens.len(), 3);
assert_eq!(vocabulary.tokens_to_ids().len(), 3);

// Confirm different types.
vocabulary.insert(b"four", 4);
vocabulary.try_insert(b"four", 4).expect("Insert failed");
assert_eq!(vocabulary.token_to_ids("four"), Some(&vec![4]));

vocabulary.insert(b"five".to_vec(), 5);
vocabulary
.try_insert(b"five".to_vec(), 5)
.expect("Insert failed");
assert_eq!(vocabulary.token_to_ids("five"), Some(&vec![5]));

vocabulary.insert("six".to_string(), 6);
vocabulary
.try_insert("six".to_string(), 6)
.expect("Insert failed");
assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6]));
}

#[test]
fn new_empty_vocabulary_from_hashmap() {
let map: HashMap<Token, Vec<TokenId>> = HashMap::default();
let vocabulary = Vocabulary::from((1_u32, map));
let vocabulary = Vocabulary::try_from((1_u32, map)).expect("Vocabulary failed");
assert_eq!(vocabulary.eos_token_id, 1);
assert!(vocabulary.tokens.is_empty());
}
Expand Down
12 changes: 6 additions & 6 deletions src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl Default for Mods {

impl Mods {
/// Apply default modifications to each token.
fn apply_default(&self, token: String) -> String {
fn apply_default(&self, token: &str) -> String {
let to = Self::default().spacechar.to_string();
token.replace(self.spacechar, &to)
}
Expand Down Expand Up @@ -190,7 +190,7 @@ impl TokenProcessor {
}

/// Operates on each token based on the level of `TokenProcessor`.
pub(crate) fn process(&self, token: String) -> Result<Vec<u8>> {
pub(crate) fn process(&self, token: &str) -> Result<Vec<u8>> {
match &self.level {
TokenProcessorLevel::Byte => token
.chars()
Expand Down Expand Up @@ -275,7 +275,7 @@ mod tests {
('ΓΎ', 0xFE),
('ΓΏ', 0xFF),
] {
let processed = processor.process(ch.to_string()).expect("Not processed");
let processed = processor.process(&ch.to_string()).expect("Not processed");
assert_eq!(processed, [byte]);
}
}
Expand Down Expand Up @@ -304,7 +304,7 @@ mod tests {
vec![0x20, 0x20, 0x20],
),
] {
let processed = processor.process(input.to_string()).expect("Not processed");
let processed = processor.process(input).expect("Not processed");
assert_eq!(processed, expected);
}
}
Expand All @@ -328,7 +328,7 @@ mod tests {
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] {
let result = processor.process(token.to_string());
let result = processor.process(token);
match result {
Err(Error::ByteProcessorFailed) => {}
_ => unreachable!(),
Expand All @@ -342,7 +342,7 @@ mod tests {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

let result = processor.process("<0x6y>".to_string());
let result = processor.process("<0x6y>");
match result {
Err(Error::ByteFallbackProcessorFailed) => {}
_ => unreachable!(),
Expand Down
7 changes: 7 additions & 0 deletions tests/fsm/test_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def test_insert_bad_type(vocabulary):
vocabulary.insert(1, 6)


def test_insert_eos_token(vocabulary):
with pytest.raises(
ValueError, match="EOS token should not be inserted into Vocabulary"
):
vocabulary.insert("eos-token", 3)


def test_from_pretrained():
vocabulary = Vocabulary.from_pretrained("gpt2")
assert vocabulary.get_eos_token_id() == 50256
Expand Down

0 comments on commit 1fab872

Please sign in to comment.