Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
Add test for default_dict_copy and case starknet contract account is …
Browse files Browse the repository at this point in the history
…not deployed
  • Loading branch information
ClementWalter committed Oct 23, 2023
1 parent fcd5e8f commit 006aed2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
47 changes: 31 additions & 16 deletions src/kakarot/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ namespace Account {
alloc_locals;
let storage = self.storage;
let (local storage_key) = hash_felts{hash_ptr=pedersen_ptr}(cast(key, felt*), 2);

let (pointer) = dict_read{dict_ptr=storage}(key=storage_key);

// Case reading from local storage
if (pointer != 0) {
// Return from local storage if found
let value_ptr = cast(pointer, Uint256*);
Expand All @@ -198,25 +198,40 @@ namespace Account {
self.selfdestruct,
);
return (self, [value_ptr]);
} else {
// Otherwise regular read value from contract storage
}

// Case reading from Starknet storage
let (local registered_starknet_account) = Accounts.get_starknet_address(address.evm);
let starknet_account_exists = is_not_zero(registered_starknet_account);
if (starknet_account_exists != 0) {
let (value) = IContractAccount.storage(
contract_address=address.starknet, key=storage_key
);
// Cache for possible later use (almost free and can save a lot)
tempvar new_value = new Uint256(value.low, value.high);
dict_write{dict_ptr=storage}(key=storage_key, new_value=cast(new_value, felt));
tempvar self = new model.Account(
self.address,
self.code_len,
self.code,
self.storage_start,
storage,
self.nonce,
self.selfdestruct,
);
return (self, value);
tempvar value_ptr = new Uint256(value.low, value.high);
tempvar syscall_ptr = syscall_ptr;
tempvar pedersen_ptr = pedersen_ptr;
tempvar range_check_ptr = range_check_ptr;
// Otherwise returns 0
} else {
tempvar value_ptr = new Uint256(0, 0);
tempvar syscall_ptr = syscall_ptr;
tempvar pedersen_ptr = pedersen_ptr;
tempvar range_check_ptr = range_check_ptr;
}

// Cache for possible later use (almost free and can save a syscall later on)
dict_write{dict_ptr=storage}(key=storage_key, new_value=cast(value_ptr, felt));

tempvar self = new model.Account(
self.address,
self.code_len,
self.code,
self.storage_start,
storage,
self.nonce,
self.selfdestruct,
);
return (self, [value_ptr]);
}

// @notice Update a storage key with the given value
Expand Down
21 changes: 19 additions & 2 deletions src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,31 @@ func default_dict_copy{range_check_ptr}(start: DictAccess*, end: DictAccess*) ->
return (new_start, new_ptr);
}

tempvar squashed_start = squashed_start;
tempvar dict_len = dict_len;
tempvar new_ptr = new_ptr;

loop:
dict_write{dict_ptr=new_ptr}(key=squashed_start.key, new_value=squashed_start.new_value);
let squashed_start = cast([ap - 3], DictAccess*);
let dict_len = [ap - 2];
let new_ptr = cast([ap - 1], DictAccess*);

let key = [squashed_start].key;
let new_value = [squashed_start].new_value;

dict_write{dict_ptr=new_ptr}(key=key, new_value=new_value);

tempvar squashed_start = squashed_start + DictAccess.SIZE;
tempvar dict_len = dict_len - DictAccess.SIZE;
tempvar new_ptr = new_ptr + DictAccess.SIZE;

static_assert squashed_start == [ap - 3];
static_assert dict_len == [ap - 2];
static_assert new_ptr == [ap - 1];

jmp loop if dict_len != 0;

return (new_start, new_ptr);
return (new_start, new_ptr - DictAccess.SIZE);
}

func dict_keys{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> (
Expand Down
24 changes: 23 additions & 1 deletion tests/src/utils/test_dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.dict import dict_write, dict_read

from utils.dict import dict_keys
from utils.dict import dict_keys, default_dict_copy

@external
func test__dict_keys__should_return_keys{
Expand Down Expand Up @@ -44,3 +44,25 @@ func test__dict_keys__should_return_keys{

return ();
}

@external
func test__default_dict_copy__should_return_copied_dict{range_check_ptr}() {
let default_value = 0xdead;
let (dict_ptr_start) = default_dict_new(default_value);
let dict_ptr = dict_ptr_start;
let key = 0x7e1;
with dict_ptr {
let (value) = dict_read(key);
assert value = default_value;
dict_write(key, 0xff);
let (value) = dict_read(key);
assert value = 0xff;
}
let (new_start, new_ptr) = default_dict_copy(dict_ptr_start, dict_ptr);
let (value) = dict_read{dict_ptr=new_ptr}(key);
assert value = 0xff;
let (value) = dict_read{dict_ptr=new_ptr}(key + 1);
assert value = default_value;

return ();
}
4 changes: 4 additions & 0 deletions tests/src/utils/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ class TestDict:
class TestDictKeys:
async def test_should_return_keys(self, dict_):
await dict_.test__dict_keys__should_return_keys().call()

class TestDefaultDictCopy:
async def test_should_return_copied_dict(self, dict_):
await dict_.test__default_dict_copy__should_return_copied_dict().call()

0 comments on commit 006aed2

Please sign in to comment.