From 96a83842facde6f1bc75b534ad4689ea82d29abd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 28 May 2024 04:48:43 -0700 Subject: [PATCH] refactor[tool]: refactor storage layout export (#3789) refactor storage layout allocator. separate concerns of allocating the storage layout and exporting the storage layout into separate functions. this is intended to make it easier to add features to the storage layout export in the future fix several bugs in storage layout overrides, including: - handle stateful modules - add a sanity check that the override file roundtrips - ignore non-storage variables in override files - set nonreentrant lock properly for all functions instead of panicking misc: add `n_slots` to each storage layout item in the export --- tests/unit/cli/storage_layout/__init__.py | 0 .../cli/storage_layout/test_storage_layout.py | 80 ++--- .../test_storage_layout_overrides.py | 296 +++++++++++++++++- tests/unit/cli/storage_layout/utils.py | 17 + vyper/compiler/phases.py | 5 +- vyper/semantics/analysis/base.py | 6 +- vyper/semantics/analysis/data_positions.py | 247 ++++++++++----- 7 files changed, 499 insertions(+), 152 deletions(-) create mode 100644 tests/unit/cli/storage_layout/__init__.py create mode 100644 tests/unit/cli/storage_layout/utils.py diff --git a/tests/unit/cli/storage_layout/__init__.py b/tests/unit/cli/storage_layout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index ece2743b81..d490d2008f 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -1,21 +1,6 @@ from vyper.compiler import compile_code -from vyper.evm.opcodes import version_check - -def _adjust_storage_layout_for_cancun(layout): - def _go(layout): - for _varname, item in layout.items(): - if "slot" in item and isinstance(item["slot"], int): - item["slot"] -= 1 - else: - # recurse to submodule - _go(item) - - if version_check(begin="cancun"): - layout["transient_storage_layout"] = { - "$.nonreentrant_key": layout["storage_layout"].pop("$.nonreentrant_key") - } - _go(layout["storage_layout"]) +from .utils import adjust_storage_layout_for_cancun def test_storage_layout(): @@ -55,19 +40,18 @@ def public_foo3(): pass """ - out = compile_code(code, output_formats=["layout"]) - expected = { "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, - "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, - "baz": {"slot": 6, "type": "Bytes[65]"}, - "bar": {"slot": 10, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock", "n_slots": 1}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]", "n_slots": 1}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]", "n_slots": 4}, + "baz": {"slot": 6, "type": "Bytes[65]", "n_slots": 4}, + "bar": {"slot": 10, "type": "uint256", "n_slots": 1}, } } - _adjust_storage_layout_for_cancun(expected) + adjust_storage_layout_for_cancun(expected) + out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected @@ -88,12 +72,9 @@ def __init__(): "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, }, - "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "name": {"slot": 1, "type": "String[32]"}, - }, + "storage_layout": {"name": {"slot": 1, "type": "String[32]", "n_slots": 2}}, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -137,13 +118,12 @@ def __init__(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "counter2": {"slot": 2, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 3, "type": "uint256", "n_slots": 1}}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -187,13 +167,12 @@ def __init__(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, - "counter2": {"slot": 3, "type": "uint256"}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "counter2": {"slot": 3, "type": "uint256", "n_slots": 1}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -271,14 +250,14 @@ def bar(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, - "counter2": {"slot": 3, "type": "uint256"}, - "a_library": {"supply": {"slot": 4, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock", "n_slots": 1}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "counter2": {"slot": 3, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 4, "type": "uint256", "n_slots": 1}}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -351,16 +330,15 @@ def foo() -> uint256: }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, "lib2": { - "lib1": {"supply": {"slot": 2, "type": "uint256"}}, - "storage_variable": {"slot": 3, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "storage_variable": {"slot": 3, "type": "uint256", "n_slots": 1}, }, - "counter2": {"slot": 4, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256", "n_slots": 1}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 707c94c3fc..f02a8471e2 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -3,6 +3,7 @@ import pytest from vyper.compiler import compile_code +from vyper.evm.opcodes import version_check from vyper.exceptions import StorageLayoutException @@ -12,11 +13,11 @@ def test_storage_layout_overrides(): b: uint256""" storage_layout_overrides = { - "a": {"type": "uint256", "slot": 1}, - "b": {"type": "uint256", "slot": 0}, + "a": {"type": "uint256", "slot": 1, "n_slots": 1}, + "b": {"type": "uint256", "slot": 0, "n_slots": 1}, } - expected_output = {"storage_layout": storage_layout_overrides, "code_layout": {}} + expected_output = {"storage_layout": storage_layout_overrides} out = compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_overrides @@ -61,18 +62,26 @@ def public_foo3(): """ storage_layout_override = { - "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, - "foo": {"type": "HashMap[address, uint256]", "slot": 1}, - "baz": {"type": "Bytes[65]", "slot": 2}, - "bar": {"type": "uint256", "slot": 6}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8, "n_slots": 1}, + "foo": {"type": "HashMap[address, uint256]", "slot": 1, "n_slots": 1}, + "baz": {"type": "Bytes[65]", "slot": 2, "n_slots": 4}, + "bar": {"type": "uint256", "slot": 6, "n_slots": 1}, } + if version_check(begin="cancun"): + del storage_layout_override["$.nonreentrant_key"] - expected_output = {"storage_layout": storage_layout_override, "code_layout": {}} + expected_output = {"storage_layout": storage_layout_override} out = compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override ) + # adjust transient storage layout + if version_check(begin="cancun"): + expected_output["transient_storage_layout"] = { + "$.nonreentrant_key": {"n_slots": 1, "slot": 0, "type": "nonreentrant lock"} + } + assert out["layout"] == expected_output @@ -118,16 +127,55 @@ def test_override_nonreentrant_slot(): def foo(): pass """ - storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} - exception_regex = re.escape( - f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" - ) - with pytest.raises(StorageLayoutException, match=exception_regex): - compile_code( - code, output_formats=["layout"], storage_layout_override=storage_layout_override + if version_check(begin="cancun"): + del storage_layout_override["$.nonreentrant_key"] + assert ( + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + is not None + ) + + else: + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + +def test_override_missing_nonreentrant_key(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {} + + if version_check(begin="cancun"): + assert ( + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + is not None + ) + # in cancun, nonreentrant key is allocated in transient storage and can't be overridden + return + else: + exception_regex = re.escape( + "Could not find storage slot for $.nonreentrant_key." + " Have you used the correct storage layout file?" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) def test_incomplete_overrides(): @@ -139,9 +187,225 @@ def test_incomplete_overrides(): with pytest.raises( StorageLayoutException, - match="Could not find storage_slot for symbol. " + match="Could not find storage slot for symbol. " "Have you used the correct storage layout file?", ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override ) + + +@pytest.mark.requires_evm_version("cancun") +def test_override_with_immutables_and_transient(): + code = """ +some_local: transient(uint256) +some_immutable: immutable(uint256) +name: public(String[64]) + +@deploy +def __init__(): + some_immutable = 5 + """ + + storage_layout_override = {"name": {"slot": 10, "type": "String[64]", "n_slots": 3}} + + out = compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + expected_output = { + "storage_layout": storage_layout_override, + "transient_storage_layout": {"some_local": {"slot": 1, "type": "uint256", "n_slots": 1}}, + "code_layout": {"some_immutable": {"offset": 0, "type": "uint256", "length": 32}}, + } + + assert out["layout"] == expected_output + + +def test_override_modules(make_input_bundle): + # test module storage layout, with initializes in an imported module + # note code repetition with test_storage_layout.py; maybe refactor to + # some fixtures + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +counter: uint256 +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 # test shadowing +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + override = { + "counter": {"slot": 5, "type": "uint256", "n_slots": 1}, + "lib2": { + "lib1": {"supply": {"slot": 12, "type": "uint256", "n_slots": 1}}, + "storage_variable": {"slot": 34, "type": "uint256", "n_slots": 1}, + "counter": {"slot": 15, "type": "uint256", "n_slots": 1}, + }, + "counter2": {"slot": 171, "type": "uint256", "n_slots": 1}, + } + out = compile_code( + code, output_formats=["layout"], input_bundle=input_bundle, storage_layout_override=override + ) + + expected_output = { + "storage_layout": override, + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + } + + assert out["layout"] == expected_output + + +def test_module_collision(make_input_bundle): + # test collisions between modules which are "siblings" in the import tree + # some fixtures + lib1 = """ +supply: uint256 + """ + lib2 = """ +counter: uint256 + """ + code = """ +import lib1 as a_library +import lib2 + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 +initializes: a_library + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + override = { + "lib2": {"counter": {"slot": 15, "type": "uint256", "n_slots": 1}}, + "a_library": {"supply": {"slot": 15, "type": "uint256", "n_slots": 1}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'a_library.supply' to" + " slot 15 but it has already been reserved by 'lib2.counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) + + +def test_module_collision2(make_input_bundle): + # test "parent-child" collisions + lib1 = """ +supply: uint256 + """ + code = """ +import lib1 + +counter: uint256 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + override = { + "counter": {"slot": 15, "type": "uint256", "n_slots": 1}, + "lib1": {"supply": {"slot": 15, "type": "uint256", "n_slots": 1}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'lib1.supply' to" + " slot 15 but it has already been reserved by 'counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) + + +def test_module_overlap(make_input_bundle): + # test a collision which only overlaps on one word + lib1 = """ +supply: uint256[2] + """ + code = """ +import lib1 + +counter: uint256[2] + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + override = { + "counter": {"slot": 15, "type": "uint256[2]", "n_slots": 2}, + "lib1": {"supply": {"slot": 16, "type": "uint256[2]", "n_slots": 2}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'lib1.supply' to" + " slot 16 but it has already been reserved by 'counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) diff --git a/tests/unit/cli/storage_layout/utils.py b/tests/unit/cli/storage_layout/utils.py new file mode 100644 index 0000000000..6e67886b0d --- /dev/null +++ b/tests/unit/cli/storage_layout/utils.py @@ -0,0 +1,17 @@ +from vyper.evm.opcodes import version_check + + +def adjust_storage_layout_for_cancun(layout): + def _go(layout): + for _varname, item in layout.items(): + if "slot" in item and isinstance(item["slot"], int): + item["slot"] -= 1 + else: + # recurse to submodule + _go(item) + + if version_check(begin="cancun"): + nonreentrant = layout["storage_layout"].pop("$.nonreentrant_key", None) + if nonreentrant is not None: + layout["transient_storage_layout"] = {"$.nonreentrant_key": nonreentrant} + _go(layout["storage_layout"]) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 10b4833e67..6f437395c6 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -12,6 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings, merge_settings from vyper.ir import compile_ir, optimizer from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target +from vyper.semantics.analysis.data_positions import generate_layout_export from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -180,7 +181,9 @@ def compilation_target(self): @cached_property def storage_layout(self) -> StorageLayout: module_ast = self.compilation_target - return set_data_positions(module_ast, self.storage_layout_override) + set_data_positions(module_ast, self.storage_layout_override) + + return generate_layout_export(module_ast) @property def global_ctx(self) -> ModuleT: diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 718581c20c..026e0626e7 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,7 +194,7 @@ def getter_ast(self) -> Optional[vy_ast.VyperNode]: def set_position(self, position: VarOffset) -> None: if self.position is not None: - raise CompilerPanic("Position was already assigned") + raise CompilerPanic(f"Position was already assigned: {self}") assert isinstance(position, VarOffset) # sanity check self.position = position @@ -207,6 +207,10 @@ def is_state_variable(self): def get_size(self) -> int: return self.typ.get_size_in(self.location) + @property + def is_storage(self): + return self.location == DataLocation.STORAGE + @property def is_transient(self): return self.location == DataLocation.TRANSIENT diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index e5e8b998ca..5f6702668f 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,5 +1,6 @@ +import json from collections import defaultdict -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar from vyper import ast as vy_ast from vyper.evm.opcodes import version_check @@ -11,7 +12,7 @@ def set_data_positions( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout = None -) -> StorageLayout: +) -> None: """ Parse the annotated Vyper AST, determine data positions for all variables, and annotate the AST nodes with the position data. @@ -22,14 +23,19 @@ def set_data_positions( Top-level Vyper AST node that has already been annotated with type data. """ if storage_layout_overrides is not None: - # extract code layout with no overrides - code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] - storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - return {"storage_layout": storage_slots, "code_layout": code_offsets} + # allocate code layout with no overrides + _allocate_layout_r(vyper_module, no_storage=True) + _allocate_with_overrides(vyper_module, storage_layout_overrides) - ret = _allocate_layout_r(vyper_module) - assert isinstance(ret, defaultdict) - return dict(ret) # convert back to dict + # sanity check that generated layout file is the same as the input. + roundtrip = generate_layout_export(vyper_module).get(_LAYOUT_KEYS[DataLocation.STORAGE], {}) + if roundtrip != storage_layout_overrides: + msg = "Computed storage layout does not match override file!\n" + msg += f"expected: {json.dumps(storage_layout_overrides)}\n\n" + msg += f"got:\n{json.dumps(roundtrip)}" + raise CompilerPanic(msg) + else: + _allocate_layout_r(vyper_module) _T = TypeVar("_T") @@ -45,6 +51,7 @@ def __setitem__(self, k, v): # some name that the user cannot assign to a variable GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" +NONREENTRANT_KEY_SIZE = 1 class SimpleAllocator: @@ -55,7 +62,7 @@ def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): self._slot = starting_slot self._max_slot = max_slot - def allocate_slot(self, n, var_name, node=None): + def allocate_slot(self, n, node=None): ret = self._slot if self._slot + n >= self._max_slot: raise StorageLayoutException( @@ -67,7 +74,7 @@ def allocate_slot(self, n, var_name, node=None): return ret def allocate_global_nonreentrancy_slot(self): - slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + slot = self.allocate_slot(NONREENTRANT_KEY_SIZE) assert slot == self._starting_slot return slot @@ -141,74 +148,105 @@ def _reserve_slot(self, slot: int, var_name: str) -> None: self.occupied_slots[slot] = var_name -def set_storage_slots_with_overrides( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout -) -> StorageLayout: +def _fetch_path(path: list[str], layout: StorageLayout, node: vy_ast.VyperNode): + tmp = layout + qualified_path = ".".join(path) + + for segment in path: + if segment not in tmp: + raise StorageLayoutException( + f"Could not find storage slot for {qualified_path}. " + "Have you used the correct storage layout file?", + node, + ) + tmp = tmp[segment] + + try: + ret = tmp["slot"] + except KeyError as e: + raise StorageLayoutException(f"no storage slot for {qualified_path}", node) from e + + return ret + + +def _allocate_with_overrides(vyper_module: vy_ast.Module, layout: StorageLayout): """ Set storage layout given a layout override file. - Returns the layout as a dict of variable name -> variable info - (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() - reserved_slots = OverridingStorageAllocator() + allocator = OverridingStorageAllocator() + + nonreentrant_slot = None + if GLOBAL_NONREENTRANT_KEY in layout: + nonreentrant_slot = layout[GLOBAL_NONREENTRANT_KEY]["slot"] + + _allocate_with_overrides_r(vyper_module, layout, allocator, nonreentrant_slot, []) + +def _allocate_with_overrides_r( + vyper_module: vy_ast.Module, + layout: StorageLayout, + allocator: OverridingStorageAllocator, + global_nonreentrant_slot: Optional[int], + path: list[str], +): # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] + fn_t = node._metadata["func_type"] # Ignore functions without non-reentrant - if not type_.nonreentrant: + if not fn_t.nonreentrant: continue - variable_name = GLOBAL_NONREENTRANT_KEY - - # re-entrant key was already identified - if variable_name in ret: + # if reentrancy keys get allocated in transient storage, we don't + # override them + if get_reentrancy_key_location() == DataLocation.TRANSIENT: continue # Expect to find this variable within the storage layout override - if variable_name in storage_layout_overrides: - reentrant_slot = storage_layout_overrides[variable_name]["slot"] - # Ensure that this slot has not been used, and prevents other storage variables - # from using the same slot - reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - - type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) - - ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} - else: + if global_nonreentrant_slot is None: raise StorageLayoutException( - f"Could not find storage_slot for {variable_name}. " + f"Could not find storage slot for {GLOBAL_NONREENTRANT_KEY}. " "Have you used the correct storage layout file?", node, ) - # Iterate through variables - for node in vyper_module.get_children(vy_ast.VariableDecl): - # Ignore immutable parameters - if node.get("annotation.func.id") == "immutable": + # prevent other storage variables from using the same slot + if allocator.occupied_slots.get(global_nonreentrant_slot) != GLOBAL_NONREENTRANT_KEY: + allocator.reserve_slot_range( + global_nonreentrant_slot, NONREENTRANT_KEY_SIZE, GLOBAL_NONREENTRANT_KEY + ) + + fn_t.set_reentrancy_key_position(VarOffset(global_nonreentrant_slot)) + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + + sub_path = [*path, module_info.alias] + _allocate_with_overrides_r( + module_info.module_node, layout, allocator, global_nonreentrant_slot, sub_path + ) continue + # Iterate through variables + # Ignore immutables and transient variables varinfo = node.target._metadata["varinfo"] + if not varinfo.is_storage: + continue + # Expect to find this variable within the storage layout overrides - if node.target.id in storage_layout_overrides: - var_slot = storage_layout_overrides[node.target.id]["slot"] - storage_length = varinfo.typ.storage_size_in_words - # Ensure that all required storage slots are reserved, and prevents other variables - # from using these slots - reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(VarOffset(var_slot)) - - ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {node.target.id}. " - "Have you used the correct storage layout file?", - node, - ) + varname = node.target.id + varpath = [*path, varname] + qualified_varname = ".".join(varpath) - return ret + var_slot = _fetch_path(varpath, layout, node) + + storage_length = varinfo.typ.storage_size_in_words + # Ensure that all required storage slots are reserved, and + # prevent other variables from using these slots + allocator.reserve_slot_range(var_slot, storage_length, qualified_varname) + varinfo.set_position(VarOffset(var_slot)) def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: @@ -229,7 +267,7 @@ def get_reentrancy_key_location() -> DataLocation: } -def _allocate_nonreentrant_keys(vyper_module, allocators): +def _set_nonreentrant_keys(vyper_module, allocators): SLOT = allocators.get_global_nonreentrant_key_slot() for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -244,73 +282,116 @@ def _allocate_nonreentrant_keys(vyper_module, allocators): def _allocate_layout_r( - vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False -) -> StorageLayout: + vyper_module: vy_ast.Module, allocators: Allocators = None, no_storage=False +): """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - global_ = False if allocators is None: - global_ = True allocators = Allocators() # always allocate nonreentrancy slot, so that adding or removing # reentrancy protection from a contract does not change its layout allocators.allocate_global_nonreentrancy_slot() - ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - # tag functions with the global nonreentrant key - if not immutables_only: - _allocate_nonreentrant_keys(vyper_module, allocators) - - layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] - # TODO this could have better typing but leave it untyped until - # we nail down the format better - if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: - slot = allocators.get_global_nonreentrant_key_slot() - ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} + if not no_storage or get_reentrancy_key_location() == DataLocation.TRANSIENT: + _set_nonreentrant_keys(vyper_module, allocators) for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): module_info = node._metadata["initializes_info"].module_info - module_layout = _allocate_layout_r(module_info.module_node, allocators) - module_alias = module_info.alias - for layout_key in module_layout.keys(): - assert layout_key in _LAYOUT_KEYS.values() - ret[layout_key][module_alias] = module_layout[layout_key] + _allocate_layout_r(module_info.module_node, allocators, no_storage) continue assert isinstance(node, vy_ast.VariableDecl) - # skip non-state variables varinfo = node.target._metadata["varinfo"] + + # skip things we don't need to allocate, like constants if not varinfo.is_state_variable(): continue - location = varinfo.location - if immutables_only and location != DataLocation.CODE: + if no_storage and varinfo.is_storage: continue - allocator = allocators.get_allocator(location) + allocator = allocators.get_allocator(varinfo.location) size = varinfo.get_size() # CMC 2021-07-23 note that HashMaps get assigned a slot here # using the same allocator (even though there is not really # any risk of physical overlap) - offset = allocator.allocate_slot(size, node.target.id, node) - + offset = allocator.allocate_slot(size, node) varinfo.set_position(VarOffset(offset)) + +# get the layout for export +def generate_layout_export(vyper_module: vy_ast.Module): + return _generate_layout_export_r(vyper_module) + + +def _generate_layout_export_r(vyper_module): + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _generate_layout_export_r(module_info.module_node) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + + # lift the nonreentrancy key (if any) into the outer dict + # note that lifting can leave the inner dict empty, which + # should be filtered (below) for cleanliness + nonreentrant = module_layout[layout_key].pop(GLOBAL_NONREENTRANT_KEY, None) + if nonreentrant is not None and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = nonreentrant + + # add the module as a nested dict, but only if it is non-empty + if len(module_layout[layout_key]) != 0: + ret[layout_key][module_alias] = module_layout[layout_key] + + continue + + assert isinstance(node, vy_ast.VariableDecl) + varinfo = node.target._metadata["varinfo"] + # skip non-state variables + if not varinfo.is_state_variable(): + continue + + location = varinfo.location layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ + size = varinfo.get_size() + offset = varinfo.position.position + # this could have better typing but leave it untyped until # we understand the use case better if location == DataLocation.CODE: item = {"type": str(type_), "length": size, "offset": offset} elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): - item = {"type": str(type_), "slot": offset} + item = {"type": str(type_), "n_slots": size, "slot": offset} else: # pragma: nocover raise CompilerPanic("unreachable") ret[layout_key][node.target.id] = item + for fn in vyper_module.get_children(vy_ast.FunctionDef): + fn_t = fn._metadata["func_type"] + if not fn_t.nonreentrant: + continue + + location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[location] + + if GLOBAL_NONREENTRANT_KEY in ret[layout_key]: + break + + slot = fn_t.reentrancy_key_position.position + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = { + "type": "nonreentrant lock", + "slot": slot, + "n_slots": NONREENTRANT_KEY_SIZE, + } + break + return ret