diff --git a/scripts/constants.py b/scripts/constants.py index a87071402..068649471 100644 --- a/scripts/constants.py +++ b/scripts/constants.py @@ -149,6 +149,7 @@ class ArtifactType(Enum): {"contract_name": "EVM", "is_account_contract": False}, {"contract_name": "OpenzeppelinAccount", "is_account_contract": True}, {"contract_name": "ERC20", "is_account_contract": False}, + {"contract_name": "replace_class", "is_account_contract": False}, ] DECLARED_CONTRACTS = [ {"contract_name": "kakarot", "cairo_version": ArtifactType.cairo0}, @@ -158,6 +159,7 @@ class ArtifactType(Enum): {"contract_name": "EVM", "cairo_version": ArtifactType.cairo0}, {"contract_name": "OpenzeppelinAccount", "cairo_version": ArtifactType.cairo0}, {"contract_name": "Precompiles", "cairo_version": ArtifactType.cairo1}, + {"contract_name": "replace_class", "cairo_version": ArtifactType.cairo0}, ] EVM_PRIVATE_KEY = os.getenv("EVM_PRIVATE_KEY") diff --git a/scripts/utils/starknet.py b/scripts/utils/starknet.py index 55a61bf41..755f11e17 100644 --- a/scripts/utils/starknet.py +++ b/scripts/utils/starknet.py @@ -79,7 +79,7 @@ async def get_starknet_account( raise ValueError( "address was not given in arg nor in env variable, see README.md#Deploy" ) - address = int(address, 16) + address = int(address, 16) if isinstance(address, str) else address private_key = private_key or NETWORK["private_key"] if private_key is None: raise ValueError( @@ -348,7 +348,7 @@ def _convert_offset_to_hex(obj): ) -async def deploy_starknet_account(class_hash, private_key=None, amount=1): +async def deploy_starknet_account(class_hash=None, private_key=None, amount=1): salt = random.randint(0, 2**251) private_key = private_key or NETWORK["private_key"] if private_key is None: @@ -357,6 +357,7 @@ async def deploy_starknet_account(class_hash, private_key=None, amount=1): ) key_pair = KeyPair.from_private_key(int(private_key, 16)) constructor_calldata = [key_pair.public_key] + class_hash = class_hash or get_declarations()["OpenzeppelinAccount"] address = compute_address( salt=salt, class_hash=class_hash, diff --git a/src/kakarot/events.cairo b/src/kakarot/events.cairo index a16e23e10..a1c515ca6 100644 --- a/src/kakarot/events.cairo +++ b/src/kakarot/events.cairo @@ -5,3 +5,7 @@ @event func evm_contract_deployed(evm_contract_address: felt, starknet_contract_address: felt) { } + +@event +func kakarot_upgraded(new_class_hash: felt) { +} diff --git a/src/kakarot/kakarot.cairo b/src/kakarot/kakarot.cairo index 81f378f10..acf6d5312 100644 --- a/src/kakarot/kakarot.cairo +++ b/src/kakarot/kakarot.cairo @@ -8,14 +8,16 @@ from starkware.cairo.common.bool import FALSE, TRUE from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin from starkware.cairo.common.math_cmp import is_not_zero from starkware.cairo.common.uint256 import Uint256 -from starkware.starknet.common.syscalls import get_caller_address +from starkware.starknet.common.syscalls import get_caller_address, replace_class from starkware.cairo.common.registers import get_fp_and_pc +from openzeppelin.access.ownable.library import Ownable // Local dependencies from backend.starknet import Starknet from kakarot.account import Account -from kakarot.model import model +from kakarot.events import kakarot_upgraded from kakarot.library import Kakarot +from kakarot.model import model from utils.utils import Helpers // Constructor @@ -38,6 +40,19 @@ func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr ); } +// @notive Upgrade the contract +// @dev Use the replace_hash syscall to upgrade the contract +// @param new_class_hash The new class hash +@external +func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + new_class_hash: felt +) { + Ownable.assert_only_owner(); + replace_class(new_class_hash); + kakarot_upgraded.emit(new_class_hash); + return (); +} + // @notice Set the native token used by kakarot // @dev Set the native token which will emulate the role of ETH on Ethereum // @param native_token_address_ The address of the native token diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index d7b32d267..2548087e2 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -336,3 +336,23 @@ def eth_get_code(): from scripts.utils.kakarot import eth_get_code return eth_get_code + + +@pytest.fixture +def call(): + """ + Send a Starknet call. + """ + from scripts.utils.starknet import call + + return call + + +@pytest.fixture +def invoke(): + """ + Send a Starknet transaction. + """ + from scripts.utils.starknet import invoke + + return invoke diff --git a/tests/end_to_end/test_kakarot.py b/tests/end_to_end/test_kakarot.py index df928a8fe..4e76659bf 100644 --- a/tests/end_to_end/test_kakarot.py +++ b/tests/end_to_end/test_kakarot.py @@ -32,6 +32,27 @@ def evm(get_contract): return get_contract("EVM") +@pytest.fixture(scope="session") +async def other(): + """ + Just another Starknet contract. + """ + from scripts.utils.starknet import deploy_starknet_account, get_starknet_account + + account_info = await deploy_starknet_account() + return await get_starknet_account(account_info["address"]) + + +@pytest.fixture(scope="session") +async def class_hashes(): + """ + All declared class hashes. + """ + from scripts.utils.starknet import get_declarations + + return get_declarations() + + @pytest_asyncio.fixture(scope="session") async def origin(evm: Contract, addresses): """ @@ -184,3 +205,31 @@ async def test_eth_call_should_succeed( assert result.success == 1 assert result.return_data == [] assert result.gas_used == 21_000 + + class TestUpgrade: + + async def test_should_raise_when_caller_is_not_owner( + self, starknet, kakarot, invoke, other, class_hashes + ): + prev_class_hash = await starknet.get_class_hash_at(kakarot.address) + await invoke("kakarot", "upgrade", class_hashes["EVM"], account=other) + new_class_hash = await starknet.get_class_hash_at(kakarot.address) + assert prev_class_hash == new_class_hash + + async def test_should_raise_when_class_hash_is_not_declared( + self, starknet, kakarot, invoke + ): + prev_class_hash = await starknet.get_class_hash_at(kakarot.address) + await invoke("kakarot", "upgrade", 0xDEAD) + new_class_hash = await starknet.get_class_hash_at(kakarot.address) + assert prev_class_hash == new_class_hash + + async def test_should_upgrade_class_hash( + self, starknet, kakarot, invoke, class_hashes + ): + prev_class_hash = await starknet.get_class_hash_at(kakarot.address) + await invoke("kakarot", "upgrade", class_hashes["replace_class"]) + new_class_hash = await starknet.get_class_hash_at(kakarot.address) + assert prev_class_hash != new_class_hash + assert new_class_hash == class_hashes["replace_class"] + await invoke("kakarot", "upgrade", prev_class_hash) diff --git a/tests/fixtures/replace_class.cairo b/tests/fixtures/replace_class.cairo new file mode 100644 index 000000000..3f520017e --- /dev/null +++ b/tests/fixtures/replace_class.cairo @@ -0,0 +1,12 @@ +%lang starknet + +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.starknet.common.syscalls import replace_class + +@external +func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + new_class_hash: felt +) { + replace_class(new_class_hash); + return (); +}