From 006aed2281808113d8dbf47ab19a202405d4b44a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Walter?= Date: Mon, 23 Oct 2023 16:54:04 +0200 Subject: [PATCH] Add test for default_dict_copy and case starknet contract account is not deployed --- src/kakarot/account.cairo | 47 ++++++++++++++++++++++----------- src/utils/dict.cairo | 21 +++++++++++++-- tests/src/utils/test_dict.cairo | 24 ++++++++++++++++- tests/src/utils/test_dict.py | 4 +++ 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/src/kakarot/account.cairo b/src/kakarot/account.cairo index cc1a304897..b9f7b3f193 100644 --- a/src/kakarot/account.cairo +++ b/src/kakarot/account.cairo @@ -182,9 +182,9 @@ namespace Account { alloc_locals; let storage = self.storage; let (local storage_key) = hash_felts{hash_ptr=pedersen_ptr}(cast(key, felt*), 2); - let (pointer) = dict_read{dict_ptr=storage}(key=storage_key); + // Case reading from local storage if (pointer != 0) { // Return from local storage if found let value_ptr = cast(pointer, Uint256*); @@ -198,25 +198,40 @@ namespace Account { self.selfdestruct, ); return (self, [value_ptr]); - } else { - // Otherwise regular read value from contract storage + } + + // Case reading from Starknet storage + let (local registered_starknet_account) = Accounts.get_starknet_address(address.evm); + let starknet_account_exists = is_not_zero(registered_starknet_account); + if (starknet_account_exists != 0) { let (value) = IContractAccount.storage( contract_address=address.starknet, key=storage_key ); - // Cache for possible later use (almost free and can save a lot) - tempvar new_value = new Uint256(value.low, value.high); - dict_write{dict_ptr=storage}(key=storage_key, new_value=cast(new_value, felt)); - tempvar self = new model.Account( - self.address, - self.code_len, - self.code, - self.storage_start, - storage, - self.nonce, - self.selfdestruct, - ); - return (self, value); + tempvar value_ptr = new Uint256(value.low, value.high); + tempvar syscall_ptr = syscall_ptr; + tempvar pedersen_ptr = pedersen_ptr; + tempvar range_check_ptr = range_check_ptr; + // Otherwise returns 0 + } else { + tempvar value_ptr = new Uint256(0, 0); + tempvar syscall_ptr = syscall_ptr; + tempvar pedersen_ptr = pedersen_ptr; + tempvar range_check_ptr = range_check_ptr; } + + // Cache for possible later use (almost free and can save a syscall later on) + dict_write{dict_ptr=storage}(key=storage_key, new_value=cast(value_ptr, felt)); + + tempvar self = new model.Account( + self.address, + self.code_len, + self.code, + self.storage_start, + storage, + self.nonce, + self.selfdestruct, + ); + return (self, [value_ptr]); } // @notice Update a storage key with the given value diff --git a/src/utils/dict.cairo b/src/utils/dict.cairo index 1013b5045f..9edfe7fe30 100644 --- a/src/utils/dict.cairo +++ b/src/utils/dict.cairo @@ -30,14 +30,31 @@ func default_dict_copy{range_check_ptr}(start: DictAccess*, end: DictAccess*) -> return (new_start, new_ptr); } + tempvar squashed_start = squashed_start; + tempvar dict_len = dict_len; + tempvar new_ptr = new_ptr; + loop: - dict_write{dict_ptr=new_ptr}(key=squashed_start.key, new_value=squashed_start.new_value); + let squashed_start = cast([ap - 3], DictAccess*); + let dict_len = [ap - 2]; + let new_ptr = cast([ap - 1], DictAccess*); + + let key = [squashed_start].key; + let new_value = [squashed_start].new_value; + + dict_write{dict_ptr=new_ptr}(key=key, new_value=new_value); + tempvar squashed_start = squashed_start + DictAccess.SIZE; tempvar dict_len = dict_len - DictAccess.SIZE; + tempvar new_ptr = new_ptr + DictAccess.SIZE; + + static_assert squashed_start == [ap - 3]; + static_assert dict_len == [ap - 2]; + static_assert new_ptr == [ap - 1]; jmp loop if dict_len != 0; - return (new_start, new_ptr); + return (new_start, new_ptr - DictAccess.SIZE); } func dict_keys{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> ( diff --git a/tests/src/utils/test_dict.cairo b/tests/src/utils/test_dict.cairo index 8c2c057879..650dc211b4 100644 --- a/tests/src/utils/test_dict.cairo +++ b/tests/src/utils/test_dict.cairo @@ -6,7 +6,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize from starkware.cairo.common.dict import dict_write, dict_read -from utils.dict import dict_keys +from utils.dict import dict_keys, default_dict_copy @external func test__dict_keys__should_return_keys{ @@ -44,3 +44,25 @@ func test__dict_keys__should_return_keys{ return (); } + +@external +func test__default_dict_copy__should_return_copied_dict{range_check_ptr}() { + let default_value = 0xdead; + let (dict_ptr_start) = default_dict_new(default_value); + let dict_ptr = dict_ptr_start; + let key = 0x7e1; + with dict_ptr { + let (value) = dict_read(key); + assert value = default_value; + dict_write(key, 0xff); + let (value) = dict_read(key); + assert value = 0xff; + } + let (new_start, new_ptr) = default_dict_copy(dict_ptr_start, dict_ptr); + let (value) = dict_read{dict_ptr=new_ptr}(key); + assert value = 0xff; + let (value) = dict_read{dict_ptr=new_ptr}(key + 1); + assert value = default_value; + + return (); +} diff --git a/tests/src/utils/test_dict.py b/tests/src/utils/test_dict.py index 422b4f81a7..635de20d3d 100644 --- a/tests/src/utils/test_dict.py +++ b/tests/src/utils/test_dict.py @@ -18,3 +18,7 @@ class TestDict: class TestDictKeys: async def test_should_return_keys(self, dict_): await dict_.test__dict_keys__should_return_keys().call() + + class TestDefaultDictCopy: + async def test_should_return_copied_dict(self, dict_): + await dict_.test__default_dict_copy__should_return_copied_dict().call()