Skip to content

Commit

Permalink
Add remove to vocabulary interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 10, 2025
1 parent bf6e8a6 commit 5ea7e10
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def time_regex_to_guide_parallel(self, pattern_name):
def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name):
# Note: after moving to full rust implementation for index and guide creation, this experiment
# is no longer shows the drastic difference as it once showed when python was heavily involved,
# due to on average speedup ~100 times.
# due to speedup up to ~100 times.

# This test is to show, that if GIL's switch interval is set to be longer, then the parallel
# test's runtime on physical cores will be much closer to the one-threaded case.
Expand Down
3 changes: 3 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Vocabulary:
def insert(self, token: Union[str, bytes], token_id: int):
"""Inserts new token with token_id or extends list of token_ids if token already present."""
...
def remove(self, token: Union[str, bytes]):
"""Removes a token from vocabulary."""
...
def get_eos_token_id(self) -> Optional[int]:
"""Gets the end of sentence token id."""
...
Expand Down
15 changes: 15 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ impl PyVocabulary {
)))
}

fn remove(&mut self, py: Python<'_>, token: Py<PyAny>) -> PyResult<()> {
if let Ok(t) = token.extract::<String>(py) {
self.0.remove(t);
return Ok(());
}
if let Ok(t) = token.extract::<Token>(py) {
self.0.remove(t);
return Ok(());
}
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"Expected a token of type str or bytes, got {:?}",
type_name!(token)
)))
}

fn get_eos_token_id(&self) -> TokenId {
self.0.eos_token_id()
}
Expand Down
15 changes: 15 additions & 0 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ impl Vocabulary {
Ok(())
}

/// Removes a token from the vocabulary.
pub fn remove(&mut self, token: impl Into<Token>) {
let token = token.into();
self.tokens.remove(&token);
}

/// Creates the vocabulary of pre-trained model from Hugging Face Hub.
pub fn from_pretrained(
model: &str,
Expand Down Expand Up @@ -251,6 +257,15 @@ mod tests {
.try_insert("six".to_string(), 6)
.expect("Insert failed");
assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6]));

vocabulary.remove(b"four");
assert_eq!(vocabulary.token_to_ids("four"), None);

vocabulary.remove(b"five".to_vec());
assert_eq!(vocabulary.token_to_ids("five"), None);

vocabulary.remove("six".to_string());
assert_eq!(vocabulary.token_to_ids("six"), None);
}

#[test]
Expand Down
19 changes: 13 additions & 6 deletions tests/fsm/test_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ def vocabulary():
return Vocabulary(eos_token_id, tokens)


def test_basic_vocabulary_interface():
eos_token_id = 3
tokens = {"1": [1], "a": [2]}
vocabulary = Vocabulary(eos_token_id, tokens)

assert vocabulary.get_eos_token_id() == eos_token_id
def test_basic_vocabulary_interface(vocabulary):
assert vocabulary.get_eos_token_id() == 3
assert vocabulary.get("1") == vocabulary.get(b"1") == [1]
assert len(vocabulary) == 2

Expand All @@ -29,6 +25,17 @@ def test_basic_vocabulary_interface():
assert vocabulary.get("b") == vocabulary.get(b"b") == [4, 5]
assert len(vocabulary) == 3

vocabulary.remove("b")
assert vocabulary.get("b") is None

# second remove doesn't fail too
vocabulary.remove("b")
assert vocabulary.get("b") is None

assert vocabulary.get("a") == [2]
vocabulary.remove(b"a")
assert vocabulary.get("a") is None


def test_string_and_bytes_as_tokens():
eos_token_id = 3
Expand Down

0 comments on commit 5ea7e10

Please sign in to comment.