Skip to content

Commit

Permalink
Stabilize Index interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 9, 2025
1 parent 1fab872 commit 15a45c0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 25 deletions.
5 changes: 4 additions & 1 deletion python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class Vocabulary:
def __eq__(self, other: object) -> bool:
"""Compares whether two vocabularies are the same."""
...
def __len__(self) -> int:
"""Returns length of Vocabulary's tokens, excluding EOS token."""
...
def __deepcopy__(self, memo: dict) -> "Vocabulary":
"""Makes a deep copy of the Vocabulary."""
...
Expand All @@ -85,7 +88,7 @@ class Index:
def is_final_state(self, state: int) -> bool:
"""Determines whether the current state is a final state."""
...
def final_states(self) -> List[int]:
def get_final_states(self) -> List[int]:
"""Get all final states."""
...
def get_transitions(self) -> Dict[int, Dict[int, int]]:
Expand Down
44 changes: 22 additions & 22 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

#[derive(Clone, Debug, PartialEq, Encode, Decode)]
pub struct Index {
initial: StateId,
finals: HashSet<StateId>,
states_to_token_subsets: HashMap<StateId, HashMap<TokenId, StateId>>,
initial_state: StateId,
final_states: HashSet<StateId>,
transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
eos_token_id: TokenId,
}

impl Index {
pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Expand Down Expand Up @@ -83,50 +83,50 @@ impl Index {

if is_valid {
Ok(Self {
initial: start_state.as_u32(),
finals: final_states,
states_to_token_subsets: transitions,
initial_state: start_state.as_u32(),
final_states,
transitions,
eos_token_id,
})
} else {
Err(Error::InsufficientVocabulary)
}
}

pub(crate) fn allowed_tokens(&self, state: StateId) -> Option<Vec<TokenId>> {
self.states_to_token_subsets
pub fn allowed_tokens(&self, state: StateId) -> Option<Vec<TokenId>> {
self.transitions
.get(&state)
.map_or_else(|| None, |res| Some(res.keys().cloned().collect()))
}

pub(crate) fn next_state(&self, state: StateId, token_id: TokenId) -> Option<StateId> {
pub fn next_state(&self, state: StateId, token_id: TokenId) -> Option<StateId> {
if token_id == self.eos_token_id {
return None;
}
Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?)
Some(*self.transitions.get(&state)?.get(&token_id)?)
}

pub(crate) fn initial(&self) -> StateId {
self.initial
pub fn initial_state(&self) -> StateId {
self.initial_state
}

pub(crate) fn is_final(&self, state: StateId) -> bool {
self.finals.contains(&state)
pub fn is_final(&self, state: StateId) -> bool {
self.final_states.contains(&state)
}

pub(crate) fn final_states(&self) -> &HashSet<StateId> {
&self.finals
pub fn final_states(&self) -> &HashSet<StateId> {
&self.final_states
}

pub(crate) fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
&self.states_to_token_subsets
pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
&self.transitions
}
}

impl std::fmt::Display for Index {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Index object with transitions:")?;
for (state_id, token_ids) in self.states_to_token_subsets.iter() {
for (state_id, token_ids) in self.transitions.iter() {
writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
}
Ok(())
Expand All @@ -148,7 +148,7 @@ mod tests {
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
assert_eq!(index.initial(), 40);
assert_eq!(index.initial_state(), 40);
assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));

let expected = HashMap::from_iter([
Expand All @@ -172,7 +172,7 @@ mod tests {

let index = Index::new(regex, &vocabulary).expect("Index failed");
let allowed = index
.allowed_tokens(index.initial())
.allowed_tokens(index.initial_state())
.expect("No allowed tokens");
assert!(allowed.contains(&101));
}
Expand Down
4 changes: 2 additions & 2 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl PyIndex {
self.0.is_final(state)
}

fn final_states(&self) -> HashSet<StateId> {
fn get_final_states(&self) -> HashSet<StateId> {
self.0.final_states().clone()
}

Expand All @@ -142,7 +142,7 @@ impl PyIndex {
}

fn get_initial_state(&self) -> StateId {
self.0.initial()
self.0.initial_state()
}
fn __repr__(&self) -> String {
format!("{:#?}", self.0)
Expand Down
25 changes: 25 additions & 0 deletions tests/fsm/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ def index() -> Index:
return Index(regex, vocabulary)


def test_basic_interface(index):
init_state = index.get_initial_state()
assert init_state == 12
assert index.is_final_state(init_state) is False

allowed_tokens = index.get_allowed_tokens(init_state)
assert allowed_tokens == [1, 2]

next_state = index.get_next_state(init_state, allowed_tokens[-1])
assert next_state == 20
assert index.is_final_state(next_state) is True
assert index.get_final_states() == {20}

expected_transitions = {
12: {
1: 20,
2: 20,
},
20: {
3: 20,
},
}
assert index.get_transitions() == expected_transitions


def test_pickling(index):
serialized = pickle.dumps(index)
deserialized = pickle.loads(serialized)
Expand Down

0 comments on commit 15a45c0

Please sign in to comment.