From 20432c505c706ed28d0fa5f743e33a9cfa2dd8c3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 2 Apr 2024 09:55:05 -0400 Subject: [PATCH 1/4] fix[lang]: pure access analysis (#3895) this commit fixes the ability to access module variables in pure functions. the `_validate_self_reference()` utility function was hard-coded to check the "self" name; remove it and replace with an analysis-based check. this commit also fixes the pure access check for immutable variables, and address members (e.g. `.codesize`) misc/refactor: * rename `VarInfo.is_module_variable()` to more fitting `VarInfo.is_state_variable()`. * refactor pure decorator tests to be in line with recent best practices --- .../codegen/features/decorators/test_pure.py | 154 ++++++++++++------ vyper/ast/nodes.pyi | 2 +- vyper/semantics/analysis/base.py | 8 +- vyper/semantics/analysis/data_positions.py | 4 +- vyper/semantics/analysis/local.py | 52 +++--- vyper/semantics/analysis/module.py | 2 +- 6 files changed, 146 insertions(+), 76 deletions(-) diff --git a/tests/functional/codegen/features/decorators/test_pure.py b/tests/functional/codegen/features/decorators/test_pure.py index 5f4c168687..7c49c2091b 100644 --- a/tests/functional/codegen/features/decorators/test_pure.py +++ b/tests/functional/codegen/features/decorators/test_pure.py @@ -1,21 +1,22 @@ +import pytest + +from vyper.compiler import compile_code from vyper.exceptions import FunctionDeclarationException, StateAccessViolation -def test_pure_operation(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_operation(get_contract): + code = """ @pure @external def foo() -> int128: return 5 """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_call(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_call(get_contract): + code = """ @pure @internal def _foo() -> int128: @@ -26,21 +27,18 @@ def _foo() -> int128: def foo() -> int128: return self._foo() """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_interface(get_contract_with_gas_estimation_for_constants): - c1 = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_interface(get_contract): + code1 = """ @pure @external def foo() -> int128: return 5 """ - ) - c2 = get_contract_with_gas_estimation_for_constants( - """ + code2 = """ interface Foo: def foo() -> int128: pure @@ -49,28 +47,35 @@ def foo() -> int128: pure def foo(a: address) -> int128: return staticcall Foo(a).foo() """ - ) + c1 = get_contract(code1) + c2 = get_contract(code2) assert c2.foo(c1.address) == 5 -def test_invalid_envar_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_envar_access(get_contract): + code = """ @pure @external def foo() -> uint256: return chain.id """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_codesize_access(get_contract): + code = """ +@pure +@external +def foo(s: address) -> uint256: + return s.codesize + """ + with pytest.raises(StateAccessViolation): + compile_code(code) def test_invalid_state_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ + code = """ x: uint256 @pure @@ -78,29 +83,84 @@ def test_invalid_state_access(get_contract, assert_compile_failed): def foo() -> uint256: return self.x """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_immutable_access(): + code = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 1234 + +@pure +@external +def foo() -> uint256: + return COUNTER + """ + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_self_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_self_access(): + code = """ @pure @external def foo() -> address: return self """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_module_variable_access(make_input_bundle): + lib1 = """ +counter: uint256 + """ + code = """ +import lib1 +initializes: lib1 + +@pure +@external +def foo() -> uint256: + return lib1.counter + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) + + +def test_invalid_module_immutable_access(make_input_bundle): + lib1 = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 123 + """ + code = """ +import lib1 +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + +@pure +@external +def foo() -> uint256: + return lib1.COUNTER + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) -def test_invalid_call(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_call(): + code = """ @view @internal def _foo() -> uint256: @@ -111,21 +171,17 @@ def _foo() -> uint256: def foo() -> uint256: return self._foo() # Fails because of calling non-pure fn """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_conflicting_decorators(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_conflicting_decorators(): + code = """ @pure @external @payable def foo() -> uint256: return 5 """ - ), - FunctionDeclarationException, - ) + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f673bb765c..fe01bf9260 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -141,7 +141,7 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class Name(VyperNode): +class Name(ExprNode): id: str = ... _type: str = ... diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e424f94e19..82b8dbe359 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -8,6 +8,7 @@ from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.semantics.types.primitives import SelfT from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: @@ -199,8 +200,11 @@ def set_position(self, position: VarOffset) -> None: assert isinstance(position, VarOffset) # sanity check self.position = position - def is_module_variable(self): - return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + def is_state_variable(self): + non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA) + # `self` gets a VarInfo, but it is not considered a state + # variable (it is magic), so we ignore it here. + return self.location not in non_state_locations and not isinstance(self.typ, SelfT) def get_size(self) -> int: return self.typ.get_size_in(self.location) diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index bb4322c7b2..e5e8b998ca 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -282,9 +282,9 @@ def _allocate_layout_r( continue assert isinstance(node, vy_ast.VariableDecl) - # skip non-storage variables + # skip non-state variables varinfo = node.target._metadata["varinfo"] - if not varinfo.is_module_variable(): + if not varinfo.is_state_variable(): continue location = varinfo.location diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 33dcce3645..1b2e3252c8 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -38,7 +38,7 @@ from vyper.semantics.data_locations import DataLocation # TODO consolidate some of these imports -from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS +from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( TYPE_T, @@ -51,6 +51,7 @@ HashMapT, IntegerT, SArrayT, + SelfT, StringT, StructT, TupleT, @@ -164,21 +165,32 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None: raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) -def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: - env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) - if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: - if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: - return +def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None: + info = get_expr_info(node) - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", node - ) + env_vars = CONSTANT_ENVIRONMENT_VARS + # check env variable access like `block.number` + if isinstance(node, vy_ast.Attribute): + if node.get("value.id") in env_vars: + raise StateAccessViolation( + "not allowed to query environment variables in pure functions" + ) + parent_info = get_expr_info(node.value) + if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members: + raise StateAccessViolation("not allowed to query address members in pure functions") + if (varinfo := info.var_info) is None: + return + # self is magic. we only need to check it if it is not the root of an Attribute + # node. (i.e. it is bare like `self`, not `self.foo`) + is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance( + node.get_ancestor(), vy_ast.Attribute + ) + if is_naked_self: + raise StateAccessViolation("not allowed to query `self` in pure functions") -def _validate_self_reference(node: vy_ast.Name) -> None: - # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through - if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): - raise StateAccessViolation("not allowed to query self in pure functions", node) + if varinfo.is_state_variable() or is_naked_self: + raise StateAccessViolation("not allowed to query state variables in pure functions") # analyse the variable access for the attribute chain for a node @@ -429,7 +441,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode): - if not var_access.variable.is_module_variable(): + if not var_access.variable.is_state_variable(): return root_module_info = check_module_uses(target) @@ -693,7 +705,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_value_access(node) if self.func and self.func.mutability == StateMutability.PURE: - _validate_pure_access(node, typ) + _validate_pure_access(node) value_type = get_exact_type_from_node(node.value) @@ -765,10 +777,10 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if not func_type.from_interface: for s in func_type.get_variable_writes(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._writes.add(s) for s in func_type.get_variable_reads(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._reads.add(s) if self.function_analyzer: @@ -873,10 +885,8 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func: - # TODO: refactor to use expr_info mutability - if self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func and self.func.mutability == StateMutability.PURE: + _validate_pure_access(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index b4f4381444..84199ec82c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -536,7 +536,7 @@ def visit_ExportsDecl(self, node): # check module uses var_accesses = func_t.get_variable_accesses() - if any(s.variable.is_module_variable() for s in var_accesses): + if any(s.variable.is_state_variable() for s in var_accesses): module_info = check_module_uses(item) assert module_info is not None # guaranteed by above checks used_modules.add(module_info) From 5bdd174959f1aadb1138fdf0d44c44c0e36a4f28 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 2 Apr 2024 13:18:37 -0400 Subject: [PATCH 2/4] fix[ci]: pin hexbytes to pre-1.0.0 (#3903) hexbytes introduces new behavior how it interacts with strings/bytes. this commit pins hexbytes as a stopgap solution until we update tests. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index b403405122..933e8bfa4b 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "hypothesis[lark]>=6.0,<7.0", "eth-stdlib==0.2.7", "setuptools", + "hexbytes<1.0", "typing_extensions", # we can remove this once dependencies upgrade to eth-rlp>=2.0 ], "lint": [ From 91ef0dd130de91081f044541c87f20d901d87d89 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Tue, 2 Apr 2024 20:42:49 +0300 Subject: [PATCH 3/4] feat[ir]: add `make_ssa` pass to venom pipeline (#3825) this commit updates the venom pipeline to be capable of translating 100% of the original vyper IR, and successfully passes the entire test suite. to accomplish this, the translation pass from the old IR to Venom is simplified, moving several optimization and analysis steps to separate passes. the most significant of these is the `make_ssa` pass, which converts any Venom code into SSA form, therefore letting us write non-SSA code in the translation pass, simplifying the translation. to support the `make_ssa` pass, this commit also adds a dominator tree implementation, along with implementations of dominance frontier and other utility functions. these should also be useful for additional passes that will be contributed in the future. to facilitate the development process, this commit also adds two more output formats: `cfg` and `cfg_runtime`, which provide a graph representation of the Venom code. --------- Co-authored-by: Charles Cooper --- .github/workflows/test.yml | 19 +- setup.cfg | 1 + tests/conftest.py | 77 +- .../builtins/codegen/test_abi_decode.py | 4 +- .../builtins/codegen/test_abi_encode.py | 4 + .../codegen/features/test_clampers.py | 2 + .../codegen/features/test_constructor.py | 4 + .../codegen/features/test_immutable.py | 2 + .../codegen/types/test_dynamic_array.py | 9 +- .../examples/factory/test_factory.py | 6 +- tests/functional/syntax/test_address_code.py | 4 +- tests/functional/syntax/test_codehash.py | 4 +- .../unit/cli/vyper_json/test_compile_json.py | 4 +- tests/unit/compiler/asm/test_asm_optimizer.py | 3 +- .../compiler/venom/test_dominator_tree.py | 73 ++ .../compiler/venom/test_duplicate_operands.py | 3 +- .../venom/test_liveness_simple_loop.py | 16 + tests/unit/compiler/venom/test_make_ssa.py | 48 + .../compiler/venom/test_multi_entry_block.py | 18 +- tests/unit/compiler/venom/test_variables.py | 8 + vyper/ast/pre_parser.py | 6 + .../function_definitions/external_function.py | 4 - .../function_definitions/internal_function.py | 6 +- vyper/codegen/return_.py | 1 - vyper/codegen/self_call.py | 2 - vyper/compiler/__init__.py | 2 + vyper/compiler/output.py | 8 + vyper/compiler/phases.py | 2 - vyper/exceptions.py | 4 + vyper/ir/compile_ir.py | 6 +- vyper/utils.py | 71 +- vyper/venom/__init__.py | 23 +- vyper/venom/analysis.py | 75 +- vyper/venom/basicblock.py | 173 +++- vyper/venom/bb_optimizer.py | 19 + vyper/venom/dominators.py | 166 ++++ vyper/venom/function.py | 160 ++- vyper/venom/ir_node_to_venom.py | 919 ++++++------------ vyper/venom/passes/base_pass.py | 4 +- vyper/venom/passes/dft.py | 48 +- vyper/venom/passes/make_ssa.py | 174 ++++ vyper/venom/passes/normalization.py | 10 +- vyper/venom/passes/simplify_cfg.py | 82 ++ vyper/venom/passes/stack_reorder.py | 24 + vyper/venom/stack_model.py | 20 +- vyper/venom/venom_to_assembly.py | 215 +++- 46 files changed, 1684 insertions(+), 849 deletions(-) create mode 100644 tests/unit/compiler/venom/test_dominator_tree.py create mode 100644 tests/unit/compiler/venom/test_liveness_simple_loop.py create mode 100644 tests/unit/compiler/venom/test_make_ssa.py create mode 100644 tests/unit/compiler/venom/test_variables.py create mode 100644 vyper/venom/dominators.py create mode 100644 vyper/venom/passes/make_ssa.py create mode 100644 vyper/venom/passes/simplify_cfg.py create mode 100644 vyper/venom/passes/stack_reorder.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8bd03de79b..10312413e9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,8 +67,9 @@ jobs: matrix: python-version: [["3.11", "311"]] opt-mode: ["gas", "none", "codesize"] - evm-version: [shanghai] debug: [true, false] + evm-version: [shanghai] + experimental-codegen: [false] memorymock: [false] # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs#expanding-or-adding-matrix-configurations @@ -94,6 +95,14 @@ jobs: opt-mode: gas evm-version: cancun + # test experimental pipeline + - python-version: ["3.11", "311"] + opt-mode: gas + debug: false + evm-version: shanghai + experimental-codegen: true + # TODO: test experimental_codegen + -Ocodesize + # run with `--memorymock`, but only need to do it one configuration # TODO: consider removing the memorymock tests - python-version: ["3.11", "311"] @@ -108,12 +117,14 @@ jobs: opt-mode: gas debug: false evm-version: shanghai + - python-version: ["3.12", "312"] opt-mode: gas debug: false evm-version: shanghai - name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}-${{ matrix.evm-version }} + + name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}${{ matrix.experimental-codegen && '-experimental' || '' }}-${{ matrix.evm-version }} steps: - uses: actions/checkout@v4 @@ -141,6 +152,7 @@ jobs: --evm-version ${{ matrix.evm-version }} \ ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} \ ${{ matrix.memorymock && '--memorymock' || '' }} \ + ${{ matrix.experimental-codegen && '--experimental-codegen' || '' }} \ --cov-branch \ --cov-report xml:coverage.xml \ --cov=vyper \ @@ -193,8 +205,7 @@ jobs: # NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo. # `pytest -m "fuzzing" --store-durations -r aR tests/` - name: Fetch test-durations - run: | - curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations + run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations - name: Run tests run: | diff --git a/setup.cfg b/setup.cfg index 467c6a372b..f84c947981 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ testpaths = tests xfail_strict = true markers = fuzzing: Run Hypothesis fuzz test suite (deselect with '-m "not fuzzing"') + venom_xfail: mark a test case as a regression (expected to fail) under the venom pipeline diff --git a/tests/conftest.py b/tests/conftest.py index 2e5f11b9b8..d0681cdf42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ def pytest_addoption(parser): help="change optimization mode", ) parser.addoption("--enable-compiler-debug-mode", action="store_true") + parser.addoption("--experimental-codegen", action="store_true") parser.addoption( "--evm-version", @@ -73,6 +74,8 @@ def output_formats(): output_formats = compiler.OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] return output_formats @@ -89,6 +92,36 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture(scope="session") +def experimental_codegen(pytestconfig): + ret = pytestconfig.getoption("experimental_codegen") + assert isinstance(ret, bool) + return ret + + +@pytest.fixture(autouse=True) +def check_venom_xfail(request, experimental_codegen): + if not experimental_codegen: + return + + marker = request.node.get_closest_marker("venom_xfail") + if marker is None: + return + + # https://github.com/okken/pytest-runtime-xfail?tab=readme-ov-file#alternatives + request.node.add_marker(pytest.mark.xfail(strict=True, **marker.kwargs)) + + +@pytest.fixture +def venom_xfail(request, experimental_codegen): + def _xfail(*args, **kwargs): + if not experimental_codegen: + return + request.node.add_marker(pytest.mark.xfail(*args, strict=True, **kwargs)) + + return _xfail + + @pytest.fixture(scope="session", autouse=True) def evm_version(pytestconfig): # note: we configure the evm version that we emit code for, @@ -108,6 +141,7 @@ def chdir_tmp_path(tmp_path): yield +# CMC 2024-03-01 this doesn't need to be a fixture @pytest.fixture def keccak(): return Web3.keccak @@ -321,6 +355,7 @@ def _get_contract( w3, source_code, optimize, + experimental_codegen, output_formats, *args, override_opt_level=None, @@ -329,6 +364,7 @@ def _get_contract( ): settings = Settings() settings.optimize = override_opt_level or optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, # test that all output formats can get generated @@ -352,17 +388,21 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize, output_formats): +def get_contract(w3, optimize, experimental_codegen, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): +def get_contract_with_gas_estimation(tester, w3, optimize, experimental_codegen, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + contract = _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -372,15 +412,19 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): +def get_contract_with_gas_estimation_for_constants( + w3, optimize, experimental_codegen, output_formats +): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize, output_formats): +def get_contract_module(optimize, experimental_codegen, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -393,16 +437,25 @@ def get_contract_module(optimize, output_formats): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_module def _deploy_blueprint_for( - w3, source_code, optimize, output_formats, initcode_prefix=ERC5202_PREFIX, **kwargs + w3, + source_code, + optimize, + experimental_codegen, + output_formats, + initcode_prefix=ERC5202_PREFIX, + **kwargs, ): settings = Settings() settings.optimize = optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, output_formats=output_formats, @@ -438,9 +491,11 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize, output_formats): +def deploy_blueprint_for(w3, optimize, experimental_codegen, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) + return _deploy_blueprint_for( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return deploy_blueprint_for diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index d281851f8e..213738957b 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -3,7 +3,7 @@ import pytest from eth.codecs import abi -from vyper.exceptions import ArgumentException, StructureException +from vyper.exceptions import ArgumentException, StackTooDeep, StructureException TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() @@ -196,6 +196,7 @@ def abi_decode(x: Bytes[{len}]) -> DynArray[DynArray[uint256, 3], 3]: @pytest.mark.parametrize("args", nested_3d_array_args) @pytest.mark.parametrize("unwrap_tuple", (True, False)) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_nested_dynarray2(get_contract, args, unwrap_tuple): if unwrap_tuple is True: encoded = abi.encode("(uint256[][][])", (args,)) @@ -273,6 +274,7 @@ def foo(bs: Bytes[160]) -> (uint256, DynArray[uint256, 3]): assert c.foo(encoded) == [2**256 - 1, bs] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_private_nested_dynarray(get_contract): code = """ bytez: DynArray[DynArray[DynArray[uint256, 3], 3], 3] diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f014c47a19..305c4b1356 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -3,6 +3,8 @@ import pytest from eth.codecs import abi +from vyper.exceptions import StackTooDeep + # @pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) def test_abi_encode(get_contract): @@ -226,6 +228,7 @@ def abi_encode( @pytest.mark.parametrize("args", nested_3d_array_args) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_nested_dynarray_2(get_contract, args): code = """ @external @@ -330,6 +333,7 @@ def foo(bs: DynArray[uint256, 3]) -> (uint256, Bytes[160]): assert c.foo(bs) == [2**256 - 1, abi.encode("(uint256[])", (bs,))] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_private_nested_dynarray(get_contract): code = """ bytez: Bytes[1696] diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 578413a8f4..fe51c026fe 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -4,6 +4,7 @@ from eth.codecs import abi from eth_utils import keccak +from vyper.exceptions import StackTooDeep from vyper.utils import int_bounds @@ -506,6 +507,7 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_multidimension_dynarray_clamper_passing(w3, get_contract, value): code = """ @external diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index 9146ace8a6..d96a889497 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -1,6 +1,8 @@ import pytest from web3.exceptions import ValidationError +from vyper.exceptions import StackTooDeep + def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ @@ -163,6 +165,7 @@ def get_foo() -> uint256: assert c.get_foo() == 39 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estimation): code = """ foo: int128 @@ -208,6 +211,7 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]: assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 49ff54b353..874600633a 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -1,6 +1,7 @@ import pytest from vyper.compiler.settings import OptimizationLevel +from vyper.exceptions import StackTooDeep @pytest.mark.parametrize( @@ -198,6 +199,7 @@ def get_idx_two() -> uint256: assert c.get_idx_two() == expected_values[2][2] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index b55f07639b..efa2799480 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -8,6 +8,7 @@ ArrayIndexException, ImmutableViolation, OverflowException, + StackTooDeep, StateAccessViolation, TypeMismatch, ) @@ -60,6 +61,7 @@ def loo(x: DynArray[DynArray[int128, 2], 2]) -> int128: print("Passed list tests") +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_string_list(get_contract): code = """ @external @@ -732,6 +734,7 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]: assert c.test_array_decimal_return3() == [[1.0, 2.0], [3.0]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_mult_list(get_contract_with_gas_estimation): code = """ nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2] @@ -1478,6 +1481,7 @@ def foo(x: int128) -> int128: assert c.foo(7) == 392 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_struct_of_lists(get_contract): code = """ struct Foo: @@ -1566,6 +1570,7 @@ def bar(x: int128) -> DynArray[int128, 3]: assert c.bar(7) == [7, 14] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): code = """ struct nestedFoo: @@ -1695,7 +1700,9 @@ def __init__(): ("DynArray[DynArray[DynArray[uint256, 5], 5], 5]", [[[], []], []]), ], ) -def test_empty_nested_dynarray(get_contract, typ, val): +def test_empty_nested_dynarray(get_contract, typ, val, venom_xfail): + if val == [[[], []], []]: + venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") code = f""" @external def foo() -> {typ}: diff --git a/tests/functional/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py index 0c5cf61b04..18f6222c20 100644 --- a/tests/functional/examples/factory/test_factory.py +++ b/tests/functional/examples/factory/test_factory.py @@ -31,12 +31,14 @@ def create_exchange(token, factory): @pytest.fixture -def factory(get_contract, optimize): +def factory(get_contract, optimize, experimental_codegen): with open("examples/factory/Exchange.vy") as f: code = f.read() exchange_interface = vyper.compile_code( - code, output_formats=["bytecode_runtime"], settings=Settings(optimize=optimize) + code, + output_formats=["bytecode_runtime"], + settings=Settings(optimize=optimize, experimental_codegen=experimental_codegen), ) exchange_deployed_bytecode = exchange_interface["bytecode_runtime"] diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index 6556fc90b9..6be50a509b 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -161,7 +161,7 @@ def test_address_code_compile_success(code: str): compiler.compile_code(code) -def test_address_code_self_success(get_contract, optimize): +def test_address_code_self_success(get_contract, optimize, experimental_codegen): code = """ code_deployment: public(Bytes[32]) @@ -174,7 +174,7 @@ def code_runtime() -> Bytes[32]: return slice(self.code, 0, 32) """ contract = get_contract(code) - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) code_compiled = compiler.compile_code( code, output_formats=["bytecode", "bytecode_runtime"], settings=settings ) diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index d351981946..7aa01a68e9 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -3,7 +3,7 @@ from vyper.utils import keccak256 -def test_get_extcodehash(get_contract, optimize): +def test_get_extcodehash(get_contract, optimize, experimental_codegen): code = """ a: address @@ -28,7 +28,7 @@ def foo3() -> bytes32: def foo4() -> bytes32: return self.a.codehash """ - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) compiled = compile_code(code, output_formats=["bytecode_runtime"], settings=settings) bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:]) hash_ = keccak256(bytecode) diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 4fe2111f43..62a799db65 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,11 +113,13 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") - # remove bb and bb_runtime from output formats + # remove venom related from output formats # because they require venom (experimental) output_formats = OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] foo = compile_from_file_input( foo_input, output_formats=output_formats, input_bundle=input_bundle ) diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index ce32249202..5742f7c8df 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -95,7 +95,7 @@ def test_dead_code_eliminator(code): assert all(ctor_only not in instr for instr in runtime_asm) -def test_library_code_eliminator(make_input_bundle): +def test_library_code_eliminator(make_input_bundle, experimental_codegen): library = """ @internal def unused1(): @@ -120,5 +120,6 @@ def foo(): res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) asm = res["asm"] assert "some_function()" in asm + assert "unused1()" not in asm assert "unused2()" not in asm diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py new file mode 100644 index 0000000000..dc27380796 --- /dev/null +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -0,0 +1,73 @@ +from typing import Optional + +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def _add_bb( + ctx: IRFunction, label: IRLabel, cfg_outs: [IRLabel], bb: Optional[IRBasicBlock] = None +) -> IRBasicBlock: + bb = bb if bb is not None else IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + cfg_outs_len = len(cfg_outs) + if cfg_outs_len == 0: + bb.append_instruction("stop") + elif cfg_outs_len == 1: + bb.append_instruction("jmp", cfg_outs[0]) + elif cfg_outs_len == 2: + bb.append_instruction("jnz", IRLiteral(1), cfg_outs[0], cfg_outs[1]) + else: + raise CompilerPanic("Invalid number of CFG outs") + return bb + + +def _make_test_ctx(): + lab = [IRLabel(str(i)) for i in range(0, 9)] + + ctx = IRFunction(lab[1]) + + bb1 = ctx.basic_blocks[0] + bb1.append_instruction("jmp", lab[2]) + + _add_bb(ctx, lab[7], []) + _add_bb(ctx, lab[6], [lab[7], lab[2]]) + _add_bb(ctx, lab[5], [lab[6], lab[3]]) + _add_bb(ctx, lab[4], [lab[6]]) + _add_bb(ctx, lab[3], [lab[5]]) + _add_bb(ctx, lab[2], [lab[3], lab[4]]) + + return ctx + + +def test_deminator_frontier_calculation(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + calculate_cfg(ctx) + dom = DominatorTree.build_dominator_tree(ctx, bb1) + df = dom.dominator_frontiers + + assert len(df[bb1]) == 0, df[bb1] + assert df[bb2] == OrderedSet({bb2}), df[bb2] + assert df[bb3] == OrderedSet({bb3, bb6}), df[bb3] + assert df[bb4] == OrderedSet({bb6}), df[bb4] + assert df[bb5] == OrderedSet({bb3, bb6}), df[bb5] + assert df[bb6] == OrderedSet({bb2}), df[bb6] + assert len(df[bb7]) == 0, df[bb7] + + +def test_phi_placement(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + x = IRVariable("%x") + bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], x), 0) + bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], x), 0) + bb7.insert_instruction(IRInstruction("mstore", [x, IRLiteral(0)]), 0) + + MakeSSA.run_pass(ctx, bb1) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 437185cc72..7cc58e6f5c 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -23,5 +23,4 @@ def test_duplicate_operands(): bb.append_instruction("stop") asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) - - assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP"] + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "ADD", "MUL", "STOP"] diff --git a/tests/unit/compiler/venom/test_liveness_simple_loop.py b/tests/unit/compiler/venom/test_liveness_simple_loop.py new file mode 100644 index 0000000000..e725518179 --- /dev/null +++ b/tests/unit/compiler/venom/test_liveness_simple_loop.py @@ -0,0 +1,16 @@ +import vyper +from vyper.compiler.settings import Settings + +source = """ +@external +def foo(a: uint256): + _numBids: uint256 = 20 + b: uint256 = 10 + + for i: uint256 in range(128): + b = 1 + _numBids +""" + + +def test_liveness_simple_loop(): + vyper.compile_code(source, ["opcodes"], settings=Settings(experimental_codegen=True)) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py new file mode 100644 index 0000000000..2a04dfc134 --- /dev/null +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -0,0 +1,48 @@ +from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def test_phi_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + bb_cont = IRBasicBlock(IRLabel("condition"), ctx) + bb_then = IRBasicBlock(IRLabel("then"), ctx) + bb_else = IRBasicBlock(IRLabel("else"), ctx) + bb_if_exit = IRBasicBlock(IRLabel("if_exit"), ctx) + ctx.append_basic_block(bb_cont) + ctx.append_basic_block(bb_then) + ctx.append_basic_block(bb_else) + ctx.append_basic_block(bb_if_exit) + + v = bb.append_instruction("mload", 64) + bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) + + bb_if_exit.append_instruction("add", v, 1, ret=v) + bb_if_exit.append_instruction("jmp", bb_cont.label) + + bb_then.append_instruction("assert", bb_then.append_instruction("mload", 96)) + bb_then.append_instruction("jmp", bb_if_exit.label) + bb_else.append_instruction("jmp", bb_if_exit.label) + + bb.append_instruction("jmp", bb_cont.label) + + calculate_cfg(ctx) + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + calculate_liveness(ctx) + + condition_block = ctx.get_basic_block("condition") + assert len(condition_block.instructions) == 2 + + phi_inst = condition_block.instructions[0] + assert phi_inst.opcode == "phi" + assert phi_inst.operands[0].name == "_global" + assert phi_inst.operands[1].name == "%1" + assert phi_inst.operands[2].name == "if_exit" + assert phi_inst.operands[3].name == "%1" + assert phi_inst.output.name == "%1" + assert phi_inst.output.value != phi_inst.operands[1].value + assert phi_inst.output.value != phi_inst.operands[3].value diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 6d8b074994..47f4b88707 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -39,10 +39,10 @@ def test_multi_entry_block_1(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" # more complicated one @@ -91,10 +91,10 @@ def test_multi_entry_block_2(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" def test_multi_entry_block_with_dynamic_jump(): @@ -132,7 +132,7 @@ def test_multi_entry_block_with_dynamic_jump(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" diff --git a/tests/unit/compiler/venom/test_variables.py b/tests/unit/compiler/venom/test_variables.py new file mode 100644 index 0000000000..cded8d0e1a --- /dev/null +++ b/tests/unit/compiler/venom/test_variables.py @@ -0,0 +1,8 @@ +from vyper.venom.basicblock import IRVariable + + +def test_variable_equality(): + v1 = IRVariable("%x") + v2 = IRVariable("%x") + assert v1 == v2 + assert v1 != IRVariable("%y") diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 227b639ad5..f0c339cca7 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -203,6 +203,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: if evm_version not in EVM_VERSIONS: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + elif pragma.startswith("experimental-codegen"): + if settings.experimental_codegen is not None: + raise StructureException( + "pragma experimental-codegen specified twice!", start + ) + settings.experimental_codegen = True else: raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 6f783bb9c5..fe706699bb 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -154,10 +154,6 @@ def _adjust_gas_estimate(func_t, common_ir): common_ir.add_gas_estimate += mem_expansion_cost func_t._ir_info.gas_estimate = common_ir.gas - # pass metadata through for venom pipeline: - common_ir.passthrough_metadata["func_t"] = func_t - common_ir.passthrough_metadata["frame_info"] = frame_info - def generate_ir_for_external_function(code, compilation_target): # TODO type hints: diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 0cf9850b70..cde1ec5c87 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -80,10 +80,6 @@ def generate_ir_for_internal_function( # tag gas estimate and frame info func_t._ir_info.gas_estimate = ir_node.gas - frame_info = tag_frame_info(func_t, context) - - # pass metadata through for venom pipeline: - ir_node.passthrough_metadata["frame_info"] = frame_info - ir_node.passthrough_metadata["func_t"] = func_t + tag_frame_info(func_t, context) return InternalFuncIR(ir_node) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index da585ff0a1..a8dac640db 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -42,7 +42,6 @@ def finalize(fill_return_buffer): # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break jump_to_exit_ir = IRnode.from_list(jump_to_exit) - jump_to_exit_ir.passthrough_metadata["func_t"] = func_t return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f53e4a81b4..2363de3641 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -112,6 +112,4 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True - o.passthrough_metadata["func_t"] = func_t - o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index ee909a57d4..84aea73071 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -25,6 +25,8 @@ "interface": output.build_interface_output, "bb": output.build_bb_output, "bb_runtime": output.build_bb_runtime_output, + "cfg": output.build_cfg_output, + "cfg_runtime": output.build_cfg_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index de8e34370d..f8beb9d11b 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -90,6 +90,14 @@ def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: return compiler_data.venom_functions[1] +def build_cfg_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[0].as_graph() + + +def build_cfg_runtime_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[1].as_graph() + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index e343938021..d794185195 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -97,8 +97,6 @@ def __init__( no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False """ - # to force experimental codegen, uncomment: - # settings.experimental_codegen = True if isinstance(file_input, str): file_input = FileInput( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3897f0ea41..996a1ddbd9 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -386,6 +386,10 @@ class CodegenPanic(VyperInternalException): """Invalid code generated during codegen phase""" +class StackTooDeep(CodegenPanic): + """Stack too deep""" # (should not happen) + + class UnexpectedNodeType(VyperInternalException): """Unexpected AST node type.""" diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index e4a4cc60f7..191803295e 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -1020,7 +1020,11 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue - if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + if ( + isinstance(assembly[i], str) + and assembly[i].startswith("SWAP") + and assembly[i] == assembly[i + 1] + ): changed = True del assembly[i : i + 2] if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: diff --git a/vyper/utils.py b/vyper/utils.py index ba615e58d7..114ddf97c2 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -15,7 +15,7 @@ _T = TypeVar("_T") -class OrderedSet(Generic[_T], dict[_T, None]): +class OrderedSet(Generic[_T]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -25,57 +25,82 @@ class OrderedSet(Generic[_T], dict[_T, None]): """ def __init__(self, iterable=None): - super().__init__() + self._data = dict() if iterable is not None: - for item in iterable: - self.add(item) + self.update(iterable) def __repr__(self): - keys = ", ".join(repr(k) for k in self.keys()) + keys = ", ".join(repr(k) for k in self) return f"{{{keys}}}" - def get(self, *args, **kwargs): - raise RuntimeError("can't call get() on OrderedSet!") + def __iter__(self): + return iter(self._data) + + def __contains__(self, item): + return self._data.__contains__(item) + + def __len__(self): + return len(self._data) def first(self): return next(iter(self)) def add(self, item: _T) -> None: - self[item] = None + self._data[item] = None def remove(self, item: _T) -> None: - del self[item] + del self._data[item] + + def drop(self, item: _T): + # friendly version of remove + self._data.pop(item, None) + + def dropmany(self, iterable): + for item in iterable: + self._data.pop(item, None) def difference(self, other): ret = self.copy() - for k in other.keys(): - if k in ret: - ret.remove(k) + ret.dropmany(other) return ret + def update(self, other): + # CMC 2024-03-22 for some reason, this is faster than dict.update? + # (maybe size dependent) + for item in other: + self._data[item] = None + def union(self, other): return self | other - def update(self, other): - super().update(self.__class__.fromkeys(other)) + def __ior__(self, other): + self.update(other) + return self def __or__(self, other): - return self.__class__(super().__or__(other)) + ret = self.copy() + ret.update(other) + return ret + + def __eq__(self, other): + return self._data == other._data def copy(self): - return self.__class__(super().copy()) + cls = self.__class__ + ret = cls.__new__(cls) + ret._data = self._data.copy() + return ret @classmethod def intersection(cls, *sets): - res = OrderedSet() if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - if len(sets) == 1: - return sets[0].copy() - for e in sets[0].keys(): - if all(e in s for s in sets[1:]): - res.add(e) - return res + + ret = sets[0].copy() + for e in sets[0]: + if any(e not in s for s in sets[1:]): + ret.remove(e) + return ret class StringEnum(enum.Enum): diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index d1c2d0c342..2efd58ad6c 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -11,10 +11,14 @@ ir_pass_optimize_unused_variables, ir_pass_remove_unreachable_blocks, ) +from vyper.venom.dominators import DominatorTree from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation from vyper.venom.passes.dft import DFTPass +from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes.simplify_cfg import SimplifyCFGPass from vyper.venom.venom_to_assembly import VenomCompiler DEFAULT_OPT_LEVEL = OptimizationLevel.default() @@ -38,6 +42,24 @@ def generate_assembly_experimental( def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: # Run passes on Venom IR # TODO: Add support for optimization levels + + ir_pass_optimize_empty_blocks(ctx) + ir_pass_remove_unreachable_blocks(ctx) + + internals = [ + bb + for bb in ctx.basic_blocks + if bb.label.value.startswith("internal") and len(bb.cfg_in) == 0 + ] + + SimplifyCFGPass.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + SimplifyCFGPass.run_pass(ctx, entry) + + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + MakeSSA.run_pass(ctx, entry) + while True: changes = 0 @@ -51,7 +73,6 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: calculate_cfg(ctx) calculate_liveness(ctx) - changes += ir_pass_constant_propagation(ctx) changes += DFTPass.run_pass(ctx) calculate_cfg(ctx) diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index daebd2560c..066a60f45e 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -1,3 +1,5 @@ +from typing import Optional + from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet from vyper.venom.basicblock import ( @@ -38,6 +40,7 @@ def calculate_cfg(ctx: IRFunction) -> None: def _reset_liveness(ctx: IRFunction) -> None: for bb in ctx.basic_blocks: + bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() @@ -50,16 +53,15 @@ def _calculate_liveness(bb: IRBasicBlock) -> bool: orig_liveness = bb.instructions[0].liveness.copy() liveness = bb.out_vars.copy() for instruction in reversed(bb.instructions): - ops = instruction.get_inputs() + ins = instruction.get_inputs() + outs = instruction.get_outputs() - for op in ops: - if op in liveness: - instruction.dup_requirements.add(op) + if ins or outs: + # perf: only copy if changed + liveness = liveness.copy() + liveness.update(ins) + liveness.dropmany(outs) - liveness = liveness.union(OrderedSet.fromkeys(ops)) - out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None - if out in liveness: - liveness.remove(out) instruction.liveness = liveness return orig_liveness != bb.instructions[0].liveness @@ -89,6 +91,18 @@ def calculate_liveness(ctx: IRFunction) -> None: break +def calculate_dup_requirements(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + last_liveness = bb.out_vars + for inst in reversed(bb.instructions): + inst.dup_requirements = OrderedSet() + ops = inst.get_inputs() + for op in ops: + if op in last_liveness: + inst.dup_requirements.add(op) + last_liveness = inst.liveness + + # calculate the input variables into self from source def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: liveness = target.instructions[0].liveness.copy() @@ -104,19 +118,17 @@ def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IR # will arbitrarily choose either %12 or %14 to be in the liveness # set, and then during instruction selection, after this instruction, # %12 will be replaced by %56 in the liveness set - source1, source2 = inst.operands[0], inst.operands[2] - phi1, phi2 = inst.operands[1], inst.operands[3] - if source.label == source1: - liveness.add(phi1) - if phi2 in liveness: - liveness.remove(phi2) - elif source.label == source2: - liveness.add(phi2) - if phi1 in liveness: - liveness.remove(phi1) - else: - # bad path into this phi node - raise CompilerPanic(f"unreachable: {inst}") + + # bad path into this phi node + if source.label not in inst.operands: + raise CompilerPanic(f"unreachable: {inst} from {source.label}") + + for label, var in inst.phi_operands: + if label == source.label: + liveness.add(var) + else: + if var in liveness: + liveness.remove(var) return liveness @@ -137,8 +149,8 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]: return self._dfg_inputs.get(op, []) # the instruction which produces this variable. - def get_producing_instruction(self, op: IRVariable) -> IRInstruction: - return self._dfg_outputs[op] + def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: + return self._dfg_outputs.get(op) @classmethod def build_dfg(cls, ctx: IRFunction) -> "DFG": @@ -163,3 +175,20 @@ def build_dfg(cls, ctx: IRFunction) -> "DFG": dfg._dfg_outputs[op] = inst return dfg + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dfg + """ + lines = ["digraph dfg_graph {"] + for var, inputs in self._dfg_inputs.items(): + for input in inputs: + for op in input.get_outputs(): + if isinstance(op, IRVariable): + lines.append(f' " {var.name} " -> " {op.name} "') + + lines.append("}") + return "\n".join(lines) + + def __repr__(self) -> str: + return self.as_graph() diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index ed70a5eaa0..6c509d8f95 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,10 +1,9 @@ -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -12,15 +11,22 @@ "alloca", "call", "staticcall", + "delegatecall", "invoke", "sload", "sstore", "iload", "istore", + "tload", + "tstore", "assert", + "assert_unreachable", "mstore", "mload", "calldatacopy", + "mcopy", + "extcodecopy", + "returndatacopy", "codecopy", "dloadbytes", "dload", @@ -39,11 +45,15 @@ "istore", "dloadbytes", "calldatacopy", + "mcopy", + "returndatacopy", "codecopy", + "extcodecopy", "return", "ret", "revert", "assert", + "assert_unreachable", "selfdestruct", "stop", "invalid", @@ -52,6 +62,7 @@ "djmp", "jnz", "log", + "exit", ] ) @@ -87,6 +98,10 @@ class IROperand: value: Any + @property + def name(self) -> str: + return self.value + class IRValue(IROperand): """ @@ -109,13 +124,16 @@ def __init__(self, value: int) -> None: assert isinstance(value, int), "value must be an int" self.value = value - def __repr__(self) -> str: - return str(self.value) + def __hash__(self) -> int: + return self.value.__hash__() + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value -class MemType(Enum): - OPERAND_STACK = auto() - MEMORY = auto() + def __repr__(self) -> str: + return str(self.value) class IRVariable(IRValue): @@ -126,18 +144,34 @@ class IRVariable(IRValue): value: str offset: int = 0 - # some variables can be in memory for conversion from legacy IR to venom - mem_type: MemType = MemType.OPERAND_STACK - mem_addr: Optional[int] = None - - def __init__( - self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None - ) -> None: + def __init__(self, value: str, version: Optional[str | int] = None) -> None: assert isinstance(value, str) + assert ":" not in value, "Variable name cannot contain ':'" + if version: + assert isinstance(value, str) or isinstance(value, int), "value must be an str or int" + value = f"{value}:{version}" + if value[0] != "%": + value = f"%{value}" self.value = value self.offset = 0 - self.mem_type = mem_type - self.mem_addr = mem_addr + + @property + def name(self) -> str: + return self.value.split(":")[0] + + @property + def version(self) -> int: + if ":" not in self.value: + return 0 + return int(self.value.split(":")[1]) + + def __hash__(self) -> int: + return self.value.__hash__() + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value def __repr__(self) -> str: return self.value @@ -158,6 +192,14 @@ def __init__(self, value: str, is_symbol: bool = False) -> None: self.value = value self.is_symbol = is_symbol + def __hash__(self) -> int: + return hash(self.value) + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value + def __repr__(self) -> str: return self.value @@ -182,6 +224,8 @@ class IRInstruction: parent: Optional["IRBasicBlock"] fence_id: int annotation: Optional[str] + ast_source: Optional[int] + error_msg: Optional[str] def __init__( self, @@ -200,6 +244,8 @@ def __init__( self.parent = None self.fence_id = -1 self.annotation = None + self.ast_source = None + self.error_msg = None def get_label_operands(self) -> list[IRLabel]: """ @@ -246,22 +292,37 @@ def replace_label_operands(self, replacements: dict) -> None: if isinstance(operand, IRLabel) and operand.value in replacements: self.operands[i] = replacements[operand.value] + @property + def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]: + """ + Get phi operands for instruction. + """ + assert self.opcode == "phi", "instruction must be a phi" + for i in range(0, len(self.operands), 2): + label = self.operands[i] + var = self.operands[i + 1] + assert isinstance(label, IRLabel), "phi operand must be a label" + assert isinstance(var, IRVariable), "phi operand must be a variable" + yield label, var + def __repr__(self) -> str: s = "" if self.output: s += f"{self.output} = " opcode = f"{self.opcode} " if self.opcode != "store" else "" s += opcode - operands = ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + operands = self.operands + if opcode not in ["jmp", "jnz", "invoke"]: + operands = reversed(operands) # type: ignore + s += ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] ) - s += operands if self.annotation: s += f" <{self.annotation}>" - # if self.liveness: - # return f"{s: <30} # {self.liveness}" + if self.liveness: + return f"{s: <30} # {self.liveness}" return s @@ -307,6 +368,9 @@ class IRBasicBlock: # stack items which this basic block produces out_vars: OrderedSet[IRVariable] + reachable: OrderedSet["IRBasicBlock"] + is_reachable: bool = False + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: assert isinstance(label, IRLabel), "label must be an IRLabel" self.label = label @@ -315,6 +379,8 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None: self.cfg_in = OrderedSet() self.cfg_out = OrderedSet() self.out_vars = OrderedSet() + self.reachable = OrderedSet() + self.is_reachable = False def add_cfg_in(self, bb: "IRBasicBlock") -> None: self.cfg_in.add(bb) @@ -333,23 +399,26 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: assert bb in self.cfg_out self.cfg_out.remove(bb) - @property - def is_reachable(self) -> bool: - return len(self.cfg_in) > 0 - - def append_instruction(self, opcode: str, *args: Union[IROperand, int]) -> Optional[IRVariable]: + def append_instruction( + self, opcode: str, *args: Union[IROperand, int], ret: IRVariable = None + ) -> Optional[IRVariable]: """ Append an instruction to the basic block Returns the output variable if the instruction supports one """ - ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + assert not self.is_terminated, self + + if ret is None: + ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] inst = IRInstruction(opcode, inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret @@ -357,10 +426,9 @@ def append_invoke_instruction( self, args: list[IROperand | int], returns: bool ) -> Optional[IRVariable]: """ - Append an instruction to the basic block - - Returns the output variable if the instruction supports one + Append an invoke to the basic block """ + assert not self.is_terminated, self ret = None if returns: ret = self.parent.get_next_variable() @@ -368,16 +436,30 @@ def append_invoke_instruction( # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] + assert isinstance(inst_args[0], IRLabel), "Invoked non label" + inst = IRInstruction("invoke", inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret - def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = None) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + + if index is None: + assert not self.is_terminated, self + index = len(self.instructions) instruction.parent = self + instruction.ast_source = self.parent.ast_source + instruction.error_msg = self.parent.error_msg self.instructions.insert(index, instruction) + def remove_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + self.instructions.remove(instruction) + def clear_instructions(self) -> None: self.instructions = [] @@ -388,6 +470,19 @@ def replace_operands(self, replacements: dict) -> None: for instruction in self.instructions: instruction.replace_operands(replacements) + def get_assignments(self): + """ + Get all assignments in basic block. + """ + return [inst.output for inst in self.instructions if inst.output] + + @property + def is_empty(self) -> bool: + """ + Check if the basic block is empty, i.e. it has no instructions. + """ + return len(self.instructions) == 0 + @property def is_terminated(self) -> bool: """ @@ -399,6 +494,20 @@ def is_terminated(self) -> bool: return False return self.instructions[-1].opcode in BB_TERMINATORS + @property + def is_terminal(self) -> bool: + """ + Check if the basic block is terminal. + """ + return len(self.cfg_out) == 0 + + @property + def in_vars(self) -> OrderedSet[IRVariable]: + for inst in self.instructions: + if inst.opcode != "phi": + return inst.liveness + return OrderedSet() + def copy(self): bb = IRBasicBlock(self.label, self.parent) bb.instructions = self.instructions.copy() diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py index 620ee66d15..60dd8bbee1 100644 --- a/vyper/venom/bb_optimizer.py +++ b/vyper/venom/bb_optimizer.py @@ -56,6 +56,25 @@ def _optimize_empty_basicblocks(ctx: IRFunction) -> int: return count +def _daisychain_empty_basicblocks(ctx: IRFunction) -> int: + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if bb.is_terminated: + continue + + if i < len(ctx.basic_blocks) - 1: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + else: + bb.append_instruction("stop") + + count += 1 + + return count + + @ir_pass def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: changes = _optimize_empty_basicblocks(ctx) diff --git a/vyper/venom/dominators.py b/vyper/venom/dominators.py new file mode 100644 index 0000000000..b69c17e1d8 --- /dev/null +++ b/vyper/venom/dominators.py @@ -0,0 +1,166 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +class DominatorTree: + """ + Dominator tree implementation. This class computes the dominator tree of a + function and provides methods to query the tree. The tree is computed using + the Lengauer-Tarjan algorithm. + """ + + ctx: IRFunction + entry_block: IRBasicBlock + dfs_order: dict[IRBasicBlock, int] + dfs_walk: list[IRBasicBlock] + dominators: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + immediate_dominators: dict[IRBasicBlock, IRBasicBlock] + dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + dominator_frontiers: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + + @classmethod + def build_dominator_tree(cls, ctx, entry): + ret = DominatorTree() + ret.compute(ctx, entry) + return ret + + def compute(self, ctx: IRFunction, entry: IRBasicBlock): + """ + Compute the dominator tree. + """ + self.ctx = ctx + self.entry_block = entry + self.dfs_order = {} + self.dfs_walk = [] + self.dominators = {} + self.immediate_dominators = {} + self.dominated = {} + self.dominator_frontiers = {} + + self._compute_dfs(self.entry_block, OrderedSet()) + self._compute_dominators() + self._compute_idoms() + self._compute_df() + + def dominates(self, bb1, bb2): + """ + Check if bb1 dominates bb2. + """ + return bb2 in self.dominators[bb1] + + def immediate_dominator(self, bb): + """ + Return the immediate dominator of a basic block. + """ + return self.immediate_dominators.get(bb) + + def _compute_dominators(self): + """ + Compute dominators + """ + basic_blocks = list(self.dfs_order.keys()) + self.dominators = {bb: OrderedSet(basic_blocks) for bb in basic_blocks} + self.dominators[self.entry_block] = OrderedSet([self.entry_block]) + changed = True + count = len(basic_blocks) ** 2 # TODO: find a proper bound for this + while changed: + count -= 1 + if count < 0: + raise CompilerPanic("Dominators computation failed to converge") + changed = False + for bb in basic_blocks: + if bb == self.entry_block: + continue + preds = bb.cfg_in + if len(preds) == 0: + continue + new_dominators = OrderedSet.intersection(*[self.dominators[pred] for pred in preds]) + new_dominators.add(bb) + if new_dominators != self.dominators[bb]: + self.dominators[bb] = new_dominators + changed = True + + def _compute_idoms(self): + """ + Compute immediate dominators + """ + self.immediate_dominators = {bb: None for bb in self.dfs_order.keys()} + self.immediate_dominators[self.entry_block] = self.entry_block + for bb in self.dfs_walk: + if bb == self.entry_block: + continue + doms = sorted(self.dominators[bb], key=lambda x: self.dfs_order[x]) + self.immediate_dominators[bb] = doms[1] + + self.dominated = {bb: OrderedSet() for bb in self.dfs_walk} + for dom, target in self.immediate_dominators.items(): + self.dominated[target].add(dom) + + def _compute_df(self): + """ + Compute dominance frontier + """ + basic_blocks = self.dfs_walk + self.dominator_frontiers = {bb: OrderedSet() for bb in basic_blocks} + + for bb in self.dfs_walk: + if len(bb.cfg_in) > 1: + for pred in bb.cfg_in: + runner = pred + while runner != self.immediate_dominators[bb]: + self.dominator_frontiers[runner].add(bb) + runner = self.immediate_dominators[runner] + + def dominance_frontier(self, basic_blocks: list[IRBasicBlock]) -> OrderedSet[IRBasicBlock]: + """ + Compute dominance frontier of a set of basic blocks. + """ + df = OrderedSet[IRBasicBlock]() + for bb in basic_blocks: + df.update(self.dominator_frontiers[bb]) + return df + + def _intersect(self, bb1, bb2): + """ + Find the nearest common dominator of two basic blocks. + """ + dfs_order = self.dfs_order + while bb1 != bb2: + while dfs_order[bb1] < dfs_order[bb2]: + bb1 = self.immediate_dominators[bb1] + while dfs_order[bb1] > dfs_order[bb2]: + bb2 = self.immediate_dominators[bb2] + return bb1 + + def _compute_dfs(self, entry: IRBasicBlock, visited): + """ + Depth-first search to compute the DFS order of the basic blocks. This + is used to compute the dominator tree. The sequence of basic blocks in + the DFS order is stored in `self.dfs_walk`. The DFS order of each basic + block is stored in `self.dfs_order`. + """ + visited.add(entry) + + for bb in entry.cfg_out: + if bb not in visited: + self._compute_dfs(bb, visited) + + self.dfs_walk.append(entry) + self.dfs_order[entry] = len(self.dfs_walk) + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dominator tree. + """ + lines = ["digraph dominator_tree {"] + for bb in self.ctx.basic_blocks: + if bb == self.entry_block: + continue + idom = self.immediate_dominator(bb) + if idom is None: + continue + lines.append(f' " {idom.label} " -> " {bb.label} "') + lines.append("}") + return "\n".join(lines) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 771dcf73ce..d1680385f5 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -1,12 +1,14 @@ -from typing import Optional +from typing import Iterator, Optional +from vyper.codegen.ir_node import IRnode +from vyper.utils import OrderedSet from vyper.venom.basicblock import ( + CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRInstruction, IRLabel, IROperand, IRVariable, - MemType, ) GLOBAL_LABEL = IRLabel("__global") @@ -27,6 +29,11 @@ class IRFunction: last_label: int last_variable: int + # Used during code generation + _ast_source_stack: list[int] + _error_msg_stack: list[str] + _bb_index: dict[str, int] + def __init__(self, name: IRLabel = None) -> None: if name is None: name = GLOBAL_LABEL @@ -40,6 +47,10 @@ def __init__(self, name: IRLabel = None) -> None: self.last_label = 0 self.last_variable = 0 + self._ast_source_stack = [] + self._error_msg_stack = [] + self._bb_index = {} + self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) @@ -62,10 +73,22 @@ def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" self.basic_blocks.append(bb) - # TODO add sanity check somewhere that basic blocks have unique labels - return self.basic_blocks[-1] + def _get_basicblock_index(self, label: str): + # perf: keep an "index" of labels to block indices to + # perform fast lookup. + # TODO: maybe better just to throw basic blocks in an ordered + # dict of some kind. + ix = self._bb_index.get(label, -1) + if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: + return ix + # do a reindex + self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) + # sanity check - no duplicate labels + assert len(self._bb_index) == len(self.basic_blocks) + return self._bb_index[label] + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ Get basic block by label. @@ -73,49 +96,97 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ if label is None: return self.basic_blocks[-1] - for bb in self.basic_blocks: - if bb.label.value == label: - return bb - raise AssertionError(f"Basic block '{label}' not found") + ix = self._get_basicblock_index(label) + return self.basic_blocks[ix] def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: """ Get basic block after label. """ - for i, bb in enumerate(self.basic_blocks[:-1]): - if bb.label.value == label.value: - return self.basic_blocks[i + 1] + ix = self._get_basicblock_index(label.value) + if 0 <= ix < len(self.basic_blocks) - 1: + return self.basic_blocks[ix + 1] raise AssertionError(f"Basic block after '{label}' not found") + def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: + """ + Get basic blocks that are terminal. + """ + for bb in self.basic_blocks: + if bb.is_terminal: + yield bb + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: """ - Get basic blocks that contain label. + Get basic blocks that point to the given basic block """ return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_label(self) -> IRLabel: + def get_next_label(self, suffix: str = "") -> IRLabel: + if suffix != "": + suffix = f"_{suffix}" self.last_label += 1 - return IRLabel(f"{self.last_label}") + return IRLabel(f"{self.last_label}{suffix}") - def get_next_variable( - self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None - ) -> IRVariable: + def get_next_variable(self) -> IRVariable: self.last_variable += 1 - return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + return IRVariable(f"%{self.last_variable}") def get_last_variable(self) -> str: return f"%{self.last_variable}" def remove_unreachable_blocks(self) -> int: - removed = 0 + self._compute_reachability() + + removed = [] new_basic_blocks = [] + + # Remove unreachable basic blocks for bb in self.basic_blocks: - if not bb.is_reachable and bb.label not in self.entry_points: - removed += 1 + if not bb.is_reachable: + removed.append(bb) else: new_basic_blocks.append(bb) self.basic_blocks = new_basic_blocks - return removed + + # Remove phi instructions that reference removed basic blocks + for bb in removed: + for out_bb in bb.cfg_out: + out_bb.remove_cfg_in(bb) + for inst in out_bb.instructions: + if inst.opcode != "phi": + continue + in_labels = inst.get_label_operands() + if bb.label in in_labels: + out_bb.remove_instruction(inst) + + return len(removed) + + def _compute_reachability(self) -> None: + """ + Compute reachability of basic blocks. + """ + for bb in self.basic_blocks: + bb.reachable = OrderedSet() + bb.is_reachable = False + + for entry in self.entry_points: + entry_bb = self.get_basic_block(entry.value) + self._compute_reachability_from(entry_bb) + + def _compute_reachability_from(self, bb: IRBasicBlock) -> None: + """ + Compute reachability of basic blocks from bb. + """ + if bb.is_reachable: + return + bb.is_reachable = True + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS or inst.opcode == "invoke": + for op in inst.get_label_operands(): + out_bb = self.get_basic_block(op.value) + bb.reachable.add(out_bb) + self._compute_reachability_from(out_bb) def append_data(self, opcode: str, args: list[IROperand]) -> None: """ @@ -147,6 +218,25 @@ def normalized(self) -> bool: # The function is normalized return True + def push_source(self, ir): + if isinstance(ir, IRnode): + self._ast_source_stack.append(ir.ast_source) + self._error_msg_stack.append(ir.error_msg) + + def pop_source(self): + assert len(self._ast_source_stack) > 0, "Empty source stack" + self._ast_source_stack.pop() + assert len(self._error_msg_stack) > 0, "Empty error stack" + self._error_msg_stack.pop() + + @property + def ast_source(self) -> Optional[int]: + return self._ast_source_stack[-1] if len(self._ast_source_stack) > 0 else None + + @property + def error_msg(self) -> Optional[str]: + return self._error_msg_stack[-1] if len(self._error_msg_stack) > 0 else None + def copy(self): new = IRFunction(self.name) new.basic_blocks = self.basic_blocks.copy() @@ -155,6 +245,32 @@ def copy(self): new.last_variable = self.last_variable return new + def as_graph(self) -> str: + import html + + def _make_label(bb): + ret = '<' + ret += f'\n' + for inst in bb.instructions: + ret += f'\n' + ret += "
{html.escape(str(bb.label))}
{html.escape(str(inst))}
>" + + return ret + # return f"{bb.label.value}:\n" + "\n".join([f" {inst}" for inst in bb.instructions]) + + ret = "digraph G {\n" + + for bb in self.basic_blocks: + for out_bb in bb.cfg_out: + ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n' + + for bb in self.basic_blocks: + ret += f' "{bb.label.value}" [shape=plaintext, ' + ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n' + + ret += "}\n" + return ret + def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" for bb in self.basic_blocks: diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index b3ac3c1ad7..f610e17f58 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -1,12 +1,10 @@ +import functools +import re from typing import Optional -from vyper.codegen.context import VariableRecord from vyper.codegen.ir_node import IRnode from vyper.evm.opcodes import get_opcodes -from vyper.exceptions import CompilerPanic -from vyper.ir.compile_ir import is_mem_sym, is_symbol -from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, OrderedSet +from vyper.utils import MemoryPositions from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -14,12 +12,17 @@ IRLiteral, IROperand, IRVariable, - MemType, ) from vyper.venom.function import IRFunction -_BINARY_IR_INSTRUCTIONS = frozenset( +# Instructions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = frozenset( [ + # binary instructions "eq", "gt", "lt", @@ -27,6 +30,7 @@ "sgt", "shr", "shl", + "sar", "or", "xor", "and", @@ -34,98 +38,104 @@ "sub", "mul", "div", + "smul", + "sdiv", "mod", + "smod", "exp", "sha3", "sha3_64", "signextend", + "chainid", + "basefee", + "timestamp", + "blockhash", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "iload", + "sload", + "tload", + "coinbase", + "number", + "prevrandao", + "difficulty", + "iszero", + "not", + "calldataload", + "extcodesize", + "extcodehash", + "balance", + "msize", + "basefee", + "invalid", + "stop", + "selfdestruct", + "assert", + "assert_unreachable", + "exit", + "calldatacopy", + "mcopy", + "extcodecopy", + "codecopy", + "returndatacopy", + "revert", + "istore", + "sstore", + "tstore", + "create", + "create2", + "addmod", + "mulmod", + "call", + "delegatecall", + "staticcall", ] ) -# Instructions that are mapped to their inverse -INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} - -# Instructions that have a direct EVM opcode equivalent and can -# be passed through to the EVM assembly without special handling -PASS_THROUGH_INSTRUCTIONS = [ - "chainid", - "basefee", - "timestamp", - "blockhash", - "caller", - "selfbalance", - "calldatasize", - "callvalue", - "address", - "origin", - "codesize", - "gas", - "gasprice", - "gaslimit", - "returndatasize", - "coinbase", - "number", - "iszero", - "not", - "calldataload", - "extcodesize", - "extcodehash", - "balance", -] +NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) SymbolTable = dict[str, Optional[IROperand]] -def _get_symbols_common(a: dict, b: dict) -> dict: - ret = {} - # preserves the ordering in `a` - for k in a.keys(): - if k not in b: - continue - if a[k] == b[k]: - continue - ret[k] = a[k], b[k] - return ret - - # convert IRnode directly to venom def ir_node_to_venom(ir: IRnode) -> IRFunction: ctx = IRFunction() - _convert_ir_bb(ctx, ir, {}, OrderedSet(), {}) + _convert_ir_bb(ctx, ir, {}) # Patch up basic blocks. Connect unterminated blocks to the next with # a jump. terminate final basic block with STOP. for i, bb in enumerate(ctx.basic_blocks): if not bb.is_terminated: - if i < len(ctx.basic_blocks) - 1: - bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + if len(ctx.basic_blocks) - 1 > i: + # TODO: revisit this. When contructor calls internal functions they + # are linked to the last ctor block. Should separate them before this + # so we don't have to handle this here + if ctx.basic_blocks[i + 1].label.value.startswith("internal"): + bb.append_instruction("stop") + else: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) else: - bb.append_instruction("stop") + bb.append_instruction("exit") return ctx -def _convert_binary_op( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], - swap: bool = False, -) -> Optional[IRVariable]: - ir_args = ir.args[::-1] if swap else ir.args - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir_args, symbols, variables, allocated_variables) - - assert isinstance(ir.value, str) # mypy hint - return ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) - - def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: - ctx.get_basic_block().append_instruction("jmp", label) + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("jmp_target"), ctx) + ctx.append_basic_block(bb) - label = ctx.get_next_label() - bb = IRBasicBlock(label, ctx) - ctx.append_basic_block(bb) + bb.append_instruction("jmp", label) def _new_block(ctx: IRFunction) -> IRBasicBlock: @@ -134,65 +144,46 @@ def _new_block(ctx: IRFunction) -> IRBasicBlock: return bb -def _handle_self_call( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], -) -> Optional[IRVariable]: - func_t = ir.passthrough_metadata.get("func_t", None) - args_ir = ir.passthrough_metadata["args_ir"] +def _append_return_args(ctx: IRFunction, ofst: int = 0, size: int = 0): + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) + ctx.append_basic_block(bb) + ret_ofst = IRVariable("ret_ofst") + ret_size = IRVariable("ret_size") + bb.append_instruction("store", ofst, ret=ret_ofst) + bb.append_instruction("store", size, ret=ret_size) + + +def _handle_self_call(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optional[IRVariable]: + setup_ir = ir.args[1] goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] target_label = goto_ir.args[0].value # goto - return_buf = goto_ir.args[1] # return buffer + return_buf_ir = goto_ir.args[1] # return buffer ret_args: list[IROperand] = [IRLabel(target_label)] # type: ignore - for arg in args_ir: - if arg.is_literal: - sym = symbols.get(f"&{arg.value}", None) - if sym is None: - ret = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) - ret_args.append(ret) - else: - ret_args.append(sym) # type: ignore - else: - ret = _convert_ir_bb(ctx, arg._optimized, symbols, variables, allocated_variables) - if arg.location and arg.location.load_op == "calldataload": - bb = ctx.get_basic_block() - ret = bb.append_instruction(arg.location.load_op, ret) - ret_args.append(ret) + if setup_ir != goto_ir: + _convert_ir_bb(ctx, setup_ir, symbols) - if return_buf.is_literal: - ret_args.append(return_buf.value) # type: ignore + return_buf = _convert_ir_bb(ctx, return_buf_ir, symbols) bb = ctx.get_basic_block() - do_ret = func_t.return_type is not None - if do_ret: - invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore - allocated_variables["return_buffer"] = invoke_ret # type: ignore - return invoke_ret - else: - bb.append_invoke_instruction(ret_args, returns=False) # type: ignore - return None + if len(goto_ir.args) > 2: + ret_args.append(return_buf) # type: ignore + + bb.append_invoke_instruction(ret_args, returns=False) # type: ignore + + return return_buf def _handle_internal_func( - ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable -) -> IRnode: + ctx: IRFunction, ir: IRnode, does_return_data: bool, symbols: SymbolTable +): bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore bb = ctx.append_basic_block(bb) - old_ir_mempos = 0 - old_ir_mempos += 64 - - for arg in func_t.arguments: - symbols[f"&{old_ir_mempos}"] = bb.append_instruction("param") - bb.instructions[-1].annotation = arg.name - old_ir_mempos += 32 # arg.typ.memory_bytes_required - # return buffer - if func_t.return_type is not None: + if does_return_data: symbols["return_buffer"] = bb.append_instruction("param") bb.instructions[-1].annotation = "return_buffer" @@ -200,17 +191,16 @@ def _handle_internal_func( symbols["return_pc"] = bb.append_instruction("param") bb.instructions[-1].annotation = "return_pc" - return ir.args[0].args[2] + _convert_ir_bb(ctx, ir.args[0].args[2], symbols) def _convert_ir_simple_node( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], + ctx: IRFunction, ir: IRnode, symbols: SymbolTable ) -> Optional[IRVariable]: - args = [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] + # execute in order + args = _convert_ir_bb_list(ctx, ir.args, symbols) + # reverse output variables for stack + args.reverse() return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore @@ -218,265 +208,162 @@ def _convert_ir_simple_node( _continue_target: Optional[IRBasicBlock] = None -def _get_variable_from_address( - variables: OrderedSet[VariableRecord], addr: int -) -> Optional[VariableRecord]: - assert isinstance(addr, int), "non-int address" - for var in variables.keys(): - if var.location.name != "memory": - continue - if addr >= var.pos and addr < var.pos + var.size: # type: ignore - return var - return None - - -def _append_return_for_stack_operand( - ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable -) -> None: - bb = ctx.get_basic_block() - if isinstance(ret_ir, IRLiteral): - sym = symbols.get(f"&{ret_ir.value}", None) - new_var = bb.append_instruction("alloca", 32, ret_ir) - bb.append_instruction("mstore", sym, new_var) # type: ignore - else: - sym = symbols.get(ret_ir.value, None) - if sym is None: - # FIXME: needs real allocations - new_var = bb.append_instruction("alloca", 32, 0) - bb.append_instruction("mstore", ret_ir, new_var) # type: ignore - else: - new_var = ret_ir - bb.append_instruction("return", last_ir, new_var) # type: ignore - - -def _convert_ir_bb_list(ctx, ir, symbols, variables, allocated_variables): +def _convert_ir_bb_list(ctx, ir, symbols): ret = [] for ir_node in ir: - venom = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) - assert venom is not None, ir_node + venom = _convert_ir_bb(ctx, ir_node, symbols) ret.append(venom) return ret -def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): - assert isinstance(ir, IRnode), ir - assert isinstance(variables, OrderedSet) - global _break_target, _continue_target +current_func = None +var_list: list[str] = [] - frame_info = ir.passthrough_metadata.get("frame_info", None) - if frame_info is not None: - local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) - variables |= local_vars - assert isinstance(variables, OrderedSet) +def pop_source_on_return(func): + @functools.wraps(func) + def pop_source(*args, **kwargs): + ctx = args[0] + ret = func(*args, **kwargs) + ctx.pop_source() + return ret - if ir.value in _BINARY_IR_INSTRUCTIONS: - return _convert_binary_op( - ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] - ) + return pop_source + + +@pop_source_on_return +def _convert_ir_bb(ctx, ir, symbols): + assert isinstance(ir, IRnode), ir + global _break_target, _continue_target, current_func, var_list + + ctx.push_source(ir) - elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + if ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: org_value = ir.value ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] - new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + new_var = _convert_ir_simple_node(ctx, ir, symbols) ir.value = org_value return ctx.get_basic_block().append_instruction("iszero", new_var) - elif ir.value in PASS_THROUGH_INSTRUCTIONS: - return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) - - elif ir.value in ["pass", "stop", "return"]: - pass + return _convert_ir_simple_node(ctx, ir, symbols) + elif ir.value == "return": + ctx.get_basic_block().append_instruction( + "return", IRVariable("ret_size"), IRVariable("ret_ofst") + ) elif ir.value == "deploy": ctx.ctor_mem_size = ir.args[0].value ctx.immutables_len = ir.args[2].value return None elif ir.value == "seq": - func_t = ir.passthrough_metadata.get("func_t", None) + if len(ir.args) == 0: + return None if ir.is_self_call: - return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) - elif func_t is not None: - symbols = {} - allocated_variables = {} - variables = OrderedSet( - {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} - ) - if func_t.is_internal: - ir = _handle_internal_func(ctx, ir, func_t, symbols) - # fallthrough - - ret = None - for ir_node in ir.args: # NOTE: skip the last one - ret = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) - - return ret - elif ir.value in ["staticcall", "call"]: # external call - idx = 0 - gas = _convert_ir_bb(ctx, ir.args[idx], symbols, variables, allocated_variables) - address = _convert_ir_bb(ctx, ir.args[idx + 1], symbols, variables, allocated_variables) - - value = None - if ir.value == "call": - value = _convert_ir_bb(ctx, ir.args[idx + 2], symbols, variables, allocated_variables) - else: - idx -= 1 - - argsOffset, argsSize, retOffset, retSize = _convert_ir_bb_list( - ctx, ir.args[idx + 3 : idx + 7], symbols, variables, allocated_variables - ) - - if isinstance(argsOffset, IRLiteral): - offset = int(argsOffset.value) - addr = offset - 32 + 4 if offset > 0 else 0 - argsOffsetVar = symbols.get(f"&{addr}", None) - if argsOffsetVar is None: - argsOffsetVar = argsOffset - elif isinstance(argsOffsetVar, IRVariable): - argsOffsetVar.mem_type = MemType.MEMORY - argsOffsetVar.mem_addr = addr - argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 - else: # pragma: nocover - raise CompilerPanic("unreachable") + return _handle_self_call(ctx, ir, symbols) + elif ir.args[0].value == "label": + current_func = ir.args[0].args[0].value + is_external = current_func.startswith("external") + is_internal = current_func.startswith("internal") + if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", current_func)) > 0: + # Internal definition + var_list = ir.args[0].args[1] + does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args + symbols = {} + _handle_internal_func(ctx, ir, does_return_data, symbols) + for ir_node in ir.args[1:]: + ret = _convert_ir_bb(ctx, ir_node, symbols) + + return ret + elif is_external: + ret = _convert_ir_bb(ctx, ir.args[0], symbols) + _append_return_args(ctx) else: - argsOffsetVar = argsOffset - - retOffsetValue = int(retOffset.value) if retOffset else 0 - retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) - symbols[f"&{retOffsetValue}"] = retVar + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("seq"), ctx) + ctx.append_basic_block(bb) + ret = _convert_ir_bb(ctx, ir.args[0], symbols) - bb = ctx.get_basic_block() + for ir_node in ir.args[1:]: + ret = _convert_ir_bb(ctx, ir_node, symbols) - if ir.value == "call": - args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] - return bb.append_instruction(ir.value, *args) - else: - args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] - return bb.append_instruction(ir.value, *args) + return ret elif ir.value == "if": cond = ir.args[0] # convert the condition - cont_ret = _convert_ir_bb(ctx, cond, symbols, variables, allocated_variables) - current_bb = ctx.get_basic_block() + cont_ret = _convert_ir_bb(ctx, cond, symbols) + cond_block = ctx.get_basic_block() + + cond_symbols = symbols.copy() - else_block = IRBasicBlock(ctx.get_next_label(), ctx) + else_block = IRBasicBlock(ctx.get_next_label("else"), ctx) ctx.append_basic_block(else_block) # convert "else" else_ret_val = None - else_syms = symbols.copy() if len(ir.args) == 3: - else_ret_val = _convert_ir_bb( - ctx, ir.args[2], else_syms, variables, allocated_variables.copy() - ) + else_ret_val = _convert_ir_bb(ctx, ir.args[2], cond_symbols) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) - after_else_syms = else_syms.copy() - else_block = ctx.get_basic_block() + + else_block_finish = ctx.get_basic_block() # convert "then" - then_block = IRBasicBlock(ctx.get_next_label(), ctx) + cond_symbols = symbols.copy() + + then_block = IRBasicBlock(ctx.get_next_label("then"), ctx) ctx.append_basic_block(then_block) - then_ret_val = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) + then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols) if isinstance(then_ret_val, IRLiteral): then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) - current_bb.append_instruction("jnz", cont_ret, then_block.label, else_block.label) + cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label) - after_then_syms = symbols.copy() - then_block = ctx.get_basic_block() + then_block_finish = ctx.get_basic_block() # exit bb - exit_label = ctx.get_next_label() - exit_bb = IRBasicBlock(exit_label, ctx) + exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), ctx) exit_bb = ctx.append_basic_block(exit_bb) - if_ret = None + if_ret = ctx.get_next_variable() if then_ret_val is not None and else_ret_val is not None: - if_ret = exit_bb.append_instruction( - "phi", then_block.label, then_ret_val, else_block.label, else_ret_val - ) - - common_symbols = _get_symbols_common(after_then_syms, after_else_syms) - for sym, val in common_symbols.items(): - ret = exit_bb.append_instruction( - "phi", then_block.label, val[0], else_block.label, val[1] - ) - old_var = symbols.get(sym, None) - symbols[sym] = ret - if old_var is not None: - for idx, var_rec in allocated_variables.items(): # type: ignore - if var_rec.value == old_var.value: - allocated_variables[idx] = ret # type: ignore - - if not else_block.is_terminated: - else_block.append_instruction("jmp", exit_bb.label) - - if not then_block.is_terminated: - then_block.append_instruction("jmp", exit_bb.label) + then_block_finish.append_instruction("store", then_ret_val, ret=if_ret) + else_block_finish.append_instruction("store", else_ret_val, ret=if_ret) + + if not else_block_finish.is_terminated: + else_block_finish.append_instruction("jmp", exit_bb.label) + + if not then_block_finish.is_terminated: + then_block_finish.append_instruction("jmp", exit_bb.label) return if_ret elif ir.value == "with": - ret = _convert_ir_bb( - ctx, ir.args[1], symbols, variables, allocated_variables - ) # initialization + ret = _convert_ir_bb(ctx, ir.args[1], symbols) # initialization + + ret = ctx.get_basic_block().append_instruction("store", ret) # Handle with nesting with same symbol with_symbols = symbols.copy() sym = ir.args[0] - if isinstance(ret, IRLiteral): - new_var = ctx.get_basic_block().append_instruction("store", ret) # type: ignore - with_symbols[sym.value] = new_var - else: - with_symbols[sym.value] = ret # type: ignore + with_symbols[sym.value] = ret - return _convert_ir_bb(ctx, ir.args[2], with_symbols, variables, allocated_variables) # body + return _convert_ir_bb(ctx, ir.args[2], with_symbols) # body elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "djump": - args = [_convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables)] + args = [_convert_ir_bb(ctx, ir.args[0], symbols)] for target in ir.args[1:]: args.append(IRLabel(target.value)) ctx.get_basic_block().append_instruction("djmp", *args) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] - arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) - new_var = ctx.get_basic_block().append_instruction("store", arg_1) # type: ignore - symbols[sym.value] = new_var - - elif ir.value == "calldatacopy": - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - new_v = arg_0 - var = ( - _get_variable_from_address(variables, int(arg_0.value)) - if isinstance(arg_0, IRLiteral) - else None - ) - bb = ctx.get_basic_block() - if var is not None: - if allocated_variables.get(var.name, None) is None: - new_v = bb.append_instruction("alloca", var.size, var.pos) # type: ignore - allocated_variables[var.name] = new_v # type: ignore - bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore - symbols[f"&{var.pos}"] = new_v # type: ignore - else: - bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore - - return new_v - elif ir.value == "codecopy": - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - ctx.get_basic_block().append_instruction("codecopy", size, arg_1, arg_0) # type: ignore + arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols) + ctx.get_basic_block().append_instruction("store", arg_1, ret=symbols[sym.value]) elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": @@ -486,15 +373,11 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): if isinstance(c, int): assert 0 <= c <= 255, "data with invalid size" ctx.append_data("db", [c]) # type: ignore - elif isinstance(c, bytes): - ctx.append_data("db", [c]) # type: ignore + elif isinstance(c.value, bytes): + ctx.append_data("db", [c.value]) # type: ignore elif isinstance(c, IRnode): - data = _convert_ir_bb(ctx, c, symbols, variables, allocated_variables) + data = _convert_ir_bb(ctx, c, symbols) ctx.append_data("db", [data]) # type: ignore - elif ir.value == "assert": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - current_bb = ctx.get_basic_block() - current_bb.append_instruction("assert", arg_0) elif ir.value == "label": label = IRLabel(ir.args[0].value, True) bb = ctx.get_basic_block() @@ -502,97 +385,30 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): bb.append_instruction("jmp", label) bb = IRBasicBlock(label, ctx) ctx.append_basic_block(bb) - _convert_ir_bb(ctx, ir.args[2], symbols, variables, allocated_variables) + code = ir.args[2] + if code.value == "pass": + bb.append_instruction("exit") + else: + _convert_ir_bb(ctx, code, symbols) elif ir.value == "exit_to": - func_t = ir.passthrough_metadata.get("func_t", None) - assert func_t is not None, "exit_to without func_t" - - if func_t.is_external: - # Hardcoded constructor special case - bb = ctx.get_basic_block() - if func_t.name == "__init__": - label = IRLabel(ir.args[0].value, True) - bb.append_instruction("jmp", label) - return None - if func_t.return_type is None: - bb.append_instruction("stop") - return None - else: - last_ir = None - ret_var = ir.args[1] - deleted = None - if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: - deleted = symbols[f"&{ret_var.value}"] - del symbols[f"&{ret_var.value}"] - for arg in ir.args[2:]: - last_ir = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) - if deleted is not None: - symbols[f"&{ret_var.value}"] = deleted - - ret_ir = _convert_ir_bb(ctx, ret_var, symbols, variables, allocated_variables) - - bb = ctx.get_basic_block() - - var = ( - _get_variable_from_address(variables, int(ret_ir.value)) - if isinstance(ret_ir, IRLiteral) - else None - ) - if var is not None: - allocated_var = allocated_variables.get(var.name, None) - assert allocated_var is not None, "unallocated variable" - new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore - - if var.size and int(var.size) > 32: - offset = int(ret_ir.value) - var.pos # type: ignore - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_var - bb.append_instruction("return", last_ir, ptr_var) - else: - _append_return_for_stack_operand(ctx, symbols, new_var, last_ir) - else: - if isinstance(ret_ir, IRLiteral): - sym = symbols.get(f"&{ret_ir.value}", None) - if sym is None: - bb.append_instruction("return", last_ir, ret_ir) - else: - if func_t.return_type.memory_bytes_required > 32: - new_var = bb.append_instruction("alloca", 32, ret_ir) - bb.append_instruction("mstore", sym, new_var) - bb.append_instruction("return", last_ir, new_var) - else: - bb.append_instruction("return", last_ir, ret_ir) - else: - if last_ir and int(last_ir.value) > 32: - bb.append_instruction("return", last_ir, ret_ir) - else: - ret_buf = 128 # TODO: need allocator - new_var = bb.append_instruction("alloca", 32, ret_buf) - bb.append_instruction("mstore", ret_ir, new_var) - bb.append_instruction("return", last_ir, new_var) - - ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) - + args = _convert_ir_bb_list(ctx, ir.args[1:], symbols) + var_list = args + _append_return_args(ctx, *var_list) + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) + ctx.append_basic_block(bb) bb = ctx.get_basic_block() - if func_t.is_internal: - assert ir.args[1].value == "return_pc", "return_pc not found" - if func_t.return_type is None: - bb.append_instruction("ret", symbols["return_pc"]) - else: - if func_t.return_type.memory_bytes_required > 32: - bb.append_instruction("ret", symbols["return_buffer"], symbols["return_pc"]) - else: - ret_by_value = bb.append_instruction("mload", symbols["return_buffer"]) - bb.append_instruction("ret", ret_by_value, symbols["return_pc"]) - elif ir.value == "revert": - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("revert", arg_1, arg_0) + label = IRLabel(ir.args[0].value) + if label.value == "return_pc": + label = symbols.get("return_pc") + bb.append_instruction("ret", label) + else: + bb.append_instruction("jmp", label) elif ir.value == "dload": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) bb = ctx.get_basic_block() src = bb.append_instruction("add", arg_0, IRLabel("code_end")) @@ -600,9 +416,7 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) elif ir.value == "dloadbytes": - dst, src_offset, len_ = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) + dst, src_offset, len_ = _convert_ir_bb_list(ctx, ir.args, symbols) bb = ctx.get_basic_block() src = bb.append_instruction("add", src_offset, IRLabel("code_end")) @@ -610,210 +424,106 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): return None elif ir.value == "mload": - sym_ir = ir.args[0] - var = ( - _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None - ) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) bb = ctx.get_basic_block() - if var is not None: - if var.size and var.size > 32: - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = bb.append_instruction( - "alloca", var.size, var.pos - ) - - offset = int(sym_ir.value) - var.pos - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_variables[var.name] + if isinstance(arg_0, IRVariable): + return bb.append_instruction("mload", arg_0) - return bb.append_instruction("mload", ptr_var) - else: - if sym_ir.is_literal: - sym = symbols.get(f"&{sym_ir.value}", None) - if sym is None: - new_var = _convert_ir_bb( - ctx, sym_ir, symbols, variables, allocated_variables - ) - symbols[f"&{sym_ir.value}"] = new_var - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = new_var - return new_var - else: - return sym - - sym = symbols.get(f"&{sym_ir.value}", None) - assert sym is not None, "unallocated variable" - return sym - else: - if sym_ir.is_literal: - new_var = symbols.get(f"&{sym_ir.value}", None) - if new_var is not None: - return bb.append_instruction("mload", new_var) - else: - return bb.append_instruction("mload", sym_ir.value) - else: - new_var = _convert_ir_bb(ctx, sym_ir, symbols, variables, allocated_variables) - # - # Old IR gets it's return value as a reference in the stack - # New IR gets it's return value in stack in case of 32 bytes or less - # So here we detect ahead of time if this mload leads a self call and - # and we skip the mload - # - if sym_ir.is_self_call: - return new_var - return bb.append_instruction("mload", new_var) + if isinstance(arg_0, IRLiteral): + avar = symbols.get(f"%{arg_0.value}") + if avar is not None: + return bb.append_instruction("mload", avar) + return bb.append_instruction("mload", arg_0) elif ir.value == "mstore": - sym_ir, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - - bb = ctx.get_basic_block() - - var = None - if isinstance(sym_ir, IRLiteral): - var = _get_variable_from_address(variables, int(sym_ir.value)) + # some upstream code depends on reversed order of evaluation -- + # to fix upstream. + arg_1, arg_0 = _convert_ir_bb_list(ctx, reversed(ir.args), symbols) - if var is not None and var.size is not None: - if var.size and var.size > 32: - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = bb.append_instruction( - "alloca", var.size, var.pos - ) - - offset = int(sym_ir.value) - var.pos - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_variables[var.name] + if isinstance(arg_1, IRVariable): + symbols[f"&{arg_0.value}"] = arg_1 - bb.append_instruction("mstore", arg_1, ptr_var) - else: - if isinstance(sym_ir, IRLiteral): - new_var = bb.append_instruction("store", arg_1) - symbols[f"&{sym_ir.value}"] = new_var - # if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = new_var - return new_var - else: - if not isinstance(sym_ir, IRLiteral): - bb.append_instruction("mstore", arg_1, sym_ir) - return None - - sym = symbols.get(f"&{sym_ir.value}", None) - if sym is None: - bb.append_instruction("mstore", arg_1, sym_ir) - if arg_1 and not isinstance(sym_ir, IRLiteral): - symbols[f"&{sym_ir.value}"] = arg_1 - return None - - if isinstance(sym_ir, IRLiteral): - bb.append_instruction("mstore", arg_1, sym) - return None - else: - symbols[sym_ir.value] = arg_1 - return arg_1 + ctx.get_basic_block().append_instruction("mstore", arg_1, arg_0) elif ir.value == "ceil32": x = ir.args[0] expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) - return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) + return _convert_ir_bb(ctx, expanded, symbols) elif ir.value == "select": - # b ^ ((a ^ b) * cond) where cond is 1 or 0 cond, a, b = ir.args - expanded = IRnode.from_list(["xor", b, ["mul", cond, ["xor", a, b]]]) - return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) - - elif ir.value in ["sload", "iload"]: - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - return ctx.get_basic_block().append_instruction(ir.value, arg_0) - elif ir.value in ["sstore", "istore"]: - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) - elif ir.value == "unique_symbol": - sym = ir.args[0] - new_var = ctx.get_next_variable() - symbols[f"&{sym.value}"] = new_var - return new_var + expanded = IRnode.from_list( + [ + "with", + "cond", + cond, + [ + "with", + "a", + a, + ["with", "b", b, ["xor", "b", ["mul", "cond", ["xor", "a", "b"]]]], + ], + ] + ) + return _convert_ir_bb(ctx, expanded, symbols) elif ir.value == "repeat": - # - # repeat(sym, start, end, bound, body) - # 1) entry block ] - # 2) init counter block ] -> same block - # 3) condition block (exit block, body block) - # 4) body block - # 5) increment block - # 6) exit block - # TODO: Add the extra bounds check after clarify + def emit_body_blocks(): global _break_target, _continue_target old_targets = _break_target, _continue_target - _break_target, _continue_target = exit_block, increment_block - _convert_ir_bb(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = exit_block, incr_block + _convert_ir_bb(ctx, body, symbols.copy()) _break_target, _continue_target = old_targets sym = ir.args[0] - start, end, _ = _convert_ir_bb_list( - ctx, ir.args[1:4], symbols, variables, allocated_variables - ) + start, end, _ = _convert_ir_bb_list(ctx, ir.args[1:4], symbols) + + assert ir.args[3].is_literal, "repeat bound expected to be literal" + + bound = ir.args[3].value + if ( + isinstance(end, IRLiteral) + and isinstance(start, IRLiteral) + and end.value + start.value <= bound + ): + bound = None body = ir.args[4] - entry_block = ctx.get_basic_block() - cond_block = IRBasicBlock(ctx.get_next_label(), ctx) - body_block = IRBasicBlock(ctx.get_next_label(), ctx) - jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) - increment_block = IRBasicBlock(ctx.get_next_label(), ctx) - exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + entry_block = IRBasicBlock(ctx.get_next_label("repeat"), ctx) + cond_block = IRBasicBlock(ctx.get_next_label("condition"), ctx) + body_block = IRBasicBlock(ctx.get_next_label("body"), ctx) + incr_block = IRBasicBlock(ctx.get_next_label("incr"), ctx) + exit_block = IRBasicBlock(ctx.get_next_label("exit"), ctx) - counter_inc_var = ctx.get_next_variable() + bb = ctx.get_basic_block() + bb.append_instruction("jmp", entry_block.label) + ctx.append_basic_block(entry_block) - counter_var = ctx.get_basic_block().append_instruction("store", start) + counter_var = entry_block.append_instruction("store", start) symbols[sym.value] = counter_var - ctx.get_basic_block().append_instruction("jmp", cond_block.label) - - ret = cond_block.append_instruction( - "phi", entry_block.label, counter_var, increment_block.label, counter_inc_var - ) - symbols[sym.value] = ret + end = entry_block.append_instruction("add", start, end) + if bound: + bound = entry_block.append_instruction("add", start, bound) + entry_block.append_instruction("jmp", cond_block.label) - xor_ret = cond_block.append_instruction("xor", ret, end) + xor_ret = cond_block.append_instruction("xor", counter_var, end) cont_ret = cond_block.append_instruction("iszero", xor_ret) ctx.append_basic_block(cond_block) - start_syms = symbols.copy() ctx.append_basic_block(body_block) - emit_body_blocks() - end_syms = symbols.copy() - diff_syms = _get_symbols_common(start_syms, end_syms) - - replacements = {} - for sym, val in diff_syms.items(): - new_var = ctx.get_next_variable() - symbols[sym] = new_var - replacements[val[0]] = new_var - replacements[val[1]] = new_var - cond_block.insert_instruction( - IRInstruction( - "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var - ), - 1, - ) - - body_block.replace_operands(replacements) + if bound: + xor_ret = body_block.append_instruction("xor", counter_var, bound) + body_block.append_instruction("assert", xor_ret) + emit_body_blocks() body_end = ctx.get_basic_block() - if not body_end.is_terminated: - body_end.append_instruction("jmp", jump_up_block.label) + if body_end.is_terminated is False: + body_end.append_instruction("jmp", incr_block.label) - jump_up_block.append_instruction("jmp", increment_block.label) - ctx.append_basic_block(jump_up_block) - - increment_block.insert_instruction( - IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var), 0 + ctx.append_basic_block(incr_block) + incr_block.insert_instruction( + IRInstruction("add", [counter_var, IRLiteral(1)], counter_var) ) - - increment_block.append_instruction("jmp", cond_block.label) - ctx.append_basic_block(increment_block) + incr_block.append_instruction("jmp", cond_block.label) ctx.append_basic_block(exit_block) @@ -826,32 +536,15 @@ def emit_body_blocks(): assert _continue_target is not None, "Continue with no contrinue target" ctx.get_basic_block().append_instruction("jmp", _continue_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) - elif ir.value == "gas": - return ctx.get_basic_block().append_instruction("gas") - elif ir.value == "returndatasize": - return ctx.get_basic_block().append_instruction("returndatasize") - elif ir.value == "returndatacopy": - assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - new_var = ctx.get_basic_block().append_instruction("returndatacopy", arg_1, size) - - symbols[f"&{arg_0.value}"] = new_var - return new_var - elif ir.value == "selfdestruct": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("selfdestruct", arg_0) + elif ir.value in NOOP_INSTRUCTIONS: + pass elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = reversed( - [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] - ) + args = reversed(_convert_ir_bb_list(ctx, ir.args, symbols)) topic_count = int(ir.value[3:]) assert topic_count >= 0 and topic_count <= 4, "invalid topic count" ctx.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): - _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + _convert_ir_opcode(ctx, ir, symbols) elif isinstance(ir.value, str) and ir.value in symbols: return symbols[ir.value] elif ir.is_literal: @@ -862,28 +555,10 @@ def emit_body_blocks(): return None -def _convert_ir_opcode( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], -) -> None: +def _convert_ir_opcode(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> None: opcode = ir.value.upper() # type: ignore inst_args = [] for arg in ir.args: if isinstance(arg, IRnode): - inst_args.append(_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables)) + inst_args.append(_convert_ir_bb(ctx, arg, symbols)) ctx.get_basic_block().append_instruction(opcode, *inst_args) - - -def _data_ofst_of(sym, ofst, height_): - # e.g. _OFST _sym_foo 32 - assert is_symbol(sym) or is_mem_sym(sym) - if isinstance(ofst.value, int): - # resolve at compile time using magic _OFST op - return ["_OFST", sym, ofst.value] - else: - # if we can't resolve at compile time, resolve at runtime - # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) - return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py index 11da80ac66..3fbbdef6df 100644 --- a/vyper/venom/passes/base_pass.py +++ b/vyper/venom/passes/base_pass.py @@ -9,11 +9,13 @@ def run_pass(cls, *args, **kwargs): t = cls() count = 0 - while True: + for _ in range(1000): changes_count = t._run_pass(*args, **kwargs) or 0 count += changes_count if changes_count == 0: break + else: + raise Exception("Too many iterations in IR pass!", t.__class__) return count diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 26994bd27f..5d149cf003 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -1,13 +1,31 @@ from vyper.utils import OrderedSet from vyper.venom.analysis import DFG -from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.basicblock import BB_TERMINATORS, IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass -# DataFlow Transformation class DFTPass(IRPass): - def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + inst_order: dict[IRInstruction, int] + inst_order_num: int + + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: int = 0): + for op in inst.get_outputs(): + assert isinstance(op, IRVariable), f"expected variable, got {op}" + uses = self.dfg.get_uses(op) + + for uses_this in uses: + if uses_this.parent != inst.parent or uses_this.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + + # if the instruction is a terminator, we need to place + # it at the end of the basic block + # along with all the instructions that "lead" to it + if uses_this.opcode in BB_TERMINATORS: + offset = len(bb.instructions) + self._process_instruction_r(bb, uses_this, offset) + if inst in self.visited_instructions: return self.visited_instructions.add(inst) @@ -15,35 +33,43 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): if inst.opcode == "phi": # phi instructions stay at the beginning of the basic block # and no input processing is needed - bb.instructions.append(inst) + # bb.instructions.append(inst) + self.inst_order[inst] = 0 return for op in inst.get_inputs(): target = self.dfg.get_producing_instruction(op) + assert target is not None, f"no producing instruction for {op}" if target.parent != inst.parent or target.fence_id != inst.fence_id: # don't reorder across basic block or fence boundaries continue - self._process_instruction_r(bb, target) + self._process_instruction_r(bb, target, offset) - bb.instructions.append(inst) + self.inst_order_num += 1 + self.inst_order[inst] = self.inst_order_num + offset def _process_basic_block(self, bb: IRBasicBlock) -> None: self.ctx.append_basic_block(bb) - instructions = bb.instructions - bb.instructions = [] - - for inst in instructions: + for inst in bb.instructions: inst.fence_id = self.fence_id if inst.volatile: self.fence_id += 1 - for inst in instructions: + # We go throught the instructions and calculate the order in which they should be executed + # based on the data flow graph. This order is stored in the inst_order dictionary. + # We then sort the instructions based on this order. + self.inst_order = {} + self.inst_order_num = 0 + for inst in bb.instructions: self._process_instruction_r(bb, inst) + bb.instructions.sort(key=lambda x: self.inst_order[x]) + def _run_pass(self, ctx: IRFunction) -> None: self.ctx = ctx self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py new file mode 100644 index 0000000000..06c61c9ea7 --- /dev/null +++ b/vyper/venom/passes/make_ssa.py @@ -0,0 +1,174 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IROperand, IRVariable +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class MakeSSA(IRPass): + """ + This pass converts the function into Static Single Assignment (SSA) form. + """ + + dom: DominatorTree + defs: dict[IRVariable, OrderedSet[IRBasicBlock]] + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int: + self.ctx = ctx + + calculate_cfg(ctx) + self.dom = DominatorTree.build_dominator_tree(ctx, entry) + + calculate_liveness(ctx) + self._add_phi_nodes() + + self.var_name_counters = {var.name: 0 for var in self.defs.keys()} + self.var_name_stacks = {var.name: [0] for var in self.defs.keys()} + self._rename_vars(entry) + self._remove_degenerate_phis(entry) + + return 0 + + def _add_phi_nodes(self): + """ + Add phi nodes to the function. + """ + self._compute_defs() + work = {var: 0 for var in self.dom.dfs_walk} + has_already = {var: 0 for var in self.dom.dfs_walk} + i = 0 + + # Iterate over all variables + for var, d in self.defs.items(): + i += 1 + defs = list(d) + while len(defs) > 0: + bb = defs.pop() + for dom in self.dom.dominator_frontiers[bb]: + if has_already[dom] >= i: + continue + + self._place_phi(var, dom) + has_already[dom] = i + if work[dom] < i: + work[dom] = i + defs.append(dom) + + def _place_phi(self, var: IRVariable, basic_block: IRBasicBlock): + if var not in basic_block.in_vars: + return + + args: list[IROperand] = [] + for bb in basic_block.cfg_in: + if bb == basic_block: + continue + + args.append(bb.label) # type: ignore + args.append(var) # type: ignore + + basic_block.insert_instruction(IRInstruction("phi", args, var), 0) + + def _add_phi(self, var: IRVariable, basic_block: IRBasicBlock) -> bool: + for inst in basic_block.instructions: + if inst.opcode == "phi" and inst.output is not None and inst.output.name == var.name: + return False + + args: list[IROperand] = [] + for bb in basic_block.cfg_in: + if bb == basic_block: + continue + + args.append(bb.label) + args.append(var) + + phi = IRInstruction("phi", args, var) + basic_block.instructions.insert(0, phi) + + return True + + def _rename_vars(self, basic_block: IRBasicBlock): + """ + Rename variables. This follows the placement of phi nodes. + """ + outs = [] + + # Pre-action + for inst in basic_block.instructions: + new_ops = [] + if inst.opcode != "phi": + for op in inst.operands: + if not isinstance(op, IRVariable): + new_ops.append(op) + continue + + new_ops.append(IRVariable(op.name, version=self.var_name_stacks[op.name][-1])) + + inst.operands = new_ops + + if inst.output is not None: + v_name = inst.output.name + i = self.var_name_counters[v_name] + + self.var_name_stacks[v_name].append(i) + self.var_name_counters[v_name] = i + 1 + + inst.output = IRVariable(v_name, version=i) + # note - after previous line, inst.output.name != v_name + outs.append(inst.output.name) + + for bb in basic_block.cfg_out: + for inst in bb.instructions: + if inst.opcode != "phi": + continue + assert inst.output is not None, "Phi instruction without output" + for i, op in enumerate(inst.operands): + if op == basic_block.label: + inst.operands[i + 1] = IRVariable( + inst.output.name, version=self.var_name_stacks[inst.output.name][-1] + ) + + for bb in self.dom.dominated[basic_block]: + if bb == basic_block: + continue + self._rename_vars(bb) + + # Post-action + for op_name in outs: + # NOTE: each pop corresponds to an append in the pre-action above + self.var_name_stacks[op_name].pop() + + def _remove_degenerate_phis(self, entry: IRBasicBlock): + for inst in entry.instructions.copy(): + if inst.opcode != "phi": + continue + + new_ops = [] + for label, op in inst.phi_operands: + if op == inst.output: + continue + new_ops.extend([label, op]) + new_ops_len = len(new_ops) + if new_ops_len == 0: + entry.instructions.remove(inst) + elif new_ops_len == 2: + entry.instructions.remove(inst) + else: + inst.operands = new_ops + + for bb in self.dom.dominated[entry]: + if bb == entry: + continue + self._remove_degenerate_phis(bb) + + def _compute_defs(self): + """ + Compute the definition points of variables in the function. + """ + self.defs = {} + for bb in self.dom.dfs_walk: + assignments = bb.get_assignments() + for var in assignments: + if var not in self.defs: + self.defs[var] = OrderedSet() + self.defs[var].add(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 26699099b2..9ca8127b2d 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -28,7 +28,7 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB source = in_bb.label.value target = bb.label.value - split_label = IRLabel(f"{target}_split_{source}") + split_label = IRLabel(f"{source}_split_{target}") in_terminal = in_bb.instructions[-1] in_terminal.replace_label_operands({bb.label: split_label}) @@ -36,6 +36,13 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) + for inst in bb.instructions: + if inst.opcode != "phi": + continue + for i in range(0, len(inst.operands), 2): + if inst.operands[i] == in_bb.label: + inst.operands[i] = split_bb.label + # Update the labels in the data segment for inst in self.ctx.data_segment: if inst.opcode == "db" and inst.operands[0] == bb.label: @@ -55,5 +62,6 @@ def _run_pass(self, ctx: IRFunction) -> int: # If we made changes, recalculate the cfg if self.changes > 0: calculate_cfg(ctx) + ctx.remove_unreachable_blocks() return self.changes diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py new file mode 100644 index 0000000000..7f02ccf819 --- /dev/null +++ b/vyper/venom/passes/simplify_cfg.py @@ -0,0 +1,82 @@ +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class SimplifyCFGPass(IRPass): + visited: OrderedSet + + def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): + a.instructions.pop() + for inst in b.instructions: + assert inst.opcode != "phi", "Not implemented yet" + if inst.opcode == "phi": + a.instructions.insert(0, inst) + else: + inst.parent = a + a.instructions.append(inst) + + # Update CFG + a.cfg_out = b.cfg_out + if len(b.cfg_out) > 0: + next_bb = b.cfg_out.first() + next_bb.remove_cfg_in(b) + next_bb.add_cfg_in(a) + + self.ctx.basic_blocks.remove(b) + + def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): + next_bb = b.cfg_out.first() + jump_inst = a.instructions[-1] + assert b.label in jump_inst.operands, f"{b.label} {jump_inst.operands}" + jump_inst.operands[jump_inst.operands.index(b.label)] = next_bb.label + + # Update CFG + a.remove_cfg_out(b) + a.add_cfg_out(next_bb) + next_bb.remove_cfg_in(b) + next_bb.add_cfg_in(a) + + self.ctx.basic_blocks.remove(b) + + def _collapse_chained_blocks_r(self, bb: IRBasicBlock): + """ + DFS into the cfg and collapse blocks with a single predecessor to the predecessor + """ + if len(bb.cfg_out) == 1: + next_bb = bb.cfg_out.first() + if len(next_bb.cfg_in) == 1: + self._merge_blocks(bb, next_bb) + self._collapse_chained_blocks_r(bb) + return + elif len(bb.cfg_out) == 2: + bb_out = bb.cfg_out.copy() + for next_bb in bb_out: + if ( + len(next_bb.cfg_in) == 1 + and len(next_bb.cfg_out) == 1 + and len(next_bb.instructions) == 1 + ): + self._merge_jump(bb, next_bb) + self._collapse_chained_blocks_r(bb) + return + + if bb in self.visited: + return + self.visited.add(bb) + + for bb_out in bb.cfg_out: + self._collapse_chained_blocks_r(bb_out) + + def _collapse_chained_blocks(self, entry: IRBasicBlock): + """ + Collapse blocks with a single predecessor to their predecessor + """ + self.visited = OrderedSet() + self._collapse_chained_blocks_r(entry) + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> None: + self.ctx = ctx + + self._collapse_chained_blocks(entry) diff --git a/vyper/venom/passes/stack_reorder.py b/vyper/venom/passes/stack_reorder.py new file mode 100644 index 0000000000..b32ec4abde --- /dev/null +++ b/vyper/venom/passes/stack_reorder.py @@ -0,0 +1,24 @@ +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class StackReorderPass(IRPass): + visited: OrderedSet + + def _reorder_stack(self, bb: IRBasicBlock): + pass + + def _visit(self, bb: IRBasicBlock): + if bb in self.visited: + return + self.visited.add(bb) + + for bb_out in bb.cfg_out: + self._visit(bb_out) + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock): + self.ctx = ctx + self.visited = OrderedSet() + self._visit(entry) diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py index 66c62b74d2..a98e5bb25b 100644 --- a/vyper/venom/stack_model.py +++ b/vyper/venom/stack_model.py @@ -30,34 +30,36 @@ def push(self, op: IROperand) -> None: def pop(self, num: int = 1) -> None: del self._stack[len(self._stack) - num :] - def get_depth(self, op: IROperand) -> int: + def get_depth(self, op: IROperand, n: int = 1) -> int: """ - Returns the depth of the first matching operand in the stack map. + Returns the depth of the n-th matching operand in the stack map. If the operand is not in the stack map, returns NOT_IN_STACK. """ assert isinstance(op, IROperand), f"{type(op)}: {op}" for i, stack_op in enumerate(reversed(self._stack)): if stack_op.value == op.value: - return -i + if n <= 1: + return -i + else: + n -= 1 return StackModel.NOT_IN_STACK # type: ignore - def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + def get_phi_depth(self, phis: list[IRVariable]) -> int: """ Returns the depth of the first matching phi variable in the stack map. If the none of the phi operands are in the stack, returns NOT_IN_STACK. - Asserts that exactly one of phi1 and phi2 is found. + Asserts that exactly one of phis is found. """ - assert isinstance(phi1, IRVariable) - assert isinstance(phi2, IRVariable) + assert isinstance(phis, list) ret = StackModel.NOT_IN_STACK for i, stack_item in enumerate(reversed(self._stack)): - if stack_item in (phi1, phi2): + if stack_item in phis: assert ( ret is StackModel.NOT_IN_STACK - ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ), f"phi argument is not unique! {phis}, {self._stack}" ret = -i return ret # type: ignore diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 608e100cd1..0cb13becf2 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -1,8 +1,22 @@ +from collections import Counter from typing import Any -from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.exceptions import CompilerPanic, StackTooDeep +from vyper.ir.compile_ir import ( + PUSH, + DataHeader, + Instruction, + RuntimeHeader, + mksymbol, + optimize_assembly, +) from vyper.utils import MemoryPositions, OrderedSet -from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.analysis import ( + calculate_cfg, + calculate_dup_requirements, + calculate_liveness, + input_vars_from, +) from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -10,7 +24,6 @@ IRLiteral, IROperand, IRVariable, - MemType, ) from vyper.venom.function import IRFunction from vyper.venom.passes.normalization import NormalizationPass @@ -23,15 +36,18 @@ "coinbase", "calldatasize", "calldatacopy", + "mcopy", "calldataload", "gas", "gasprice", "gaslimit", + "chainid", "address", "origin", "number", "extcodesize", "extcodehash", + "extcodecopy", "returndatasize", "returndatacopy", "callvalue", @@ -40,13 +56,17 @@ "sstore", "mload", "mstore", + "tload", + "tstore", "timestamp", "caller", + "blockhash", "selfdestruct", "signextend", "stop", "shr", "shl", + "sar", "and", "xor", "or", @@ -54,8 +74,13 @@ "sub", "mul", "div", + "smul", + "sdiv", "mod", + "smod", "exp", + "addmod", + "mulmod", "eq", "iszero", "not", @@ -63,12 +88,34 @@ "lt", "slt", "sgt", + "create", + "create2", + "msize", + "balance", + "call", + "staticcall", + "delegatecall", + "codesize", + "basefee", + "prevrandao", + "difficulty", + "invalid", ] ) _REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] +def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: + ret = [] + for op in asm: + if isinstance(op, str) and not isinstance(op, Instruction): + ret.append(Instruction(op, inst.ast_source, inst.error_msg)) + else: + ret.append(op) + return ret # type: ignore + + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. # There, recursing into the deploy instruction was by design, and @@ -105,18 +152,18 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: # This is a side-effect of how dynamic jumps are temporarily being used # to support the O(1) dispatcher. -> look into calculate_cfg() for ctx in self.ctxs: - calculate_cfg(ctx) NormalizationPass.run_pass(ctx) + calculate_cfg(ctx) calculate_liveness(ctx) + calculate_dup_requirements(ctx) assert ctx.normalized, "Non-normalized CFG!" self._generate_evm_for_basicblock_r(asm, ctx.basic_blocks[0], StackModel()) # TODO make this property on IRFunction + asm.extend(["_sym__ctor_exit", "JUMPDEST"]) if ctx.immutables_len is not None and ctx.ctor_mem_size is not None: - while asm[-1] != "JUMPDEST": - asm.pop() asm.extend( ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] ) @@ -139,7 +186,11 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: label = inst.operands[0].value data_segments[label] = [DataHeader(f"_sym_{label}")] elif inst.opcode == "db": - data_segments[label].append(f"_sym_{inst.operands[0].value}") + data = inst.operands[0] + if isinstance(data, IRLabel): + data_segments[label].append(f"_sym_{data.value}") + else: + data_segments[label].append(data) asm.extend(list(data_segments.values())) @@ -149,20 +200,27 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: return top_asm def _stack_reorder( - self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + self, assembly: list, stack: StackModel, stack_ops: list[IRVariable] ) -> None: - # make a list so we can index it - stack_ops = [x for x in _stack_ops.keys()] - stack_ops_count = len(_stack_ops) + stack_ops_count = len(stack_ops) + + counts = Counter(stack_ops) for i in range(stack_ops_count): op = stack_ops[i] final_stack_depth = -(stack_ops_count - i - 1) - depth = stack.get_depth(op) # type: ignore + depth = stack.get_depth(op, counts[op]) # type: ignore + counts[op] -= 1 + + if depth == StackModel.NOT_IN_STACK: + raise CompilerPanic(f"Variable {op} not in stack") if depth == final_stack_depth: continue + if op == stack.peek(final_stack_depth): + continue + self.swap(assembly, stack, depth) self.swap(assembly, stack, final_stack_depth) @@ -192,23 +250,20 @@ def _emit_input_operands( continue if isinstance(op, IRLiteral): - assembly.extend([*PUSH(op.value)]) + if op.value < -(2**255): + raise Exception(f"Value too low: {op.value}") + elif op.value >= 2**256: + raise Exception(f"Value too high: {op.value}") + assembly.extend(PUSH(op.value % 2**256)) stack.push(op) continue - if op in inst.dup_requirements: + if op in inst.dup_requirements and op not in emitted_ops: self.dup_op(assembly, stack, op) if op in emitted_ops: self.dup_op(assembly, stack, op) - # REVIEW: this seems like it can be reordered across volatile - # boundaries (which includes memory fences). maybe just - # remove it entirely at this point - if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: - assembly.extend([*PUSH(op.mem_addr)]) - assembly.append("MLOAD") - emitted_ops.add(op) def _generate_evm_for_basicblock_r( @@ -224,12 +279,34 @@ def _generate_evm_for_basicblock_r( self.clean_stack_from_cfg_in(asm, basicblock, stack) - for inst in basicblock.instructions: - asm = self._generate_evm_for_instruction(asm, inst, stack) + param_insts = [inst for inst in basicblock.instructions if inst.opcode == "param"] + main_insts = [inst for inst in basicblock.instructions if inst.opcode != "param"] + + for inst in param_insts: + asm.extend(self._generate_evm_for_instruction(inst, stack)) + + self._clean_unused_params(asm, basicblock, stack) + + for i, inst in enumerate(main_insts): + next_liveness = main_insts[i + 1].liveness if i + 1 < len(main_insts) else OrderedSet() - for bb in basicblock.cfg_out: + asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) + + for bb in basicblock.reachable: self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + def _clean_unused_params(self, asm: list, bb: IRBasicBlock, stack: StackModel) -> None: + for i, inst in enumerate(bb.instructions): + if inst.opcode != "param": + break + if inst.volatile and i + 1 < len(bb.instructions): + liveness = bb.instructions[i + 1].liveness + if inst.output is not None and inst.output not in liveness: + depth = stack.get_depth(inst.output) + if depth != 0: + self.swap(asm, stack, depth) + self.pop(asm, stack) + # pop values from stack at entry to bb # note this produces the same result(!) no matter which basic block # we enter from in the CFG. @@ -258,12 +335,14 @@ def clean_stack_from_cfg_in( continue if depth != 0: - stack.swap(depth) + self.swap(asm, stack, depth) self.pop(asm, stack) def _generate_evm_for_instruction( - self, assembly: list, inst: IRInstruction, stack: StackModel + self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet = None ) -> list[str]: + assembly: list[str | int] = [] + next_liveness = next_liveness or OrderedSet() opcode = inst.opcode # @@ -276,10 +355,22 @@ def _generate_evm_for_instruction( operands = inst.get_non_label_operands() elif opcode == "alloca": operands = inst.operands[1:2] + + # iload and istore are special cases because they can take a literal + # that is handled specialy with the _OFST macro. Look below, after the + # stack reordering. elif opcode == "iload": - operands = [] + addr = inst.operands[0] + if isinstance(addr, IRLiteral): + operands = [] + else: + operands = inst.operands elif opcode == "istore": - operands = inst.operands[0:1] + addr = inst.operands[1] + if isinstance(addr, IRLiteral): + operands = inst.operands[:1] + else: + operands = inst.operands elif opcode == "log": log_topic_count = inst.operands[0].value assert log_topic_count in [0, 1, 2, 3, 4], "Invalid topic count" @@ -289,8 +380,8 @@ def _generate_evm_for_instruction( if opcode == "phi": ret = inst.get_outputs()[0] - phi1, phi2 = inst.get_inputs() - depth = stack.get_phi_depth(phi1, phi2) + phis = inst.get_inputs() + depth = stack.get_phi_depth(phis) # collapse the arguments to the phi node in the stack. # example, for `%56 = %label1 %13 %label2 %14`, we will # find an instance of %13 *or* %14 in the stack and replace it with %56. @@ -301,7 +392,7 @@ def _generate_evm_for_instruction( stack.poke(0, ret) else: stack.poke(depth, ret) - return assembly + return apply_line_numbers(inst, assembly) # Step 2: Emit instruction's input operands self._emit_input_operands(assembly, inst, operands, stack) @@ -313,11 +404,15 @@ def _generate_evm_for_instruction( b = next(iter(inst.parent.cfg_out)) target_stack = input_vars_from(inst.parent, b) # TODO optimize stack reordering at entry and exit from basic blocks - self._stack_reorder(assembly, stack, target_stack) + # NOTE: stack in general can contain multiple copies of the same variable, + # however we are safe in the case of jmp/djmp/jnz as it's not going to + # have multiples. + target_stack_list = list(target_stack) + self._stack_reorder(assembly, stack, target_stack_list) # final step to get the inputs to this instruction ordered # correctly on the stack - self._stack_reorder(assembly, stack, OrderedSet(operands)) + self._stack_reorder(assembly, stack, operands) # type: ignore # some instructions (i.e. invoke) need to do stack manipulations # with the stack model containing the return value(s), so we fiddle @@ -359,7 +454,9 @@ def _generate_evm_for_instruction( assembly.append(f"_sym_{inst.operands[0].value}") assembly.append("JUMP") elif opcode == "djmp": - assert isinstance(inst.operands[0], IRVariable) + assert isinstance( + inst.operands[0], IRVariable + ), f"Expected IRVariable, got {inst.operands[0]}" assembly.append("JUMP") elif opcode == "gt": assembly.append("GT") @@ -367,7 +464,9 @@ def _generate_evm_for_instruction( assembly.append("LT") elif opcode == "invoke": target = inst.operands[0] - assert isinstance(target, IRLabel), "invoke target must be a label" + assert isinstance( + target, IRLabel + ), f"invoke target must be a label (is ${type(target)} ${target})" assembly.extend( [ f"_sym_label_ret_{self.label_counter}", @@ -378,16 +477,12 @@ def _generate_evm_for_instruction( ] ) self.label_counter += 1 - if stack.height > 0 and stack.peek(0) in inst.dup_requirements: - self.pop(assembly, stack) - elif opcode == "call": - assembly.append("CALL") - elif opcode == "staticcall": - assembly.append("STATICCALL") elif opcode == "ret": assembly.append("JUMP") elif opcode == "return": assembly.append("RETURN") + elif opcode == "exit": + assembly.extend(["_sym__ctor_exit", "JUMP"]) elif opcode == "phi": pass elif opcode == "sha3": @@ -395,10 +490,10 @@ def _generate_evm_for_instruction( elif opcode == "sha3_64": assembly.extend( [ - *PUSH(MemoryPositions.FREE_VAR_SPACE2), - "MSTORE", *PUSH(MemoryPositions.FREE_VAR_SPACE), "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", *PUSH(64), *PUSH(MemoryPositions.FREE_VAR_SPACE), "SHA3", @@ -408,12 +503,23 @@ def _generate_evm_for_instruction( assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) elif opcode == "assert": assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "assert_unreachable": + end_symbol = mksymbol("reachable") + assembly.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"]) elif opcode == "iload": - loc = inst.operands[0].value - assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + addr = inst.operands[0] + if isinstance(addr, IRLiteral): + assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + else: + assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.append("MLOAD") elif opcode == "istore": - loc = inst.operands[1].value - assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + addr = inst.operands[1] + if isinstance(addr, IRLiteral): + assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + else: + assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.append("MSTORE") elif opcode == "log": assembly.extend([f"LOG{log_topic_count}"]) else: @@ -421,19 +527,20 @@ def _generate_evm_for_instruction( # Step 6: Emit instructions output operands (if any) if inst.output is not None: - assert isinstance(inst.output, IRVariable), "Return value must be a variable" - if inst.output.mem_type == MemType.MEMORY: - assembly.extend([*PUSH(inst.output.mem_addr)]) + if "call" in inst.opcode and inst.output not in next_liveness: + self.pop(assembly, stack) - return assembly + return apply_line_numbers(inst, assembly) def pop(self, assembly, stack, num=1): stack.pop(num) assembly.extend(["POP"] * num) def swap(self, assembly, stack, depth): + # Swaps of the top is no op if depth == 0: return + stack.swap(depth) assembly.append(_evm_swap_for(depth)) @@ -450,11 +557,13 @@ def dup_op(self, assembly, stack, op): def _evm_swap_for(depth: int) -> str: swap_idx = -depth - assert 1 <= swap_idx <= 16, "Unsupported swap depth" + if not (1 <= swap_idx <= 16): + raise StackTooDeep(f"Unsupported swap depth {swap_idx}") return f"SWAP{swap_idx}" def _evm_dup_for(depth: int) -> str: dup_idx = 1 - depth - assert 1 <= dup_idx <= 16, "Unsupported dup depth" + if not (1 <= dup_idx <= 16): + raise StackTooDeep(f"Unsupported dup depth {dup_idx}") return f"DUP{dup_idx}" From e34ca9ca75b6c55dd5c5a0b49a5a5689b2e71bcd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 2 Apr 2024 14:09:57 -0400 Subject: [PATCH 4/4] fix[tool]: fix `combined_json` output for CLI (#3901) the output json would not be produced because Path does not have a json serializer there are actually tests for `combined_json`, but they test the `compile_files` API directly, whereas the offending code is in the very outer `_cli_helper()` function. the best (long-term) way to test this might be to have a harness which runs the vyper CLI directly from shell, but that is not explored here to reduce scope. --- vyper/cli/vyper_compile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 778d68b5b1..7a3aa800f6 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -65,6 +65,7 @@ def _parse_cli_args(): def _cli_helper(f, output_formats, compiled): if output_formats == ("combined_json",): + compiled = {str(path): v for (path, v) in compiled.items()} print(json.dumps(compiled), file=f) return