From 1fab8726ba8f2f20739e58fbb0b8c7772d98e454 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 9 Jan 2025 12:51:30 +0000 Subject: [PATCH] Disallow insert of eos token into Vocabulary --- src/error.rs | 4 ++ src/index.rs | 16 ++++++-- src/python_bindings/mod.rs | 10 ++--- src/vocabulary/mod.rs | 75 +++++++++++++++++++++++++----------- src/vocabulary/processor.rs | 12 +++--- tests/fsm/test_vocabulary.py | 7 ++++ 6 files changed, 85 insertions(+), 39 deletions(-) diff --git a/src/error.rs b/src/error.rs index 4ffe7ed7..e5781e84 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,12 +4,16 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { + // Index Errors #[error("The vocabulary does not allow to build an index that matches the input")] InsufficientVocabulary, #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] DfaHasNoStartState, + // Vocabulary Errors + #[error("EOS token should not be inserted into Vocabulary")] + EOSTokenDisallowed, #[error(transparent)] TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] diff --git a/src/index.rs b/src/index.rs index c005587c..407e4711 100644 --- a/src/index.rs +++ b/src/index.rs @@ -142,7 +142,9 @@ mod tests { let regex = "0|[1-9][0-9]*"; let mut vocabulary = Vocabulary::new(4); for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); @@ -163,7 +165,9 @@ mod tests { let regex = "`\\n(\\.\\n)?`\\n"; let mut vocabulary = Vocabulary::new(104); for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); @@ -179,14 +183,18 @@ mod tests { let mut vocabulary = Vocabulary::new(8); for (token, token_id) in [(" 😍", 5), ("blah", 0), ("πŸ˜‡", 2), ("😈a", 1), ("😍", 3)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } for (token, token_id) in [ (vec![32, 240, 159, 152], 7), (vec![32, 240, 159, 152, 141], 6), (vec![240, 159, 152, 141], 4), ] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 3fbdd305..eca8e7bc 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -209,10 +209,10 @@ impl PyVocabulary { #[new] fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { if let Ok(dict) = map.extract::>>(py) { - return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); } if let Ok(dict) = map.extract::, Vec>>(py) { - return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); } let message = "Expected a dict with keys of type str or bytes and values of type list[int]"; @@ -248,12 +248,10 @@ impl PyVocabulary { fn insert(&mut self, py: Python<'_>, token: Py, token_id: TokenId) -> PyResult<()> { if let Ok(t) = token.extract::(py) { - self.0.insert(t, token_id); - return Ok(()); + return Ok(self.0.try_insert(t, token_id)?); } if let Ok(t) = token.extract::(py) { - self.0.insert(t, token_id); - return Ok(()); + return Ok(self.0.try_insert(t, token_id)?); } Err(PyErr::new::(format!( "Expected a token of type str or bytes, got {:?}", diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index dbae9b6c..71f2c424 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -24,12 +24,13 @@ mod processor; /// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None); /// ``` /// -/// ### Create an empty vocabulary. +/// ### Create an empty vocabulary and manually insert tokens. /// ```rust /// # use outlines_core::prelude::*; /// # -/// let mut vocabulary = Vocabulary::new(1); -/// vocabulary.insert("token", 0); +/// let eos_token_id = 1; +/// let mut vocabulary = Vocabulary::new(eos_token_id); +/// vocabulary.try_insert("token", 0).expect("New token inserted"); /// ``` #[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { @@ -47,9 +48,13 @@ impl Vocabulary { } /// Inserts a token to the vocabulary with the specified identifier. - pub fn insert(&mut self, token: impl Into, id: TokenId) { + pub fn try_insert(&mut self, token: impl Into, id: TokenId) -> Result<(), Error> { + if id == self.eos_token_id { + return Err(Error::EOSTokenDisallowed); + } let token = token.into(); self.tokens.entry(token).or_default().push(id); + Ok(()) } /// Creates the vocabulary of pre-trained model from Hugging Face Hub. @@ -81,8 +86,8 @@ impl Vocabulary { // Start building the vocabulary from eos_token_id and added tokens. let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { - if !added_token.special { - vocabulary.insert(added_token.content.clone(), *id); + if !added_token.special && id != &eos_token_id { + vocabulary.try_insert(added_token.content.clone(), *id)? } } @@ -94,8 +99,10 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let processed_token = processor.process(token)?; - vocabulary.insert(processed_token, token_id); + if token_id != eos_token_id { + let processed_token = processor.process(&token)?; + vocabulary.try_insert(processed_token, token_id)?; + } } Ok(vocabulary) @@ -169,26 +176,39 @@ impl std::fmt::Display for Vocabulary { } } -impl From<(TokenId, HashMap>)> for Vocabulary { - fn from(values: (TokenId, HashMap>)) -> Vocabulary { +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { let (eos_token_id, tokens) = values; - Vocabulary { + if tokens.iter().any(|(_, ids)| ids.contains(&eos_token_id)) { + return Err(Error::EOSTokenDisallowed); + } + Ok(Vocabulary { eos_token_id, tokens, - } + }) } } -impl From<(TokenId, HashMap>)> for Vocabulary { - fn from(values: (TokenId, HashMap>)) -> Vocabulary { +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { let (eos_token_id, tokens) = values; - Vocabulary { + Ok(Vocabulary { eos_token_id, tokens: tokens .into_iter() - .map(|(k, v)| (k.as_bytes().to_vec(), v)) - .collect::>>(), - } + .map(|(k, v)| { + if v.contains(&eos_token_id) { + Err(Error::EOSTokenDisallowed) + } else { + Ok((k.as_bytes().to_vec(), v)) + } + }) + .collect::>, _>>()?, + }) } } @@ -202,32 +222,41 @@ mod tests { let eos_token_id = 3; let mut vocabulary = Vocabulary::new(eos_token_id); + match vocabulary.try_insert("eos-token", eos_token_id) { + Err(Error::EOSTokenDisallowed) => {} + _ => unreachable!(), + } + // New empty vocabulary. assert_eq!(vocabulary.eos_token_id, eos_token_id); assert!(vocabulary.tokens.is_empty()); for (token, id) in [("zero", 0), ("one", 1), ("two", 2)] { - vocabulary.insert(token, id); + vocabulary.try_insert(token, id).expect("Insert failed"); assert_eq!(vocabulary.token_to_ids(token), Some(&vec![id])); } assert_eq!(vocabulary.tokens.len(), 3); assert_eq!(vocabulary.tokens_to_ids().len(), 3); // Confirm different types. - vocabulary.insert(b"four", 4); + vocabulary.try_insert(b"four", 4).expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("four"), Some(&vec![4])); - vocabulary.insert(b"five".to_vec(), 5); + vocabulary + .try_insert(b"five".to_vec(), 5) + .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("five"), Some(&vec![5])); - vocabulary.insert("six".to_string(), 6); + vocabulary + .try_insert("six".to_string(), 6) + .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6])); } #[test] fn new_empty_vocabulary_from_hashmap() { let map: HashMap> = HashMap::default(); - let vocabulary = Vocabulary::from((1_u32, map)); + let vocabulary = Vocabulary::try_from((1_u32, map)).expect("Vocabulary failed"); assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 7426f249..1e10bf0e 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -102,7 +102,7 @@ impl Default for Mods { impl Mods { /// Apply default modifications to each token. - fn apply_default(&self, token: String) -> String { + fn apply_default(&self, token: &str) -> String { let to = Self::default().spacechar.to_string(); token.replace(self.spacechar, &to) } @@ -190,7 +190,7 @@ impl TokenProcessor { } /// Operates on each token based on the level of `TokenProcessor`. - pub(crate) fn process(&self, token: String) -> Result> { + pub(crate) fn process(&self, token: &str) -> Result> { match &self.level { TokenProcessorLevel::Byte => token .chars() @@ -275,7 +275,7 @@ mod tests { ('ΓΎ', 0xFE), ('ΓΏ', 0xFF), ] { - let processed = processor.process(ch.to_string()).expect("Not processed"); + let processed = processor.process(&ch.to_string()).expect("Not processed"); assert_eq!(processed, [byte]); } } @@ -304,7 +304,7 @@ mod tests { vec![0x20, 0x20, 0x20], ), ] { - let processed = processor.process(input.to_string()).expect("Not processed"); + let processed = processor.process(input).expect("Not processed"); assert_eq!(processed, expected); } } @@ -328,7 +328,7 @@ mod tests { let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] { - let result = processor.process(token.to_string()); + let result = processor.process(token); match result { Err(Error::ByteProcessorFailed) => {} _ => unreachable!(), @@ -342,7 +342,7 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); - let result = processor.process("<0x6y>".to_string()); + let result = processor.process("<0x6y>"); match result { Err(Error::ByteFallbackProcessorFailed) => {} _ => unreachable!(), diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index 447b260a..f4879d62 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -77,6 +77,13 @@ def test_insert_bad_type(vocabulary): vocabulary.insert(1, 6) +def test_insert_eos_token(vocabulary): + with pytest.raises( + ValueError, match="EOS token should not be inserted into Vocabulary" + ): + vocabulary.insert("eos-token", 3) + + def test_from_pretrained(): vocabulary = Vocabulary.from_pretrained("gpt2") assert vocabulary.get_eos_token_id() == 50256