Skip to content

Commit

Permalink
feat: add a few usability functions to the interface (#169)
Browse files Browse the repository at this point in the history
* feat: add a few usability functions to the interface

* fix: fix clippy issue

* feat: add usize to nibbles

* docs: update changelog.md

---------

Co-authored-by: Vladimir Trifonov <[email protected]>
Co-authored-by: Ben <[email protected]>
  • Loading branch information
3 people authored and Nashtare committed May 20, 2024
1 parent 3afeba3 commit 827925d
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Changed
- Add a few QoL useability functions to the interface ([#169](https://github.com/0xPolygonZero/zk_evm/pull/169))

## [0.3.1] - 2024-04-22

Expand Down
56 changes: 56 additions & 0 deletions mpt_trie/src/nibbles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ macro_rules! impl_as_u64s_for_primitive {
};
}

impl_as_u64s_for_primitive!(usize);
impl_as_u64s_for_primitive!(u8);
impl_as_u64s_for_primitive!(u16);
impl_as_u64s_for_primitive!(u32);
Expand Down Expand Up @@ -178,6 +179,7 @@ macro_rules! impl_to_nibbles {
};
}

impl_to_nibbles!(usize);
impl_to_nibbles!(u8);
impl_to_nibbles!(u16);
impl_to_nibbles!(u32);
Expand Down Expand Up @@ -908,6 +910,23 @@ impl Nibbles {
}
}

/// Returns a slice of the internal bytes of packed nibbles.
/// Only the relevant bytes (up to `count` nibbles) are considered valid.
pub fn as_byte_slice(&self) -> &[u8] {
// Calculate the number of full bytes needed to cover 'count' nibbles
let bytes_needed = (self.count + 1) / 2; // each nibble is half a byte

// Safe because we are ensuring the slice size does not exceed the bounds of the
// array
unsafe {
// Convert the pointer to `packed` to a pointer to `u8`
let packed_ptr = self.packed.0.as_ptr() as *const u8;

// Create a slice from this pointer and the number of needed bytes
std::slice::from_raw_parts(packed_ptr, bytes_needed)
}
}

const fn nibble_append_safety_asserts(&self, n: Nibble) {
assert!(
self.count < 64,
Expand Down Expand Up @@ -1616,6 +1635,12 @@ mod tests {
format!("{:x}", 0x1234_u64.to_nibbles_byte_padded()),
"0x1234"
);

assert_eq!(format!("{:x}", 0x1234_usize.to_nibbles()), "0x1234");
assert_eq!(
format!("{:x}", 0x1234_usize.to_nibbles_byte_padded()),
"0x1234"
);
}

#[test]
Expand All @@ -1627,4 +1652,35 @@ mod tests {

Nibbles::from_hex_prefix_encoding(&buf).unwrap();
}

#[test]
fn nibbles_as_byte_slice_works() -> Result<(), StrToNibblesError> {
let cases = [
(0x0, vec![]),
(0x1, vec![0x01]),
(0x12, vec![0x12]),
(0x123, vec![0x23, 0x01]),
];

for case in cases.iter() {
let nibbles = Nibbles::from(case.0 as u64);
let byte_vec = nibbles.as_byte_slice().to_vec();
assert_eq!(byte_vec, case.1.clone(), "Failed for input 0x{:X}", case.0);
}

let input = "3ab76c381c0f8ea617ea96780ffd1e165c754b28a41a95922f9f70682c581351";
let nibbles = Nibbles::from_str(input)?;

let byte_vec = nibbles.as_byte_slice().to_vec();
let mut expected_vec: Vec<u8> = hex::decode(input).expect("Invalid hex string");
expected_vec.reverse();
assert_eq!(
byte_vec,
expected_vec.clone(),
"Failed for input 0x{}",
input
);

Ok(())
}
}
19 changes: 19 additions & 0 deletions mpt_trie/src/partial_trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ pub trait PartialTrie:
/// Returns an iterator over the trie that returns all values for every
/// `Leaf` and `Hash` node.
fn values(&self) -> impl Iterator<Item = ValOrHash>;

/// Returns `true` if the trie contains an element with the given key.
fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>;
}

/// Part of the trait that is not really part of the public interface but
Expand Down Expand Up @@ -261,6 +266,13 @@ impl PartialTrie for StandardTrie {
fn values(&self) -> impl Iterator<Item = ValOrHash> {
self.0.trie_values()
}

fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
self.0.trie_has_item_by_key(k)
}
}

impl TrieNodeIntern for StandardTrie {
Expand Down Expand Up @@ -381,6 +393,13 @@ impl PartialTrie for HashedPartialTrie {
fn values(&self) -> impl Iterator<Item = ValOrHash> {
self.node.trie_values()
}

fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
self.node.trie_has_item_by_key(k)
}
}

impl TrieNodeIntern for HashedPartialTrie {
Expand Down
32 changes: 31 additions & 1 deletion mpt_trie/src/trie_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl<T: PartialTrie> Node<T> {
where
K: Into<Nibbles>,
{
let k = k.into();
let k: Nibbles = k.into();
trace!("Deleting a leaf node with key {} if it exists", k);

delete_intern(&self.clone(), k)?.map_or(Ok(None), |(updated_root, deleted_val)| {
Expand All @@ -391,6 +391,14 @@ impl<T: PartialTrie> Node<T> {
pub(crate) fn trie_values(&self) -> impl Iterator<Item = ValOrHash> {
self.trie_items().map(|(_, v)| v)
}

pub(crate) fn trie_has_item_by_key<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
let k = k.into();
self.trie_items().any(|(key, _)| key == k)
}
}

fn insert_into_trie_rec<N: PartialTrie>(
Expand Down Expand Up @@ -1105,6 +1113,28 @@ mod tests {
Ok(())
}

#[test]
fn existent_node_key_contains_returns_true() -> TrieOpResult<()> {
common_setup();

let mut trie = StandardTrie::default();
trie.insert(0x1234, vec![91])?;
assert!(trie.contains(0x1234));

Ok(())
}

#[test]
fn non_existent_node_key_contains_returns_false() -> TrieOpResult<()> {
common_setup();

let mut trie = StandardTrie::default();
trie.insert(0x1234, vec![91])?;
assert!(!trie.contains(0x5678));

Ok(())
}

#[test]
fn deleting_from_an_empty_trie_returns_none() -> TrieOpResult<()> {
common_setup();
Expand Down
18 changes: 18 additions & 0 deletions mpt_trie/src/trie_subsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,24 @@ mod tests {
Ok(())
}

#[test]
fn sub_trie_existent_key_contains_returns_true() {
let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap();

let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap();

assert!(partial_trie.contains(0x0));
}

#[test]
fn sub_trie_non_existent_key_contains_returns_false() {
let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap();

let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap();

assert!(!partial_trie.contains(0x1));
}

fn assert_all_keys_do_not_exist(trie: &TrieType, ks: impl Iterator<Item = Nibbles>) {
for k in ks {
assert!(trie.get(k).is_none());
Expand Down

0 comments on commit 827925d

Please sign in to comment.