Skip to content

Commit

Permalink
SMT pruning (#215)
Browse files Browse the repository at this point in the history
* SMT pruning

* Comments

* Minor

* Minor

* Clippy
  • Loading branch information
wborgeaud authored May 6, 2024
1 parent 5f8e41a commit b5e9063
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 15 deletions.
64 changes: 49 additions & 15 deletions smt_trie/src/smt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::needless_range_loop)]

use std::collections::HashMap;
use std::borrow::Borrow;
use std::collections::{HashMap, HashSet};

use ethereum_types::U256;
use plonky2::field::goldilocks_field::GoldilocksField;
Expand All @@ -14,9 +15,9 @@ use crate::utils::{
f2limbs, get_unique_sibling, hash0, hash_key_hash, hashout2u, key2u, limbs2f, u2h, u2k,
};

const HASH_TYPE: u8 = 0;
const INTERNAL_TYPE: u8 = 1;
const LEAF_TYPE: u8 = 2;
pub(crate) const HASH_TYPE: u8 = 0;
pub(crate) const INTERNAL_TYPE: u8 = 1;
pub(crate) const LEAF_TYPE: u8 = 2;

pub type F = GoldilocksField;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -400,32 +401,66 @@ impl<D: Db> Smt<D> {
self.root = new_root;
}

/// Serialize the SMT into a vector of U256.
/// Serialize and prune the SMT into a vector of U256.
/// Starts with a [0, 0] for convenience, that way `ptr=0` is a canonical
/// empty node. Therefore the root of the SMT is at `ptr=2`.
/// `keys` is a list of keys whose prefixes will not be hashed-out in the
/// serialization.
/// Serialization rules:
/// ```pseudocode
/// serialize( HashNode { h } ) = [HASH_TYPE, h]
/// serialize( InternalNode { left, right } ) = [INTERNAL_TYPE, serialize(left).ptr, serialize(right).ptr]
/// serialize( LeafNode { rem_key, value } ) = [LEAF_TYPE, rem_key, value]
/// ```
pub fn serialize(&self) -> Vec<U256> {
pub fn serialize_and_prune<K: Borrow<Key>, I: IntoIterator<Item = K>>(
&self,
keys: I,
) -> Vec<U256> {
let mut v = vec![U256::zero(); 2]; // For empty hash node.
let key = Key(self.root.elements);
serialize(self, key, &mut v);

let mut keys_to_include = HashSet::new();
for key in keys.into_iter() {
let mut bits = key.borrow().split();
loop {
keys_to_include.insert(bits);
if bits.is_empty() {
break;
}
bits.pop_next_bit();
}
}

serialize(self, key, &mut v, Bits::empty(), &keys_to_include);
if v.len() == 2 {
v.extend([U256::zero(); 2]);
}
v
}

pub fn serialize(&self) -> Vec<U256> {
// Include all keys.
self.serialize_and_prune(self.kv_store.keys())
}
}

fn serialize<D: Db>(smt: &Smt<D>, key: Key, v: &mut Vec<U256>) -> usize {
fn serialize<D: Db>(
smt: &Smt<D>,
key: Key,
v: &mut Vec<U256>,
cur_bits: Bits,
keys_to_include: &HashSet<Bits>,
) -> usize {
if key.0.iter().all(F::is_zero) {
return 0; // `ptr=0` is an empty node.
}

if let Some(node) = smt.db.get_node(&key) {
if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() {
let index = v.len();
v.push(HASH_TYPE.into());
v.push(key2u(key));
index
} else if let Some(node) = smt.db.get_node(&key) {
if node.0.iter().all(F::is_zero) {
panic!("wtf?");
}
Expand All @@ -449,17 +484,16 @@ fn serialize<D: Db>(smt: &Smt<D>, key: Key, v: &mut Vec<U256>) -> usize {
v.push(INTERNAL_TYPE.into());
v.push(U256::zero());
v.push(U256::zero());
let i_left = serialize(smt, key_left, v).into();
let i_left =
serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into();
v[index + 1] = i_left;
let i_right = serialize(smt, key_right, v).into();
let i_right =
serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into();
v[index + 2] = i_right;
index
}
} else {
let index = v.len();
v.push(HASH_TYPE.into());
v.push(key2u(key));
index
unreachable!()
}
}

Expand Down
41 changes: 41 additions & 0 deletions smt_trie/src/smt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use rand::{random, thread_rng, Rng};

use crate::bits::Bits;
use crate::db::Db;
use crate::smt::HASH_TYPE;
use crate::utils::hashout2u;
use crate::{
db::MemoryDb,
smt::{hash_serialize, Key, Smt, F},
Expand Down Expand Up @@ -366,3 +368,42 @@ fn test_set_hash_order() {
let ser = second_smt.serialize();
assert_eq!(hash_serialize(&ser), second_smt.root);
}

#[test]
fn test_serialize_and_prune() {
let mut smt = Smt::<MemoryDb>::default();

for _ in 0..128 {
let k = Key(F::rand_array());
let v = U256(random());
smt.set(k, v);
}

let ser = smt.serialize();
assert_eq!(hash_serialize(&ser), smt.root);

let subset = {
let r: u128 = random();
smt.kv_store
.keys()
.enumerate()
.filter_map(|(i, k)| if r & (1 << i) != 0 { Some(*k) } else { None })
.collect::<Vec<_>>()
};

let pruned_ser = smt.serialize_and_prune(subset);
assert_eq!(hash_serialize(&pruned_ser), smt.root);
assert!(pruned_ser.len() <= ser.len());

let trivial_ser = smt.serialize_and_prune::<Key, Vec<_>>(vec![]);
assert_eq!(
trivial_ser,
vec![
U256::zero(),
U256::zero(),
HASH_TYPE.into(),
hashout2u(smt.root)
]
);
assert_eq!(hash_serialize(&trivial_ser), smt.root);
}

0 comments on commit b5e9063

Please sign in to comment.