diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9c7403044..b0c2d81e6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ -* @muursh @Nashtare -/evm_arithmetization/ @wborgeaud @muursh @Nashtare +* @muursh @Nashtare @cpubot +/evm_arithmetization/ @wborgeaud @muursh @Nashtare @cpubot diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fbf89674..ca08f80a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed +- Add a few QoL useability functions to the interface ([#169](https://github.com/0xPolygonZero/zk_evm/pull/169)) ## [0.3.1] - 2024-04-22 diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm index 5d0512a12..f89938326 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm @@ -63,17 +63,22 @@ global init_access_lists: POP %endmacro -// Multiply the ptr at the top of the stack by 2 -// and abort if 2*ptr - @SEGMENT_ACCESSED_ADDRESSES >= @GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN -// In this way ptr must be pointing to the begining of a node. +// Multiply the value at the top of the stack, denoted by ptr/2, by 2 +// and abort if ptr/2 >= mem[@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN]/2 +// In this way 2*ptr/2 must be pointing to the begining of a node. %macro get_valid_addr_ptr - // stack: ptr + // stack: ptr/2 + DUP1 + // stack: ptr/2, ptr/2 + %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) + // @GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN must be an even number because + // both @SEGMENT_ACCESSED_ADDRESSES and the unscaled access addresses list len + // must be even numbers + %div_const(2) + // stack: scaled_len/2, ptr/2, ptr/2 + %assert_gt %mul_const(2) - PUSH @SEGMENT_ACCESSED_ADDRESSES - DUP2 - SUB - %assert_lt_const(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) - // stack: 2*ptr + // stack: ptr %endmacro @@ -205,17 +210,20 @@ global remove_accessed_addresses: // stack: cold_access, value_ptr %endmacro -// Multiply the ptr at the top of the stack by 4 -// and abort if 4*ptr - SEGMENT_ACCESSED_STORAGE_KEYS >= @GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN -// In this way ptr must be pointing to the beginning of a node. +// Multiply the ptr at the top of the stack, denoted by ptr/4, by 4 +// and abort if ptr/4 >= @GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN/4 +// In this way 4*ptr/4 be pointing to the beginning of a node. %macro get_valid_storage_ptr - // stack: ptr + // stack: ptr/4 + DUP1 + %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + // By construction, both @SEGMENT_ACCESSED_STORAGE_KEYS and the unscaled list len + // must be multiples of 4 + %div_const(4) + // stack: scaled_len/4, ptr/4, ptr/4 + %assert_gt %mul_const(4) - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS - DUP2 - SUB - %assert_lt_const(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) - // stack: 2*ptr + // stack: ptr %endmacro /// Inserts the storage key into the access list if it is not already present. diff --git a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/compression.asm b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/compression.asm index a9467a00b..6ff84301b 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/compression.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/compression.asm @@ -65,10 +65,11 @@ compression_loop: %mload_kernel_code_u32 // stack: K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest DUP10 - DUP10 - DUP10 - DUP10 - // stack: e[i], f[i], g[i], h[i], K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest + DUP8 + DUP11 + DUP11 + DUP11 + // stack: e[i], f[i], g[i], e[i], h[i], K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest %sha2_temp_word1 // stack: T1[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest DUP4 diff --git a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/message_schedule.asm b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/message_schedule.asm index 66fa67a9b..3bcd7dbfc 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/message_schedule.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/message_schedule.asm @@ -13,13 +13,11 @@ gen_message_schedule_from_block: // stack: block_addr, output_addr, retdest DUP1 // stack: block_addr, block_addr, output_addr, retdest - %add_const(32) - // stack: block_addr + 32, block_addr, output_addr, retdest - SWAP1 - // stack: block_addr, block_addr + 32, output_addr, retdest %mload_u256 - // stack: block[0], block_addr + 32, output_addr, retdest + // stack: block[0], block_addr, output_addr, retdest SWAP1 + // stack: block_addr, block[0], output_addr, retdest + %add_const(32) // stack: block_addr + 32, block[0], output_addr, retdest %mload_u256 // stack: block[1], block[0], output_addr, retdest @@ -45,27 +43,22 @@ gen_message_schedule_from_block_0_loop: // stack: output_addr, block[0] % (1 << 32), block[0] >> 32, output_addr, counter, block[1], retdest %mstore_u32 // stack: block[0] >> 32, output_addr, counter, block[1], retdest - SWAP1 - // stack: output_addr, block[0] >> 32, counter, block[1], retdest - %sub_const(4) - // stack: output_addr - 4, block[0] >> 32, counter, block[1], retdest - SWAP1 - // stack: block[0] >> 32, output_addr - 4, counter, block[1], retdest + %stack (block0_shifted, output_addr, counter) -> (output_addr, 4, 1, counter, block0_shifted) + SUB + // stack: output_addr - 4, 1, counter, block[0] >> 32, block[1], retdest SWAP2 - // stack: counter, output_addr - 4, block[0] >> 32, block[1], retdest - %decrement + SUB + // stack: counter - 1, output_addr - 4, block[0] >> 32, block[1], retdest DUP1 %jumpi(gen_message_schedule_from_block_0_loop) gen_message_schedule_from_block_0_end: // stack: old counter=0, output_addr, block[0], block[1], retdest POP // stack: output_addr, block[0], block[1], retdest - %stack (out, b0, b1) -> (out, 8, b1, b0) - // stack: output_addr, counter=8, block[1], block[0], retdest %add_const(64) - // stack: output_addr + 64, counter, block[1], block[0], retdest - SWAP1 - // stack: counter, output_addr + 64, block[1], block[0], retdest + // stack: output_addr + 64, block[0], block[1], retdest + %stack (out, b0, b1) -> (8, out, b1, b0) + // stack: counter=8, output_addr + 64, block[1], block[0], retdest gen_message_schedule_from_block_1_loop: // Split the second half (256 bits) of the block into the next eight (32-bit) chunks of the message sdchedule. // stack: counter, output_addr, block[1], block[0], retdest @@ -83,29 +76,22 @@ gen_message_schedule_from_block_1_loop: // stack: output_addr, block[1] % (1 << 32), block[1] >> 32, output_addr, counter, block[0], retdest %mstore_u32 // stack: block[1] >> 32, output_addr, counter, block[0], retdest - SWAP1 - // stack: output_addr, block[1] >> 32, counter, block[0], retdest - %sub_const(4) - // stack: output_addr - 4, block[1] >> 32, counter, block[0], retdest - SWAP1 - // stack: block[1] >> 32, output_addr - 4, counter, block[0], retdest + %stack (block1_shifted, output_addr, counter) -> (output_addr, 4, 1, counter, block1_shifted) + SUB + // stack: output_addr - 4, 1, counter, block[1] >> 32, block[0], retdest SWAP2 - // stack: counter, output_addr - 4, block[1] >> 32, block[0], retdest - %decrement + SUB + // stack: counter - 1, output_addr - 4, block[1] >> 32, block[0], retdest DUP1 %jumpi(gen_message_schedule_from_block_1_loop) gen_message_schedule_from_block_1_end: // stack: old counter=0, output_addr, block[1], block[0], retdest POP // stack: output_addr, block[0], block[1], retdest - PUSH 48 - // stack: counter=48, output_addr, block[0], block[1], retdest - SWAP1 - // stack: output_addr, counter, block[0], block[1], retdest %add_const(36) - // stack: output_addr + 36, counter, block[0], block[1], retdest - SWAP1 - // stack: counter, output_addr + 36, block[0], block[1], retdest + // stack: output_addr + 36, block[0], block[1], retdest + PUSH 48 + // stack: counter=48, output_addr + 36, block[0], block[1], retdest gen_message_schedule_remaining_loop: // Generate the next 48 chunks of the message schedule, one at a time, from prior chunks. // stack: counter, output_addr, block[0], block[1], retdest @@ -153,9 +139,10 @@ gen_message_schedule_remaining_loop: // stack: output_addr, x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], sigma_1(x[output_addr - 2*4]), counter, block[0], block[1], retdest SWAP4 // stack: sigma_1(x[output_addr - 2*4]), x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest - %add_u32 - %add_u32 - %add_u32 + ADD + ADD + ADD + %as_u32 // stack: sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest DUP2 // stack: output_addr, sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest @@ -174,7 +161,7 @@ gen_message_schedule_remaining_end: %pop4 JUMP -// Precodition: memory, starting at 0, contains num_blocks, block0[0], ..., block0[63], block1[0], ..., blocklast[63] +// Precondition: memory, starting at 0, contains num_blocks, block0[0], ..., block0[63], block1[0], ..., blocklast[63] // stack contains output_addr // Postcondition: starting at output_addr, set of 256 bytes per block // each contains the 64 32-bit chunks of the message schedule for that block (in four-byte increments) @@ -182,12 +169,14 @@ global sha2_gen_all_message_schedules: // stack: output_addr, retdest DUP1 // stack: output_addr, output_addr, retdest - %mload_current_general_no_offset - // stack: num_blocks, output_addr, output_addr, retdest - PUSH 1 - // stack: cur_offset = 1, counter = num_blocks, output_addr, output_addr, retdest - %build_current_general_address - // stack: cur_addr, counter, output_addr, output_addr, retdest + %build_current_general_address_no_offset + DUP1 + // stack: base_addr, base_addr, output_addr, output_addr, retdest + MLOAD_GENERAL + // stack: num_blocks, base_addr, output_addr, output_addr, retdest + SWAP1 + %increment + // stack: cur_addr (offset = 1), counter = num_blocks, output_addr, output_addr, retdest gen_all_message_schedules_loop: // stack: cur_addr, counter, cur_output_addr, output_addr, retdest PUSH gen_all_message_schedules_loop_end diff --git a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm index d50e5c9a8..f0da871a5 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm @@ -1,14 +1,14 @@ // 32-bit right rotation %macro rotr(rot) // stack: value + DUP1 + // stack: value, value PUSH $rot - // stack: rot, value - DUP2 - DUP2 - // stack: rot, value, rot, value + // stack: rot, value, value SHR - // stack: value >> rot, rot, value - %stack (shifted, rot, value) -> (rot, value, shifted) + // stack: value >> rot, value + SWAP1 + PUSH $rot // stack: rot, value, value >> rot PUSH 32 SUB @@ -26,16 +26,14 @@ // stack: x, x %rotr(7) // stack: rotr(x, 7), x - SWAP1 - // stack: x, rotr(x, 7) DUP1 - // stack: x, x, rotr(x, 7) - %rotr(18) - // stack: rotr(x, 18), x, rotr(x, 7) - SWAP1 - // stack: x, rotr(x, 18), rotr(x, 7) + // stack: rotr(x, 7), rotr(x, 7), x + %rotr(11) + // stack: rotr(x, 18), rotr(x, 7), x + SWAP2 + // stack: x, rotr(x, 7), rotr(x, 18) %shr_const(3) - // stack: shr(x, 3), rotr(x, 18), rotr(x, 7) + // stack: shr(x, 3), rotr(x, 7), rotr(x, 18) XOR XOR %endmacro @@ -46,36 +44,30 @@ // stack: x, x %rotr(17) // stack: rotr(x, 17), x - SWAP1 - // stack: x, rotr(x, 17) DUP1 - // stack: x, x, rotr(x, 17) - %rotr(19) - // stack: rotr(x, 19), x, rotr(x, 17) - SWAP1 - // stack: x, rotr(x, 19), rotr(x, 17) + // stack: rotr(x, 17), rotr(x, 17), x + %rotr(2) + // stack: rotr(x, 19), rotr(x, 17), x + SWAP2 + // stack: x, rotr(x, 17), rotr(x, 19) PUSH 10 SHR - // stack: shr(x, 10), rotr(x, 19), rotr(x, 17) + // stack: shr(x, 10), rotr(x, 17), rotr(x, 19) XOR XOR %endmacro %macro sha2_bigsigma_0 // stack: x - DUP1 - // stack: x, x %rotr(2) - // stack: rotr(x, 2), x - SWAP1 - // stack: x, rotr(x, 2) + // stack: rotr(x, 2) DUP1 - // stack: x, x, rotr(x, 2) - %rotr(13) - // stack: rotr(x, 13), x, rotr(x, 2) - SWAP1 - // stack: x, rotr(x, 13), rotr(x, 2) - %rotr(22) + // stack: rotr(x, 2), rotr(x, 2) + %rotr(11) + // stack: rotr(x, 13), rotr(x, 2) + DUP1 + // stack: rotr(x, 13), rotr(x, 13), rotr(x, 2) + %rotr(9) // stack: rotr(x, 22), rotr(x, 13), rotr(x, 2) XOR XOR @@ -83,19 +75,15 @@ %macro sha2_bigsigma_1 // stack: x - DUP1 - // stack: x, x %rotr(6) - // stack: rotr(x, 6), x - SWAP1 - // stack: x, rotr(x, 6) + // stack: rotr(x, 6) DUP1 - // stack: x, x, rotr(x, 6) - %rotr(11) - // stack: rotr(x, 11), x, rotr(x, 6) - SWAP1 - // stack: x, rotr(x, 11), rotr(x, 6) - %rotr(25) + // stack: rotr(x, 6), rotr(x, 6) + %rotr(5) + // stack: rotr(x, 11), rotr(x, 6) + DUP1 + // stack: rotr(x, 11), rotr(x, 11), rotr(x, 6) + %rotr(14) // stack: rotr(x, 25), rotr(x, 11), rotr(x, 6) XOR XOR @@ -103,41 +91,31 @@ %macro sha2_choice // stack: x, y, z - DUP1 - // stack: x, x, y, z - NOT - // stack: not x, x, y, z SWAP1 - // stack: x, not x, y, z - SWAP3 - // stack: z, not x, y, x - AND - // stack: (not x) and z, y, x - SWAP2 - // stack: x, y, (not x) and z + // stack: y, x, z + DUP3 + // stack: z, y, x, z + XOR + // stack: z xor y, x, z AND - // stack: x and y, (not x) and z - OR + // stack: (z xor y) and x, z + XOR + // stack: ((z xor y) and x) xor z == (x and y) xor (not x and z) %endmacro %macro sha2_majority // stack: x, y, z - DUP1 - // stack: x, x, y, z - DUP3 - // stack: y, x, x, y, z - DUP5 - // stack: z, y, x, x, y, z - AND - // stack: z and y, x, x, y, z - SWAP4 - // stack: z, x, x, y, z and y + DUP2 + DUP2 AND - // stack: z and x, x, y, z and y + // stack: x and y, x, y, z SWAP2 - // stack: y, x, z and x, z and y - AND - // stack: y and x, z and x, z and y + // stack: y, x, x and y, z OR + // stack: y or x, x and y, z + %stack(y_or_x, x_and_y, z) -> (z, y_or_x, x_and_y) + AND + // stack: z and (y or x), x and y OR + // stack: (z and (y or x) or (x and y) == (x and y) or (x and z) or (y and z) %endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/temp_words.asm b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/temp_words.asm index ed610947f..0f3fb4b7a 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/temp_words.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/temp_words.asm @@ -1,18 +1,16 @@ // "T_1" in the SHA-256 spec %macro sha2_temp_word1 - // stack: e, f, g, h, K[i], W[i] - DUP1 - // stack: e, e, f, g, h, K[i], W[i] - %sha2_bigsigma_1 - // stack: Sigma_1(e), e, f, g, h, K[i], W[i] - %stack (sig, e, f, g) -> (e, f, g, sig) - // stack: e, f, g, Sigma_1(e), h, K[i], W[i] + // stack: e, f, g, e, h, K[i], W[i] %sha2_choice - // stack: Ch(e, f, g), Sigma_1(e), h, K[i], W[i] - %add_u32 - %add_u32 - %add_u32 - %add_u32 + // stack: Ch(e, f, g), e, h, K[i], W[i] + SWAP1 + // stack: e, Ch(e, f, g), h, K[i], W[i] + %sha2_bigsigma_1 + // stack: Sigma_1(e), Ch(e, f, g), h, K[i], W[i] + ADD + ADD + ADD + ADD // stack: Ch(e, f, g) + Sigma_1(e) + h + K[i] + W[i] %endmacro @@ -27,6 +25,6 @@ // stack: c, a, b, Sigma_0(a) %sha2_majority // stack: Maj(c, a, b), Sigma_0(a) - %add_u32 + ADD // stack: Maj(c, a, b) + Sigma_0(a) %endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/write_length.asm b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/write_length.asm index 9c2707b8d..a2a0216a6 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/write_length.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/hash/sha2/write_length.asm @@ -3,14 +3,13 @@ %build_current_general_address SWAP1 // stack: length, last_addr - DUP1 - // stack: length, length, last_addr + DUP2 + DUP2 + // stack: length, last_addr, length, last_addr %and_const(0xff) - // stack: length % (1 << 8), length, last_addr - DUP3 - // stack: last_addr, length % (1 << 8), length, last_addr - %swap_mstore - + // stack: length % (1 << 8), last_addr, length, last_addr + MSTORE_GENERAL + %rep 7 // For i = 0 to 6 // stack: length >> (8 * i), last_addr - i - 1 @@ -20,14 +19,13 @@ // stack: length >> (8 * i), last_addr - i - 2 %shr_const(8) // stack: length >> (8 * (i + 1)), last_addr - i - 2 - PUSH 256 DUP2 - // stack: length >> (8 * (i + 1)), 256, length >> (8 * (i + 1)), last_addr - i - 2 - MOD - // stack: (length >> (8 * (i + 1))) % (1 << 8), length >> (8 * (i + 1)), last_addr - i - 2 + PUSH 256 DUP3 - // stack: last_addr - i - 2, (length >> (8 * (i + 1))) % (1 << 8), length >> (8 * (i + 1)), last_addr - i - 2 - %swap_mstore + // stack: length >> (8 * (i + 1)), 256, last_addr - i - 2, length >> (8 * (i + 1)), last_addr - i - 2 + MOD + // stack: (length >> (8 * (i + 1))) % (1 << 8), last_addr - i - 2, length >> (8 * (i + 1)), last_addr - i - 2 + MSTORE_GENERAL %endrep %pop2 diff --git a/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm b/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm index dc73721b3..1b3439475 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm @@ -39,8 +39,8 @@ global panic: %endmacro %macro assert_lt(ret) - GE - %assert_zero($ret) + LT + %assert_nonzero($ret) %endmacro %macro assert_le @@ -56,10 +56,8 @@ global panic: %endmacro %macro assert_gt - // %assert_zero is cheaper than %assert_nonzero, so we will leverage the - // fact that (x > y) == !(x <= y). - LE - %assert_zero + GT + %assert_nonzero %endmacro %macro assert_gt(ret) diff --git a/mpt_trie/src/nibbles.rs b/mpt_trie/src/nibbles.rs index 509f813c7..e388cb78e 100644 --- a/mpt_trie/src/nibbles.rs +++ b/mpt_trie/src/nibbles.rs @@ -118,6 +118,7 @@ macro_rules! impl_as_u64s_for_primitive { }; } +impl_as_u64s_for_primitive!(usize); impl_as_u64s_for_primitive!(u8); impl_as_u64s_for_primitive!(u16); impl_as_u64s_for_primitive!(u32); @@ -178,6 +179,7 @@ macro_rules! impl_to_nibbles { }; } +impl_to_nibbles!(usize); impl_to_nibbles!(u8); impl_to_nibbles!(u16); impl_to_nibbles!(u32); @@ -908,6 +910,23 @@ impl Nibbles { } } + /// Returns a slice of the internal bytes of packed nibbles. + /// Only the relevant bytes (up to `count` nibbles) are considered valid. + pub fn as_byte_slice(&self) -> &[u8] { + // Calculate the number of full bytes needed to cover 'count' nibbles + let bytes_needed = (self.count + 1) / 2; // each nibble is half a byte + + // Safe because we are ensuring the slice size does not exceed the bounds of the + // array + unsafe { + // Convert the pointer to `packed` to a pointer to `u8` + let packed_ptr = self.packed.0.as_ptr() as *const u8; + + // Create a slice from this pointer and the number of needed bytes + std::slice::from_raw_parts(packed_ptr, bytes_needed) + } + } + const fn nibble_append_safety_asserts(&self, n: Nibble) { assert!( self.count < 64, @@ -1616,6 +1635,12 @@ mod tests { format!("{:x}", 0x1234_u64.to_nibbles_byte_padded()), "0x1234" ); + + assert_eq!(format!("{:x}", 0x1234_usize.to_nibbles()), "0x1234"); + assert_eq!( + format!("{:x}", 0x1234_usize.to_nibbles_byte_padded()), + "0x1234" + ); } #[test] @@ -1627,4 +1652,35 @@ mod tests { Nibbles::from_hex_prefix_encoding(&buf).unwrap(); } + + #[test] + fn nibbles_as_byte_slice_works() -> Result<(), StrToNibblesError> { + let cases = [ + (0x0, vec![]), + (0x1, vec![0x01]), + (0x12, vec![0x12]), + (0x123, vec![0x23, 0x01]), + ]; + + for case in cases.iter() { + let nibbles = Nibbles::from(case.0 as u64); + let byte_vec = nibbles.as_byte_slice().to_vec(); + assert_eq!(byte_vec, case.1.clone(), "Failed for input 0x{:X}", case.0); + } + + let input = "3ab76c381c0f8ea617ea96780ffd1e165c754b28a41a95922f9f70682c581351"; + let nibbles = Nibbles::from_str(input)?; + + let byte_vec = nibbles.as_byte_slice().to_vec(); + let mut expected_vec: Vec = hex::decode(input).expect("Invalid hex string"); + expected_vec.reverse(); + assert_eq!( + byte_vec, + expected_vec.clone(), + "Failed for input 0x{}", + input + ); + + Ok(()) + } } diff --git a/mpt_trie/src/partial_trie.rs b/mpt_trie/src/partial_trie.rs index 3d29e8c05..a593f57c8 100644 --- a/mpt_trie/src/partial_trie.rs +++ b/mpt_trie/src/partial_trie.rs @@ -107,6 +107,11 @@ pub trait PartialTrie: /// Returns an iterator over the trie that returns all values for every /// `Leaf` and `Hash` node. fn values(&self) -> impl Iterator; + + /// Returns `true` if the trie contains an element with the given key. + fn contains(&self, k: K) -> bool + where + K: Into; } /// Part of the trait that is not really part of the public interface but @@ -261,6 +266,13 @@ impl PartialTrie for StandardTrie { fn values(&self) -> impl Iterator { self.0.trie_values() } + + fn contains(&self, k: K) -> bool + where + K: Into, + { + self.0.trie_has_item_by_key(k) + } } impl TrieNodeIntern for StandardTrie { @@ -381,6 +393,13 @@ impl PartialTrie for HashedPartialTrie { fn values(&self) -> impl Iterator { self.node.trie_values() } + + fn contains(&self, k: K) -> bool + where + K: Into, + { + self.node.trie_has_item_by_key(k) + } } impl TrieNodeIntern for HashedPartialTrie { diff --git a/mpt_trie/src/trie_ops.rs b/mpt_trie/src/trie_ops.rs index 64f7f70b4..2f0e794d7 100644 --- a/mpt_trie/src/trie_ops.rs +++ b/mpt_trie/src/trie_ops.rs @@ -364,7 +364,7 @@ impl Node { where K: Into, { - let k = k.into(); + let k: Nibbles = k.into(); trace!("Deleting a leaf node with key {} if it exists", k); delete_intern(&self.clone(), k)?.map_or(Ok(None), |(updated_root, deleted_val)| { @@ -391,6 +391,14 @@ impl Node { pub(crate) fn trie_values(&self) -> impl Iterator { self.trie_items().map(|(_, v)| v) } + + pub(crate) fn trie_has_item_by_key(&self, k: K) -> bool + where + K: Into, + { + let k = k.into(); + self.trie_items().any(|(key, _)| key == k) + } } fn insert_into_trie_rec( @@ -1105,6 +1113,28 @@ mod tests { Ok(()) } + #[test] + fn existent_node_key_contains_returns_true() -> TrieOpResult<()> { + common_setup(); + + let mut trie = StandardTrie::default(); + trie.insert(0x1234, vec![91])?; + assert!(trie.contains(0x1234)); + + Ok(()) + } + + #[test] + fn non_existent_node_key_contains_returns_false() -> TrieOpResult<()> { + common_setup(); + + let mut trie = StandardTrie::default(); + trie.insert(0x1234, vec![91])?; + assert!(!trie.contains(0x5678)); + + Ok(()) + } + #[test] fn deleting_from_an_empty_trie_returns_none() -> TrieOpResult<()> { common_setup(); diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs index 871eef2e9..74f50d769 100644 --- a/mpt_trie/src/trie_subsets.rs +++ b/mpt_trie/src/trie_subsets.rs @@ -805,6 +805,24 @@ mod tests { Ok(()) } + #[test] + fn sub_trie_existent_key_contains_returns_true() { + let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap(); + + let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap(); + + assert!(partial_trie.contains(0x0)); + } + + #[test] + fn sub_trie_non_existent_key_contains_returns_false() { + let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap(); + + let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap(); + + assert!(!partial_trie.contains(0x1)); + } + fn assert_all_keys_do_not_exist(trie: &TrieType, ks: impl Iterator) { for k in ks { assert!(trie.get(k).is_none());