diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8bd03de79b..10312413e9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,8 +67,9 @@ jobs: matrix: python-version: [["3.11", "311"]] opt-mode: ["gas", "none", "codesize"] - evm-version: [shanghai] debug: [true, false] + evm-version: [shanghai] + experimental-codegen: [false] memorymock: [false] # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs#expanding-or-adding-matrix-configurations @@ -94,6 +95,14 @@ jobs: opt-mode: gas evm-version: cancun + # test experimental pipeline + - python-version: ["3.11", "311"] + opt-mode: gas + debug: false + evm-version: shanghai + experimental-codegen: true + # TODO: test experimental_codegen + -Ocodesize + # run with `--memorymock`, but only need to do it one configuration # TODO: consider removing the memorymock tests - python-version: ["3.11", "311"] @@ -108,12 +117,14 @@ jobs: opt-mode: gas debug: false evm-version: shanghai + - python-version: ["3.12", "312"] opt-mode: gas debug: false evm-version: shanghai - name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}-${{ matrix.evm-version }} + + name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}${{ matrix.experimental-codegen && '-experimental' || '' }}-${{ matrix.evm-version }} steps: - uses: actions/checkout@v4 @@ -141,6 +152,7 @@ jobs: --evm-version ${{ matrix.evm-version }} \ ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} \ ${{ matrix.memorymock && '--memorymock' || '' }} \ + ${{ matrix.experimental-codegen && '--experimental-codegen' || '' }} \ --cov-branch \ --cov-report xml:coverage.xml \ --cov=vyper \ @@ -193,8 +205,7 @@ jobs: # NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo. # `pytest -m "fuzzing" --store-durations -r aR tests/` - name: Fetch test-durations - run: | - curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations + run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations - name: Run tests run: | diff --git a/setup.cfg b/setup.cfg index 467c6a372b..f84c947981 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ testpaths = tests xfail_strict = true markers = fuzzing: Run Hypothesis fuzz test suite (deselect with '-m "not fuzzing"') + venom_xfail: mark a test case as a regression (expected to fail) under the venom pipeline diff --git a/setup.py b/setup.py index b403405122..933e8bfa4b 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "hypothesis[lark]>=6.0,<7.0", "eth-stdlib==0.2.7", "setuptools", + "hexbytes<1.0", "typing_extensions", # we can remove this once dependencies upgrade to eth-rlp>=2.0 ], "lint": [ diff --git a/tests/conftest.py b/tests/conftest.py index 2e5f11b9b8..d0681cdf42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ def pytest_addoption(parser): help="change optimization mode", ) parser.addoption("--enable-compiler-debug-mode", action="store_true") + parser.addoption("--experimental-codegen", action="store_true") parser.addoption( "--evm-version", @@ -73,6 +74,8 @@ def output_formats(): output_formats = compiler.OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] return output_formats @@ -89,6 +92,36 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture(scope="session") +def experimental_codegen(pytestconfig): + ret = pytestconfig.getoption("experimental_codegen") + assert isinstance(ret, bool) + return ret + + +@pytest.fixture(autouse=True) +def check_venom_xfail(request, experimental_codegen): + if not experimental_codegen: + return + + marker = request.node.get_closest_marker("venom_xfail") + if marker is None: + return + + # https://github.com/okken/pytest-runtime-xfail?tab=readme-ov-file#alternatives + request.node.add_marker(pytest.mark.xfail(strict=True, **marker.kwargs)) + + +@pytest.fixture +def venom_xfail(request, experimental_codegen): + def _xfail(*args, **kwargs): + if not experimental_codegen: + return + request.node.add_marker(pytest.mark.xfail(*args, strict=True, **kwargs)) + + return _xfail + + @pytest.fixture(scope="session", autouse=True) def evm_version(pytestconfig): # note: we configure the evm version that we emit code for, @@ -108,6 +141,7 @@ def chdir_tmp_path(tmp_path): yield +# CMC 2024-03-01 this doesn't need to be a fixture @pytest.fixture def keccak(): return Web3.keccak @@ -321,6 +355,7 @@ def _get_contract( w3, source_code, optimize, + experimental_codegen, output_formats, *args, override_opt_level=None, @@ -329,6 +364,7 @@ def _get_contract( ): settings = Settings() settings.optimize = override_opt_level or optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, # test that all output formats can get generated @@ -352,17 +388,21 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize, output_formats): +def get_contract(w3, optimize, experimental_codegen, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): +def get_contract_with_gas_estimation(tester, w3, optimize, experimental_codegen, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + contract = _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -372,15 +412,19 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): +def get_contract_with_gas_estimation_for_constants( + w3, optimize, experimental_codegen, output_formats +): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize, output_formats): +def get_contract_module(optimize, experimental_codegen, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -393,16 +437,25 @@ def get_contract_module(optimize, output_formats): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_module def _deploy_blueprint_for( - w3, source_code, optimize, output_formats, initcode_prefix=ERC5202_PREFIX, **kwargs + w3, + source_code, + optimize, + experimental_codegen, + output_formats, + initcode_prefix=ERC5202_PREFIX, + **kwargs, ): settings = Settings() settings.optimize = optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, output_formats=output_formats, @@ -438,9 +491,11 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize, output_formats): +def deploy_blueprint_for(w3, optimize, experimental_codegen, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) + return _deploy_blueprint_for( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return deploy_blueprint_for diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index d281851f8e..213738957b 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -3,7 +3,7 @@ import pytest from eth.codecs import abi -from vyper.exceptions import ArgumentException, StructureException +from vyper.exceptions import ArgumentException, StackTooDeep, StructureException TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() @@ -196,6 +196,7 @@ def abi_decode(x: Bytes[{len}]) -> DynArray[DynArray[uint256, 3], 3]: @pytest.mark.parametrize("args", nested_3d_array_args) @pytest.mark.parametrize("unwrap_tuple", (True, False)) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_nested_dynarray2(get_contract, args, unwrap_tuple): if unwrap_tuple is True: encoded = abi.encode("(uint256[][][])", (args,)) @@ -273,6 +274,7 @@ def foo(bs: Bytes[160]) -> (uint256, DynArray[uint256, 3]): assert c.foo(encoded) == [2**256 - 1, bs] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_private_nested_dynarray(get_contract): code = """ bytez: DynArray[DynArray[DynArray[uint256, 3], 3], 3] diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f014c47a19..305c4b1356 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -3,6 +3,8 @@ import pytest from eth.codecs import abi +from vyper.exceptions import StackTooDeep + # @pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) def test_abi_encode(get_contract): @@ -226,6 +228,7 @@ def abi_encode( @pytest.mark.parametrize("args", nested_3d_array_args) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_nested_dynarray_2(get_contract, args): code = """ @external @@ -330,6 +333,7 @@ def foo(bs: DynArray[uint256, 3]) -> (uint256, Bytes[160]): assert c.foo(bs) == [2**256 - 1, abi.encode("(uint256[])", (bs,))] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_private_nested_dynarray(get_contract): code = """ bytez: Bytes[1696] diff --git a/tests/functional/codegen/features/decorators/test_pure.py b/tests/functional/codegen/features/decorators/test_pure.py index 5f4c168687..7c49c2091b 100644 --- a/tests/functional/codegen/features/decorators/test_pure.py +++ b/tests/functional/codegen/features/decorators/test_pure.py @@ -1,21 +1,22 @@ +import pytest + +from vyper.compiler import compile_code from vyper.exceptions import FunctionDeclarationException, StateAccessViolation -def test_pure_operation(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_operation(get_contract): + code = """ @pure @external def foo() -> int128: return 5 """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_call(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_call(get_contract): + code = """ @pure @internal def _foo() -> int128: @@ -26,21 +27,18 @@ def _foo() -> int128: def foo() -> int128: return self._foo() """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_interface(get_contract_with_gas_estimation_for_constants): - c1 = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_interface(get_contract): + code1 = """ @pure @external def foo() -> int128: return 5 """ - ) - c2 = get_contract_with_gas_estimation_for_constants( - """ + code2 = """ interface Foo: def foo() -> int128: pure @@ -49,28 +47,35 @@ def foo() -> int128: pure def foo(a: address) -> int128: return staticcall Foo(a).foo() """ - ) + c1 = get_contract(code1) + c2 = get_contract(code2) assert c2.foo(c1.address) == 5 -def test_invalid_envar_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_envar_access(get_contract): + code = """ @pure @external def foo() -> uint256: return chain.id """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_codesize_access(get_contract): + code = """ +@pure +@external +def foo(s: address) -> uint256: + return s.codesize + """ + with pytest.raises(StateAccessViolation): + compile_code(code) def test_invalid_state_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ + code = """ x: uint256 @pure @@ -78,29 +83,84 @@ def test_invalid_state_access(get_contract, assert_compile_failed): def foo() -> uint256: return self.x """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_immutable_access(): + code = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 1234 + +@pure +@external +def foo() -> uint256: + return COUNTER + """ + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_self_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_self_access(): + code = """ @pure @external def foo() -> address: return self """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_module_variable_access(make_input_bundle): + lib1 = """ +counter: uint256 + """ + code = """ +import lib1 +initializes: lib1 + +@pure +@external +def foo() -> uint256: + return lib1.counter + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) + + +def test_invalid_module_immutable_access(make_input_bundle): + lib1 = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 123 + """ + code = """ +import lib1 +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + +@pure +@external +def foo() -> uint256: + return lib1.COUNTER + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) -def test_invalid_call(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_call(): + code = """ @view @internal def _foo() -> uint256: @@ -111,21 +171,17 @@ def _foo() -> uint256: def foo() -> uint256: return self._foo() # Fails because of calling non-pure fn """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_conflicting_decorators(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_conflicting_decorators(): + code = """ @pure @external @payable def foo() -> uint256: return 5 """ - ), - FunctionDeclarationException, - ) + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 578413a8f4..fe51c026fe 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -4,6 +4,7 @@ from eth.codecs import abi from eth_utils import keccak +from vyper.exceptions import StackTooDeep from vyper.utils import int_bounds @@ -506,6 +507,7 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_multidimension_dynarray_clamper_passing(w3, get_contract, value): code = """ @external diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index 9146ace8a6..d96a889497 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -1,6 +1,8 @@ import pytest from web3.exceptions import ValidationError +from vyper.exceptions import StackTooDeep + def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ @@ -163,6 +165,7 @@ def get_foo() -> uint256: assert c.get_foo() == 39 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estimation): code = """ foo: int128 @@ -208,6 +211,7 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]: assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 49ff54b353..874600633a 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -1,6 +1,7 @@ import pytest from vyper.compiler.settings import OptimizationLevel +from vyper.exceptions import StackTooDeep @pytest.mark.parametrize( @@ -198,6 +199,7 @@ def get_idx_two() -> uint256: assert c.get_idx_two() == expected_values[2][2] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index b55f07639b..efa2799480 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -8,6 +8,7 @@ ArrayIndexException, ImmutableViolation, OverflowException, + StackTooDeep, StateAccessViolation, TypeMismatch, ) @@ -60,6 +61,7 @@ def loo(x: DynArray[DynArray[int128, 2], 2]) -> int128: print("Passed list tests") +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_string_list(get_contract): code = """ @external @@ -732,6 +734,7 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]: assert c.test_array_decimal_return3() == [[1.0, 2.0], [3.0]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_mult_list(get_contract_with_gas_estimation): code = """ nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2] @@ -1478,6 +1481,7 @@ def foo(x: int128) -> int128: assert c.foo(7) == 392 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_struct_of_lists(get_contract): code = """ struct Foo: @@ -1566,6 +1570,7 @@ def bar(x: int128) -> DynArray[int128, 3]: assert c.bar(7) == [7, 14] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): code = """ struct nestedFoo: @@ -1695,7 +1700,9 @@ def __init__(): ("DynArray[DynArray[DynArray[uint256, 5], 5], 5]", [[[], []], []]), ], ) -def test_empty_nested_dynarray(get_contract, typ, val): +def test_empty_nested_dynarray(get_contract, typ, val, venom_xfail): + if val == [[[], []], []]: + venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") code = f""" @external def foo() -> {typ}: diff --git a/tests/functional/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py index 0c5cf61b04..18f6222c20 100644 --- a/tests/functional/examples/factory/test_factory.py +++ b/tests/functional/examples/factory/test_factory.py @@ -31,12 +31,14 @@ def create_exchange(token, factory): @pytest.fixture -def factory(get_contract, optimize): +def factory(get_contract, optimize, experimental_codegen): with open("examples/factory/Exchange.vy") as f: code = f.read() exchange_interface = vyper.compile_code( - code, output_formats=["bytecode_runtime"], settings=Settings(optimize=optimize) + code, + output_formats=["bytecode_runtime"], + settings=Settings(optimize=optimize, experimental_codegen=experimental_codegen), ) exchange_deployed_bytecode = exchange_interface["bytecode_runtime"] diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index 6556fc90b9..6be50a509b 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -161,7 +161,7 @@ def test_address_code_compile_success(code: str): compiler.compile_code(code) -def test_address_code_self_success(get_contract, optimize): +def test_address_code_self_success(get_contract, optimize, experimental_codegen): code = """ code_deployment: public(Bytes[32]) @@ -174,7 +174,7 @@ def code_runtime() -> Bytes[32]: return slice(self.code, 0, 32) """ contract = get_contract(code) - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) code_compiled = compiler.compile_code( code, output_formats=["bytecode", "bytecode_runtime"], settings=settings ) diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index d351981946..7aa01a68e9 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -3,7 +3,7 @@ from vyper.utils import keccak256 -def test_get_extcodehash(get_contract, optimize): +def test_get_extcodehash(get_contract, optimize, experimental_codegen): code = """ a: address @@ -28,7 +28,7 @@ def foo3() -> bytes32: def foo4() -> bytes32: return self.a.codehash """ - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) compiled = compile_code(code, output_formats=["bytecode_runtime"], settings=settings) bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:]) hash_ = keccak256(bytecode) diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 4fe2111f43..62a799db65 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,11 +113,13 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") - # remove bb and bb_runtime from output formats + # remove venom related from output formats # because they require venom (experimental) output_formats = OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] foo = compile_from_file_input( foo_input, output_formats=output_formats, input_bundle=input_bundle ) diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index ce32249202..5742f7c8df 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -95,7 +95,7 @@ def test_dead_code_eliminator(code): assert all(ctor_only not in instr for instr in runtime_asm) -def test_library_code_eliminator(make_input_bundle): +def test_library_code_eliminator(make_input_bundle, experimental_codegen): library = """ @internal def unused1(): @@ -120,5 +120,6 @@ def foo(): res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) asm = res["asm"] assert "some_function()" in asm + assert "unused1()" not in asm assert "unused2()" not in asm diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py new file mode 100644 index 0000000000..dc27380796 --- /dev/null +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -0,0 +1,73 @@ +from typing import Optional + +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def _add_bb( + ctx: IRFunction, label: IRLabel, cfg_outs: [IRLabel], bb: Optional[IRBasicBlock] = None +) -> IRBasicBlock: + bb = bb if bb is not None else IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + cfg_outs_len = len(cfg_outs) + if cfg_outs_len == 0: + bb.append_instruction("stop") + elif cfg_outs_len == 1: + bb.append_instruction("jmp", cfg_outs[0]) + elif cfg_outs_len == 2: + bb.append_instruction("jnz", IRLiteral(1), cfg_outs[0], cfg_outs[1]) + else: + raise CompilerPanic("Invalid number of CFG outs") + return bb + + +def _make_test_ctx(): + lab = [IRLabel(str(i)) for i in range(0, 9)] + + ctx = IRFunction(lab[1]) + + bb1 = ctx.basic_blocks[0] + bb1.append_instruction("jmp", lab[2]) + + _add_bb(ctx, lab[7], []) + _add_bb(ctx, lab[6], [lab[7], lab[2]]) + _add_bb(ctx, lab[5], [lab[6], lab[3]]) + _add_bb(ctx, lab[4], [lab[6]]) + _add_bb(ctx, lab[3], [lab[5]]) + _add_bb(ctx, lab[2], [lab[3], lab[4]]) + + return ctx + + +def test_deminator_frontier_calculation(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + calculate_cfg(ctx) + dom = DominatorTree.build_dominator_tree(ctx, bb1) + df = dom.dominator_frontiers + + assert len(df[bb1]) == 0, df[bb1] + assert df[bb2] == OrderedSet({bb2}), df[bb2] + assert df[bb3] == OrderedSet({bb3, bb6}), df[bb3] + assert df[bb4] == OrderedSet({bb6}), df[bb4] + assert df[bb5] == OrderedSet({bb3, bb6}), df[bb5] + assert df[bb6] == OrderedSet({bb2}), df[bb6] + assert len(df[bb7]) == 0, df[bb7] + + +def test_phi_placement(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + x = IRVariable("%x") + bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], x), 0) + bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], x), 0) + bb7.insert_instruction(IRInstruction("mstore", [x, IRLiteral(0)]), 0) + + MakeSSA.run_pass(ctx, bb1) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 437185cc72..7cc58e6f5c 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -23,5 +23,4 @@ def test_duplicate_operands(): bb.append_instruction("stop") asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) - - assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP"] + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "ADD", "MUL", "STOP"] diff --git a/tests/unit/compiler/venom/test_liveness_simple_loop.py b/tests/unit/compiler/venom/test_liveness_simple_loop.py new file mode 100644 index 0000000000..e725518179 --- /dev/null +++ b/tests/unit/compiler/venom/test_liveness_simple_loop.py @@ -0,0 +1,16 @@ +import vyper +from vyper.compiler.settings import Settings + +source = """ +@external +def foo(a: uint256): + _numBids: uint256 = 20 + b: uint256 = 10 + + for i: uint256 in range(128): + b = 1 + _numBids +""" + + +def test_liveness_simple_loop(): + vyper.compile_code(source, ["opcodes"], settings=Settings(experimental_codegen=True)) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py new file mode 100644 index 0000000000..2a04dfc134 --- /dev/null +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -0,0 +1,48 @@ +from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def test_phi_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + bb_cont = IRBasicBlock(IRLabel("condition"), ctx) + bb_then = IRBasicBlock(IRLabel("then"), ctx) + bb_else = IRBasicBlock(IRLabel("else"), ctx) + bb_if_exit = IRBasicBlock(IRLabel("if_exit"), ctx) + ctx.append_basic_block(bb_cont) + ctx.append_basic_block(bb_then) + ctx.append_basic_block(bb_else) + ctx.append_basic_block(bb_if_exit) + + v = bb.append_instruction("mload", 64) + bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) + + bb_if_exit.append_instruction("add", v, 1, ret=v) + bb_if_exit.append_instruction("jmp", bb_cont.label) + + bb_then.append_instruction("assert", bb_then.append_instruction("mload", 96)) + bb_then.append_instruction("jmp", bb_if_exit.label) + bb_else.append_instruction("jmp", bb_if_exit.label) + + bb.append_instruction("jmp", bb_cont.label) + + calculate_cfg(ctx) + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + calculate_liveness(ctx) + + condition_block = ctx.get_basic_block("condition") + assert len(condition_block.instructions) == 2 + + phi_inst = condition_block.instructions[0] + assert phi_inst.opcode == "phi" + assert phi_inst.operands[0].name == "_global" + assert phi_inst.operands[1].name == "%1" + assert phi_inst.operands[2].name == "if_exit" + assert phi_inst.operands[3].name == "%1" + assert phi_inst.output.name == "%1" + assert phi_inst.output.value != phi_inst.operands[1].value + assert phi_inst.output.value != phi_inst.operands[3].value diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 6d8b074994..47f4b88707 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -39,10 +39,10 @@ def test_multi_entry_block_1(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" # more complicated one @@ -91,10 +91,10 @@ def test_multi_entry_block_2(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" def test_multi_entry_block_with_dynamic_jump(): @@ -132,7 +132,7 @@ def test_multi_entry_block_with_dynamic_jump(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" diff --git a/tests/unit/compiler/venom/test_variables.py b/tests/unit/compiler/venom/test_variables.py new file mode 100644 index 0000000000..cded8d0e1a --- /dev/null +++ b/tests/unit/compiler/venom/test_variables.py @@ -0,0 +1,8 @@ +from vyper.venom.basicblock import IRVariable + + +def test_variable_equality(): + v1 = IRVariable("%x") + v2 = IRVariable("%x") + assert v1 == v2 + assert v1 != IRVariable("%y") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f673bb765c..fe01bf9260 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -141,7 +141,7 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class Name(VyperNode): +class Name(ExprNode): id: str = ... _type: str = ... diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 227b639ad5..f0c339cca7 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -203,6 +203,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: if evm_version not in EVM_VERSIONS: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + elif pragma.startswith("experimental-codegen"): + if settings.experimental_codegen is not None: + raise StructureException( + "pragma experimental-codegen specified twice!", start + ) + settings.experimental_codegen = True else: raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 778d68b5b1..7a3aa800f6 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -65,6 +65,7 @@ def _parse_cli_args(): def _cli_helper(f, output_formats, compiled): if output_formats == ("combined_json",): + compiled = {str(path): v for (path, v) in compiled.items()} print(json.dumps(compiled), file=f) return diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 6f783bb9c5..fe706699bb 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -154,10 +154,6 @@ def _adjust_gas_estimate(func_t, common_ir): common_ir.add_gas_estimate += mem_expansion_cost func_t._ir_info.gas_estimate = common_ir.gas - # pass metadata through for venom pipeline: - common_ir.passthrough_metadata["func_t"] = func_t - common_ir.passthrough_metadata["frame_info"] = frame_info - def generate_ir_for_external_function(code, compilation_target): # TODO type hints: diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 0cf9850b70..cde1ec5c87 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -80,10 +80,6 @@ def generate_ir_for_internal_function( # tag gas estimate and frame info func_t._ir_info.gas_estimate = ir_node.gas - frame_info = tag_frame_info(func_t, context) - - # pass metadata through for venom pipeline: - ir_node.passthrough_metadata["frame_info"] = frame_info - ir_node.passthrough_metadata["func_t"] = func_t + tag_frame_info(func_t, context) return InternalFuncIR(ir_node) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index da585ff0a1..a8dac640db 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -42,7 +42,6 @@ def finalize(fill_return_buffer): # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break jump_to_exit_ir = IRnode.from_list(jump_to_exit) - jump_to_exit_ir.passthrough_metadata["func_t"] = func_t return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f53e4a81b4..2363de3641 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -112,6 +112,4 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True - o.passthrough_metadata["func_t"] = func_t - o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index ee909a57d4..84aea73071 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -25,6 +25,8 @@ "interface": output.build_interface_output, "bb": output.build_bb_output, "bb_runtime": output.build_bb_runtime_output, + "cfg": output.build_cfg_output, + "cfg_runtime": output.build_cfg_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index de8e34370d..f8beb9d11b 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -90,6 +90,14 @@ def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: return compiler_data.venom_functions[1] +def build_cfg_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[0].as_graph() + + +def build_cfg_runtime_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[1].as_graph() + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index e343938021..d794185195 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -97,8 +97,6 @@ def __init__( no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False """ - # to force experimental codegen, uncomment: - # settings.experimental_codegen = True if isinstance(file_input, str): file_input = FileInput( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3897f0ea41..996a1ddbd9 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -386,6 +386,10 @@ class CodegenPanic(VyperInternalException): """Invalid code generated during codegen phase""" +class StackTooDeep(CodegenPanic): + """Stack too deep""" # (should not happen) + + class UnexpectedNodeType(VyperInternalException): """Unexpected AST node type.""" diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index e4a4cc60f7..191803295e 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -1020,7 +1020,11 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue - if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + if ( + isinstance(assembly[i], str) + and assembly[i].startswith("SWAP") + and assembly[i] == assembly[i + 1] + ): changed = True del assembly[i : i + 2] if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e424f94e19..82b8dbe359 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -8,6 +8,7 @@ from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.semantics.types.primitives import SelfT from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: @@ -199,8 +200,11 @@ def set_position(self, position: VarOffset) -> None: assert isinstance(position, VarOffset) # sanity check self.position = position - def is_module_variable(self): - return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + def is_state_variable(self): + non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA) + # `self` gets a VarInfo, but it is not considered a state + # variable (it is magic), so we ignore it here. + return self.location not in non_state_locations and not isinstance(self.typ, SelfT) def get_size(self) -> int: return self.typ.get_size_in(self.location) diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index bb4322c7b2..e5e8b998ca 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -282,9 +282,9 @@ def _allocate_layout_r( continue assert isinstance(node, vy_ast.VariableDecl) - # skip non-storage variables + # skip non-state variables varinfo = node.target._metadata["varinfo"] - if not varinfo.is_module_variable(): + if not varinfo.is_state_variable(): continue location = varinfo.location diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 33dcce3645..1b2e3252c8 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -38,7 +38,7 @@ from vyper.semantics.data_locations import DataLocation # TODO consolidate some of these imports -from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS +from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( TYPE_T, @@ -51,6 +51,7 @@ HashMapT, IntegerT, SArrayT, + SelfT, StringT, StructT, TupleT, @@ -164,21 +165,32 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None: raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) -def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: - env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) - if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: - if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: - return +def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None: + info = get_expr_info(node) - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", node - ) + env_vars = CONSTANT_ENVIRONMENT_VARS + # check env variable access like `block.number` + if isinstance(node, vy_ast.Attribute): + if node.get("value.id") in env_vars: + raise StateAccessViolation( + "not allowed to query environment variables in pure functions" + ) + parent_info = get_expr_info(node.value) + if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members: + raise StateAccessViolation("not allowed to query address members in pure functions") + if (varinfo := info.var_info) is None: + return + # self is magic. we only need to check it if it is not the root of an Attribute + # node. (i.e. it is bare like `self`, not `self.foo`) + is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance( + node.get_ancestor(), vy_ast.Attribute + ) + if is_naked_self: + raise StateAccessViolation("not allowed to query `self` in pure functions") -def _validate_self_reference(node: vy_ast.Name) -> None: - # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through - if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): - raise StateAccessViolation("not allowed to query self in pure functions", node) + if varinfo.is_state_variable() or is_naked_self: + raise StateAccessViolation("not allowed to query state variables in pure functions") # analyse the variable access for the attribute chain for a node @@ -429,7 +441,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode): - if not var_access.variable.is_module_variable(): + if not var_access.variable.is_state_variable(): return root_module_info = check_module_uses(target) @@ -693,7 +705,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_value_access(node) if self.func and self.func.mutability == StateMutability.PURE: - _validate_pure_access(node, typ) + _validate_pure_access(node) value_type = get_exact_type_from_node(node.value) @@ -765,10 +777,10 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if not func_type.from_interface: for s in func_type.get_variable_writes(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._writes.add(s) for s in func_type.get_variable_reads(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._reads.add(s) if self.function_analyzer: @@ -873,10 +885,8 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func: - # TODO: refactor to use expr_info mutability - if self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func and self.func.mutability == StateMutability.PURE: + _validate_pure_access(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index b4f4381444..84199ec82c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -536,7 +536,7 @@ def visit_ExportsDecl(self, node): # check module uses var_accesses = func_t.get_variable_accesses() - if any(s.variable.is_module_variable() for s in var_accesses): + if any(s.variable.is_state_variable() for s in var_accesses): module_info = check_module_uses(item) assert module_info is not None # guaranteed by above checks used_modules.add(module_info) diff --git a/vyper/utils.py b/vyper/utils.py index ba615e58d7..114ddf97c2 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -15,7 +15,7 @@ _T = TypeVar("_T") -class OrderedSet(Generic[_T], dict[_T, None]): +class OrderedSet(Generic[_T]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -25,57 +25,82 @@ class OrderedSet(Generic[_T], dict[_T, None]): """ def __init__(self, iterable=None): - super().__init__() + self._data = dict() if iterable is not None: - for item in iterable: - self.add(item) + self.update(iterable) def __repr__(self): - keys = ", ".join(repr(k) for k in self.keys()) + keys = ", ".join(repr(k) for k in self) return f"{{{keys}}}" - def get(self, *args, **kwargs): - raise RuntimeError("can't call get() on OrderedSet!") + def __iter__(self): + return iter(self._data) + + def __contains__(self, item): + return self._data.__contains__(item) + + def __len__(self): + return len(self._data) def first(self): return next(iter(self)) def add(self, item: _T) -> None: - self[item] = None + self._data[item] = None def remove(self, item: _T) -> None: - del self[item] + del self._data[item] + + def drop(self, item: _T): + # friendly version of remove + self._data.pop(item, None) + + def dropmany(self, iterable): + for item in iterable: + self._data.pop(item, None) def difference(self, other): ret = self.copy() - for k in other.keys(): - if k in ret: - ret.remove(k) + ret.dropmany(other) return ret + def update(self, other): + # CMC 2024-03-22 for some reason, this is faster than dict.update? + # (maybe size dependent) + for item in other: + self._data[item] = None + def union(self, other): return self | other - def update(self, other): - super().update(self.__class__.fromkeys(other)) + def __ior__(self, other): + self.update(other) + return self def __or__(self, other): - return self.__class__(super().__or__(other)) + ret = self.copy() + ret.update(other) + return ret + + def __eq__(self, other): + return self._data == other._data def copy(self): - return self.__class__(super().copy()) + cls = self.__class__ + ret = cls.__new__(cls) + ret._data = self._data.copy() + return ret @classmethod def intersection(cls, *sets): - res = OrderedSet() if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - if len(sets) == 1: - return sets[0].copy() - for e in sets[0].keys(): - if all(e in s for s in sets[1:]): - res.add(e) - return res + + ret = sets[0].copy() + for e in sets[0]: + if any(e not in s for s in sets[1:]): + ret.remove(e) + return ret class StringEnum(enum.Enum): diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index d1c2d0c342..2efd58ad6c 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -11,10 +11,14 @@ ir_pass_optimize_unused_variables, ir_pass_remove_unreachable_blocks, ) +from vyper.venom.dominators import DominatorTree from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation from vyper.venom.passes.dft import DFTPass +from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes.simplify_cfg import SimplifyCFGPass from vyper.venom.venom_to_assembly import VenomCompiler DEFAULT_OPT_LEVEL = OptimizationLevel.default() @@ -38,6 +42,24 @@ def generate_assembly_experimental( def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: # Run passes on Venom IR # TODO: Add support for optimization levels + + ir_pass_optimize_empty_blocks(ctx) + ir_pass_remove_unreachable_blocks(ctx) + + internals = [ + bb + for bb in ctx.basic_blocks + if bb.label.value.startswith("internal") and len(bb.cfg_in) == 0 + ] + + SimplifyCFGPass.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + SimplifyCFGPass.run_pass(ctx, entry) + + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + MakeSSA.run_pass(ctx, entry) + while True: changes = 0 @@ -51,7 +73,6 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: calculate_cfg(ctx) calculate_liveness(ctx) - changes += ir_pass_constant_propagation(ctx) changes += DFTPass.run_pass(ctx) calculate_cfg(ctx) diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index daebd2560c..066a60f45e 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -1,3 +1,5 @@ +from typing import Optional + from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet from vyper.venom.basicblock import ( @@ -38,6 +40,7 @@ def calculate_cfg(ctx: IRFunction) -> None: def _reset_liveness(ctx: IRFunction) -> None: for bb in ctx.basic_blocks: + bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() @@ -50,16 +53,15 @@ def _calculate_liveness(bb: IRBasicBlock) -> bool: orig_liveness = bb.instructions[0].liveness.copy() liveness = bb.out_vars.copy() for instruction in reversed(bb.instructions): - ops = instruction.get_inputs() + ins = instruction.get_inputs() + outs = instruction.get_outputs() - for op in ops: - if op in liveness: - instruction.dup_requirements.add(op) + if ins or outs: + # perf: only copy if changed + liveness = liveness.copy() + liveness.update(ins) + liveness.dropmany(outs) - liveness = liveness.union(OrderedSet.fromkeys(ops)) - out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None - if out in liveness: - liveness.remove(out) instruction.liveness = liveness return orig_liveness != bb.instructions[0].liveness @@ -89,6 +91,18 @@ def calculate_liveness(ctx: IRFunction) -> None: break +def calculate_dup_requirements(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + last_liveness = bb.out_vars + for inst in reversed(bb.instructions): + inst.dup_requirements = OrderedSet() + ops = inst.get_inputs() + for op in ops: + if op in last_liveness: + inst.dup_requirements.add(op) + last_liveness = inst.liveness + + # calculate the input variables into self from source def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: liveness = target.instructions[0].liveness.copy() @@ -104,19 +118,17 @@ def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IR # will arbitrarily choose either %12 or %14 to be in the liveness # set, and then during instruction selection, after this instruction, # %12 will be replaced by %56 in the liveness set - source1, source2 = inst.operands[0], inst.operands[2] - phi1, phi2 = inst.operands[1], inst.operands[3] - if source.label == source1: - liveness.add(phi1) - if phi2 in liveness: - liveness.remove(phi2) - elif source.label == source2: - liveness.add(phi2) - if phi1 in liveness: - liveness.remove(phi1) - else: - # bad path into this phi node - raise CompilerPanic(f"unreachable: {inst}") + + # bad path into this phi node + if source.label not in inst.operands: + raise CompilerPanic(f"unreachable: {inst} from {source.label}") + + for label, var in inst.phi_operands: + if label == source.label: + liveness.add(var) + else: + if var in liveness: + liveness.remove(var) return liveness @@ -137,8 +149,8 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]: return self._dfg_inputs.get(op, []) # the instruction which produces this variable. - def get_producing_instruction(self, op: IRVariable) -> IRInstruction: - return self._dfg_outputs[op] + def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: + return self._dfg_outputs.get(op) @classmethod def build_dfg(cls, ctx: IRFunction) -> "DFG": @@ -163,3 +175,20 @@ def build_dfg(cls, ctx: IRFunction) -> "DFG": dfg._dfg_outputs[op] = inst return dfg + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dfg + """ + lines = ["digraph dfg_graph {"] + for var, inputs in self._dfg_inputs.items(): + for input in inputs: + for op in input.get_outputs(): + if isinstance(op, IRVariable): + lines.append(f' " {var.name} " -> " {op.name} "') + + lines.append("}") + return "\n".join(lines) + + def __repr__(self) -> str: + return self.as_graph() diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index ed70a5eaa0..6c509d8f95 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,10 +1,9 @@ -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -12,15 +11,22 @@ "alloca", "call", "staticcall", + "delegatecall", "invoke", "sload", "sstore", "iload", "istore", + "tload", + "tstore", "assert", + "assert_unreachable", "mstore", "mload", "calldatacopy", + "mcopy", + "extcodecopy", + "returndatacopy", "codecopy", "dloadbytes", "dload", @@ -39,11 +45,15 @@ "istore", "dloadbytes", "calldatacopy", + "mcopy", + "returndatacopy", "codecopy", + "extcodecopy", "return", "ret", "revert", "assert", + "assert_unreachable", "selfdestruct", "stop", "invalid", @@ -52,6 +62,7 @@ "djmp", "jnz", "log", + "exit", ] ) @@ -87,6 +98,10 @@ class IROperand: value: Any + @property + def name(self) -> str: + return self.value + class IRValue(IROperand): """ @@ -109,13 +124,16 @@ def __init__(self, value: int) -> None: assert isinstance(value, int), "value must be an int" self.value = value - def __repr__(self) -> str: - return str(self.value) + def __hash__(self) -> int: + return self.value.__hash__() + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value -class MemType(Enum): - OPERAND_STACK = auto() - MEMORY = auto() + def __repr__(self) -> str: + return str(self.value) class IRVariable(IRValue): @@ -126,18 +144,34 @@ class IRVariable(IRValue): value: str offset: int = 0 - # some variables can be in memory for conversion from legacy IR to venom - mem_type: MemType = MemType.OPERAND_STACK - mem_addr: Optional[int] = None - - def __init__( - self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None - ) -> None: + def __init__(self, value: str, version: Optional[str | int] = None) -> None: assert isinstance(value, str) + assert ":" not in value, "Variable name cannot contain ':'" + if version: + assert isinstance(value, str) or isinstance(value, int), "value must be an str or int" + value = f"{value}:{version}" + if value[0] != "%": + value = f"%{value}" self.value = value self.offset = 0 - self.mem_type = mem_type - self.mem_addr = mem_addr + + @property + def name(self) -> str: + return self.value.split(":")[0] + + @property + def version(self) -> int: + if ":" not in self.value: + return 0 + return int(self.value.split(":")[1]) + + def __hash__(self) -> int: + return self.value.__hash__() + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value def __repr__(self) -> str: return self.value @@ -158,6 +192,14 @@ def __init__(self, value: str, is_symbol: bool = False) -> None: self.value = value self.is_symbol = is_symbol + def __hash__(self) -> int: + return hash(self.value) + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value + def __repr__(self) -> str: return self.value @@ -182,6 +224,8 @@ class IRInstruction: parent: Optional["IRBasicBlock"] fence_id: int annotation: Optional[str] + ast_source: Optional[int] + error_msg: Optional[str] def __init__( self, @@ -200,6 +244,8 @@ def __init__( self.parent = None self.fence_id = -1 self.annotation = None + self.ast_source = None + self.error_msg = None def get_label_operands(self) -> list[IRLabel]: """ @@ -246,22 +292,37 @@ def replace_label_operands(self, replacements: dict) -> None: if isinstance(operand, IRLabel) and operand.value in replacements: self.operands[i] = replacements[operand.value] + @property + def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]: + """ + Get phi operands for instruction. + """ + assert self.opcode == "phi", "instruction must be a phi" + for i in range(0, len(self.operands), 2): + label = self.operands[i] + var = self.operands[i + 1] + assert isinstance(label, IRLabel), "phi operand must be a label" + assert isinstance(var, IRVariable), "phi operand must be a variable" + yield label, var + def __repr__(self) -> str: s = "" if self.output: s += f"{self.output} = " opcode = f"{self.opcode} " if self.opcode != "store" else "" s += opcode - operands = ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + operands = self.operands + if opcode not in ["jmp", "jnz", "invoke"]: + operands = reversed(operands) # type: ignore + s += ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] ) - s += operands if self.annotation: s += f" <{self.annotation}>" - # if self.liveness: - # return f"{s: <30} # {self.liveness}" + if self.liveness: + return f"{s: <30} # {self.liveness}" return s @@ -307,6 +368,9 @@ class IRBasicBlock: # stack items which this basic block produces out_vars: OrderedSet[IRVariable] + reachable: OrderedSet["IRBasicBlock"] + is_reachable: bool = False + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: assert isinstance(label, IRLabel), "label must be an IRLabel" self.label = label @@ -315,6 +379,8 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None: self.cfg_in = OrderedSet() self.cfg_out = OrderedSet() self.out_vars = OrderedSet() + self.reachable = OrderedSet() + self.is_reachable = False def add_cfg_in(self, bb: "IRBasicBlock") -> None: self.cfg_in.add(bb) @@ -333,23 +399,26 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: assert bb in self.cfg_out self.cfg_out.remove(bb) - @property - def is_reachable(self) -> bool: - return len(self.cfg_in) > 0 - - def append_instruction(self, opcode: str, *args: Union[IROperand, int]) -> Optional[IRVariable]: + def append_instruction( + self, opcode: str, *args: Union[IROperand, int], ret: IRVariable = None + ) -> Optional[IRVariable]: """ Append an instruction to the basic block Returns the output variable if the instruction supports one """ - ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + assert not self.is_terminated, self + + if ret is None: + ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] inst = IRInstruction(opcode, inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret @@ -357,10 +426,9 @@ def append_invoke_instruction( self, args: list[IROperand | int], returns: bool ) -> Optional[IRVariable]: """ - Append an instruction to the basic block - - Returns the output variable if the instruction supports one + Append an invoke to the basic block """ + assert not self.is_terminated, self ret = None if returns: ret = self.parent.get_next_variable() @@ -368,16 +436,30 @@ def append_invoke_instruction( # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] + assert isinstance(inst_args[0], IRLabel), "Invoked non label" + inst = IRInstruction("invoke", inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret - def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = None) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + + if index is None: + assert not self.is_terminated, self + index = len(self.instructions) instruction.parent = self + instruction.ast_source = self.parent.ast_source + instruction.error_msg = self.parent.error_msg self.instructions.insert(index, instruction) + def remove_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + self.instructions.remove(instruction) + def clear_instructions(self) -> None: self.instructions = [] @@ -388,6 +470,19 @@ def replace_operands(self, replacements: dict) -> None: for instruction in self.instructions: instruction.replace_operands(replacements) + def get_assignments(self): + """ + Get all assignments in basic block. + """ + return [inst.output for inst in self.instructions if inst.output] + + @property + def is_empty(self) -> bool: + """ + Check if the basic block is empty, i.e. it has no instructions. + """ + return len(self.instructions) == 0 + @property def is_terminated(self) -> bool: """ @@ -399,6 +494,20 @@ def is_terminated(self) -> bool: return False return self.instructions[-1].opcode in BB_TERMINATORS + @property + def is_terminal(self) -> bool: + """ + Check if the basic block is terminal. + """ + return len(self.cfg_out) == 0 + + @property + def in_vars(self) -> OrderedSet[IRVariable]: + for inst in self.instructions: + if inst.opcode != "phi": + return inst.liveness + return OrderedSet() + def copy(self): bb = IRBasicBlock(self.label, self.parent) bb.instructions = self.instructions.copy() diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py index 620ee66d15..60dd8bbee1 100644 --- a/vyper/venom/bb_optimizer.py +++ b/vyper/venom/bb_optimizer.py @@ -56,6 +56,25 @@ def _optimize_empty_basicblocks(ctx: IRFunction) -> int: return count +def _daisychain_empty_basicblocks(ctx: IRFunction) -> int: + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if bb.is_terminated: + continue + + if i < len(ctx.basic_blocks) - 1: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + else: + bb.append_instruction("stop") + + count += 1 + + return count + + @ir_pass def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: changes = _optimize_empty_basicblocks(ctx) diff --git a/vyper/venom/dominators.py b/vyper/venom/dominators.py new file mode 100644 index 0000000000..b69c17e1d8 --- /dev/null +++ b/vyper/venom/dominators.py @@ -0,0 +1,166 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +class DominatorTree: + """ + Dominator tree implementation. This class computes the dominator tree of a + function and provides methods to query the tree. The tree is computed using + the Lengauer-Tarjan algorithm. + """ + + ctx: IRFunction + entry_block: IRBasicBlock + dfs_order: dict[IRBasicBlock, int] + dfs_walk: list[IRBasicBlock] + dominators: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + immediate_dominators: dict[IRBasicBlock, IRBasicBlock] + dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + dominator_frontiers: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + + @classmethod + def build_dominator_tree(cls, ctx, entry): + ret = DominatorTree() + ret.compute(ctx, entry) + return ret + + def compute(self, ctx: IRFunction, entry: IRBasicBlock): + """ + Compute the dominator tree. + """ + self.ctx = ctx + self.entry_block = entry + self.dfs_order = {} + self.dfs_walk = [] + self.dominators = {} + self.immediate_dominators = {} + self.dominated = {} + self.dominator_frontiers = {} + + self._compute_dfs(self.entry_block, OrderedSet()) + self._compute_dominators() + self._compute_idoms() + self._compute_df() + + def dominates(self, bb1, bb2): + """ + Check if bb1 dominates bb2. + """ + return bb2 in self.dominators[bb1] + + def immediate_dominator(self, bb): + """ + Return the immediate dominator of a basic block. + """ + return self.immediate_dominators.get(bb) + + def _compute_dominators(self): + """ + Compute dominators + """ + basic_blocks = list(self.dfs_order.keys()) + self.dominators = {bb: OrderedSet(basic_blocks) for bb in basic_blocks} + self.dominators[self.entry_block] = OrderedSet([self.entry_block]) + changed = True + count = len(basic_blocks) ** 2 # TODO: find a proper bound for this + while changed: + count -= 1 + if count < 0: + raise CompilerPanic("Dominators computation failed to converge") + changed = False + for bb in basic_blocks: + if bb == self.entry_block: + continue + preds = bb.cfg_in + if len(preds) == 0: + continue + new_dominators = OrderedSet.intersection(*[self.dominators[pred] for pred in preds]) + new_dominators.add(bb) + if new_dominators != self.dominators[bb]: + self.dominators[bb] = new_dominators + changed = True + + def _compute_idoms(self): + """ + Compute immediate dominators + """ + self.immediate_dominators = {bb: None for bb in self.dfs_order.keys()} + self.immediate_dominators[self.entry_block] = self.entry_block + for bb in self.dfs_walk: + if bb == self.entry_block: + continue + doms = sorted(self.dominators[bb], key=lambda x: self.dfs_order[x]) + self.immediate_dominators[bb] = doms[1] + + self.dominated = {bb: OrderedSet() for bb in self.dfs_walk} + for dom, target in self.immediate_dominators.items(): + self.dominated[target].add(dom) + + def _compute_df(self): + """ + Compute dominance frontier + """ + basic_blocks = self.dfs_walk + self.dominator_frontiers = {bb: OrderedSet() for bb in basic_blocks} + + for bb in self.dfs_walk: + if len(bb.cfg_in) > 1: + for pred in bb.cfg_in: + runner = pred + while runner != self.immediate_dominators[bb]: + self.dominator_frontiers[runner].add(bb) + runner = self.immediate_dominators[runner] + + def dominance_frontier(self, basic_blocks: list[IRBasicBlock]) -> OrderedSet[IRBasicBlock]: + """ + Compute dominance frontier of a set of basic blocks. + """ + df = OrderedSet[IRBasicBlock]() + for bb in basic_blocks: + df.update(self.dominator_frontiers[bb]) + return df + + def _intersect(self, bb1, bb2): + """ + Find the nearest common dominator of two basic blocks. + """ + dfs_order = self.dfs_order + while bb1 != bb2: + while dfs_order[bb1] < dfs_order[bb2]: + bb1 = self.immediate_dominators[bb1] + while dfs_order[bb1] > dfs_order[bb2]: + bb2 = self.immediate_dominators[bb2] + return bb1 + + def _compute_dfs(self, entry: IRBasicBlock, visited): + """ + Depth-first search to compute the DFS order of the basic blocks. This + is used to compute the dominator tree. The sequence of basic blocks in + the DFS order is stored in `self.dfs_walk`. The DFS order of each basic + block is stored in `self.dfs_order`. + """ + visited.add(entry) + + for bb in entry.cfg_out: + if bb not in visited: + self._compute_dfs(bb, visited) + + self.dfs_walk.append(entry) + self.dfs_order[entry] = len(self.dfs_walk) + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dominator tree. + """ + lines = ["digraph dominator_tree {"] + for bb in self.ctx.basic_blocks: + if bb == self.entry_block: + continue + idom = self.immediate_dominator(bb) + if idom is None: + continue + lines.append(f' " {idom.label} " -> " {bb.label} "') + lines.append("}") + return "\n".join(lines) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 771dcf73ce..d1680385f5 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -1,12 +1,14 @@ -from typing import Optional +from typing import Iterator, Optional +from vyper.codegen.ir_node import IRnode +from vyper.utils import OrderedSet from vyper.venom.basicblock import ( + CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRInstruction, IRLabel, IROperand, IRVariable, - MemType, ) GLOBAL_LABEL = IRLabel("__global") @@ -27,6 +29,11 @@ class IRFunction: last_label: int last_variable: int + # Used during code generation + _ast_source_stack: list[int] + _error_msg_stack: list[str] + _bb_index: dict[str, int] + def __init__(self, name: IRLabel = None) -> None: if name is None: name = GLOBAL_LABEL @@ -40,6 +47,10 @@ def __init__(self, name: IRLabel = None) -> None: self.last_label = 0 self.last_variable = 0 + self._ast_source_stack = [] + self._error_msg_stack = [] + self._bb_index = {} + self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) @@ -62,10 +73,22 @@ def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" self.basic_blocks.append(bb) - # TODO add sanity check somewhere that basic blocks have unique labels - return self.basic_blocks[-1] + def _get_basicblock_index(self, label: str): + # perf: keep an "index" of labels to block indices to + # perform fast lookup. + # TODO: maybe better just to throw basic blocks in an ordered + # dict of some kind. + ix = self._bb_index.get(label, -1) + if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: + return ix + # do a reindex + self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) + # sanity check - no duplicate labels + assert len(self._bb_index) == len(self.basic_blocks) + return self._bb_index[label] + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ Get basic block by label. @@ -73,49 +96,97 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ if label is None: return self.basic_blocks[-1] - for bb in self.basic_blocks: - if bb.label.value == label: - return bb - raise AssertionError(f"Basic block '{label}' not found") + ix = self._get_basicblock_index(label) + return self.basic_blocks[ix] def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: """ Get basic block after label. """ - for i, bb in enumerate(self.basic_blocks[:-1]): - if bb.label.value == label.value: - return self.basic_blocks[i + 1] + ix = self._get_basicblock_index(label.value) + if 0 <= ix < len(self.basic_blocks) - 1: + return self.basic_blocks[ix + 1] raise AssertionError(f"Basic block after '{label}' not found") + def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: + """ + Get basic blocks that are terminal. + """ + for bb in self.basic_blocks: + if bb.is_terminal: + yield bb + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: """ - Get basic blocks that contain label. + Get basic blocks that point to the given basic block """ return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_label(self) -> IRLabel: + def get_next_label(self, suffix: str = "") -> IRLabel: + if suffix != "": + suffix = f"_{suffix}" self.last_label += 1 - return IRLabel(f"{self.last_label}") + return IRLabel(f"{self.last_label}{suffix}") - def get_next_variable( - self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None - ) -> IRVariable: + def get_next_variable(self) -> IRVariable: self.last_variable += 1 - return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + return IRVariable(f"%{self.last_variable}") def get_last_variable(self) -> str: return f"%{self.last_variable}" def remove_unreachable_blocks(self) -> int: - removed = 0 + self._compute_reachability() + + removed = [] new_basic_blocks = [] + + # Remove unreachable basic blocks for bb in self.basic_blocks: - if not bb.is_reachable and bb.label not in self.entry_points: - removed += 1 + if not bb.is_reachable: + removed.append(bb) else: new_basic_blocks.append(bb) self.basic_blocks = new_basic_blocks - return removed + + # Remove phi instructions that reference removed basic blocks + for bb in removed: + for out_bb in bb.cfg_out: + out_bb.remove_cfg_in(bb) + for inst in out_bb.instructions: + if inst.opcode != "phi": + continue + in_labels = inst.get_label_operands() + if bb.label in in_labels: + out_bb.remove_instruction(inst) + + return len(removed) + + def _compute_reachability(self) -> None: + """ + Compute reachability of basic blocks. + """ + for bb in self.basic_blocks: + bb.reachable = OrderedSet() + bb.is_reachable = False + + for entry in self.entry_points: + entry_bb = self.get_basic_block(entry.value) + self._compute_reachability_from(entry_bb) + + def _compute_reachability_from(self, bb: IRBasicBlock) -> None: + """ + Compute reachability of basic blocks from bb. + """ + if bb.is_reachable: + return + bb.is_reachable = True + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS or inst.opcode == "invoke": + for op in inst.get_label_operands(): + out_bb = self.get_basic_block(op.value) + bb.reachable.add(out_bb) + self._compute_reachability_from(out_bb) def append_data(self, opcode: str, args: list[IROperand]) -> None: """ @@ -147,6 +218,25 @@ def normalized(self) -> bool: # The function is normalized return True + def push_source(self, ir): + if isinstance(ir, IRnode): + self._ast_source_stack.append(ir.ast_source) + self._error_msg_stack.append(ir.error_msg) + + def pop_source(self): + assert len(self._ast_source_stack) > 0, "Empty source stack" + self._ast_source_stack.pop() + assert len(self._error_msg_stack) > 0, "Empty error stack" + self._error_msg_stack.pop() + + @property + def ast_source(self) -> Optional[int]: + return self._ast_source_stack[-1] if len(self._ast_source_stack) > 0 else None + + @property + def error_msg(self) -> Optional[str]: + return self._error_msg_stack[-1] if len(self._error_msg_stack) > 0 else None + def copy(self): new = IRFunction(self.name) new.basic_blocks = self.basic_blocks.copy() @@ -155,6 +245,32 @@ def copy(self): new.last_variable = self.last_variable return new + def as_graph(self) -> str: + import html + + def _make_label(bb): + ret = '<' + ret += f'\n' + for inst in bb.instructions: + ret += f'\n' + ret += "
{html.escape(str(bb.label))}
{html.escape(str(inst))}
>" + + return ret + # return f"{bb.label.value}:\n" + "\n".join([f" {inst}" for inst in bb.instructions]) + + ret = "digraph G {\n" + + for bb in self.basic_blocks: + for out_bb in bb.cfg_out: + ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n' + + for bb in self.basic_blocks: + ret += f' "{bb.label.value}" [shape=plaintext, ' + ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n' + + ret += "}\n" + return ret + def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" for bb in self.basic_blocks: diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index b3ac3c1ad7..f610e17f58 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -1,12 +1,10 @@ +import functools +import re from typing import Optional -from vyper.codegen.context import VariableRecord from vyper.codegen.ir_node import IRnode from vyper.evm.opcodes import get_opcodes -from vyper.exceptions import CompilerPanic -from vyper.ir.compile_ir import is_mem_sym, is_symbol -from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, OrderedSet +from vyper.utils import MemoryPositions from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -14,12 +12,17 @@ IRLiteral, IROperand, IRVariable, - MemType, ) from vyper.venom.function import IRFunction -_BINARY_IR_INSTRUCTIONS = frozenset( +# Instructions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = frozenset( [ + # binary instructions "eq", "gt", "lt", @@ -27,6 +30,7 @@ "sgt", "shr", "shl", + "sar", "or", "xor", "and", @@ -34,98 +38,104 @@ "sub", "mul", "div", + "smul", + "sdiv", "mod", + "smod", "exp", "sha3", "sha3_64", "signextend", + "chainid", + "basefee", + "timestamp", + "blockhash", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "iload", + "sload", + "tload", + "coinbase", + "number", + "prevrandao", + "difficulty", + "iszero", + "not", + "calldataload", + "extcodesize", + "extcodehash", + "balance", + "msize", + "basefee", + "invalid", + "stop", + "selfdestruct", + "assert", + "assert_unreachable", + "exit", + "calldatacopy", + "mcopy", + "extcodecopy", + "codecopy", + "returndatacopy", + "revert", + "istore", + "sstore", + "tstore", + "create", + "create2", + "addmod", + "mulmod", + "call", + "delegatecall", + "staticcall", ] ) -# Instructions that are mapped to their inverse -INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} - -# Instructions that have a direct EVM opcode equivalent and can -# be passed through to the EVM assembly without special handling -PASS_THROUGH_INSTRUCTIONS = [ - "chainid", - "basefee", - "timestamp", - "blockhash", - "caller", - "selfbalance", - "calldatasize", - "callvalue", - "address", - "origin", - "codesize", - "gas", - "gasprice", - "gaslimit", - "returndatasize", - "coinbase", - "number", - "iszero", - "not", - "calldataload", - "extcodesize", - "extcodehash", - "balance", -] +NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) SymbolTable = dict[str, Optional[IROperand]] -def _get_symbols_common(a: dict, b: dict) -> dict: - ret = {} - # preserves the ordering in `a` - for k in a.keys(): - if k not in b: - continue - if a[k] == b[k]: - continue - ret[k] = a[k], b[k] - return ret - - # convert IRnode directly to venom def ir_node_to_venom(ir: IRnode) -> IRFunction: ctx = IRFunction() - _convert_ir_bb(ctx, ir, {}, OrderedSet(), {}) + _convert_ir_bb(ctx, ir, {}) # Patch up basic blocks. Connect unterminated blocks to the next with # a jump. terminate final basic block with STOP. for i, bb in enumerate(ctx.basic_blocks): if not bb.is_terminated: - if i < len(ctx.basic_blocks) - 1: - bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + if len(ctx.basic_blocks) - 1 > i: + # TODO: revisit this. When contructor calls internal functions they + # are linked to the last ctor block. Should separate them before this + # so we don't have to handle this here + if ctx.basic_blocks[i + 1].label.value.startswith("internal"): + bb.append_instruction("stop") + else: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) else: - bb.append_instruction("stop") + bb.append_instruction("exit") return ctx -def _convert_binary_op( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], - swap: bool = False, -) -> Optional[IRVariable]: - ir_args = ir.args[::-1] if swap else ir.args - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir_args, symbols, variables, allocated_variables) - - assert isinstance(ir.value, str) # mypy hint - return ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) - - def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: - ctx.get_basic_block().append_instruction("jmp", label) + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("jmp_target"), ctx) + ctx.append_basic_block(bb) - label = ctx.get_next_label() - bb = IRBasicBlock(label, ctx) - ctx.append_basic_block(bb) + bb.append_instruction("jmp", label) def _new_block(ctx: IRFunction) -> IRBasicBlock: @@ -134,65 +144,46 @@ def _new_block(ctx: IRFunction) -> IRBasicBlock: return bb -def _handle_self_call( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], -) -> Optional[IRVariable]: - func_t = ir.passthrough_metadata.get("func_t", None) - args_ir = ir.passthrough_metadata["args_ir"] +def _append_return_args(ctx: IRFunction, ofst: int = 0, size: int = 0): + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) + ctx.append_basic_block(bb) + ret_ofst = IRVariable("ret_ofst") + ret_size = IRVariable("ret_size") + bb.append_instruction("store", ofst, ret=ret_ofst) + bb.append_instruction("store", size, ret=ret_size) + + +def _handle_self_call(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optional[IRVariable]: + setup_ir = ir.args[1] goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] target_label = goto_ir.args[0].value # goto - return_buf = goto_ir.args[1] # return buffer + return_buf_ir = goto_ir.args[1] # return buffer ret_args: list[IROperand] = [IRLabel(target_label)] # type: ignore - for arg in args_ir: - if arg.is_literal: - sym = symbols.get(f"&{arg.value}", None) - if sym is None: - ret = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) - ret_args.append(ret) - else: - ret_args.append(sym) # type: ignore - else: - ret = _convert_ir_bb(ctx, arg._optimized, symbols, variables, allocated_variables) - if arg.location and arg.location.load_op == "calldataload": - bb = ctx.get_basic_block() - ret = bb.append_instruction(arg.location.load_op, ret) - ret_args.append(ret) + if setup_ir != goto_ir: + _convert_ir_bb(ctx, setup_ir, symbols) - if return_buf.is_literal: - ret_args.append(return_buf.value) # type: ignore + return_buf = _convert_ir_bb(ctx, return_buf_ir, symbols) bb = ctx.get_basic_block() - do_ret = func_t.return_type is not None - if do_ret: - invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore - allocated_variables["return_buffer"] = invoke_ret # type: ignore - return invoke_ret - else: - bb.append_invoke_instruction(ret_args, returns=False) # type: ignore - return None + if len(goto_ir.args) > 2: + ret_args.append(return_buf) # type: ignore + + bb.append_invoke_instruction(ret_args, returns=False) # type: ignore + + return return_buf def _handle_internal_func( - ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable -) -> IRnode: + ctx: IRFunction, ir: IRnode, does_return_data: bool, symbols: SymbolTable +): bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore bb = ctx.append_basic_block(bb) - old_ir_mempos = 0 - old_ir_mempos += 64 - - for arg in func_t.arguments: - symbols[f"&{old_ir_mempos}"] = bb.append_instruction("param") - bb.instructions[-1].annotation = arg.name - old_ir_mempos += 32 # arg.typ.memory_bytes_required - # return buffer - if func_t.return_type is not None: + if does_return_data: symbols["return_buffer"] = bb.append_instruction("param") bb.instructions[-1].annotation = "return_buffer" @@ -200,17 +191,16 @@ def _handle_internal_func( symbols["return_pc"] = bb.append_instruction("param") bb.instructions[-1].annotation = "return_pc" - return ir.args[0].args[2] + _convert_ir_bb(ctx, ir.args[0].args[2], symbols) def _convert_ir_simple_node( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], + ctx: IRFunction, ir: IRnode, symbols: SymbolTable ) -> Optional[IRVariable]: - args = [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] + # execute in order + args = _convert_ir_bb_list(ctx, ir.args, symbols) + # reverse output variables for stack + args.reverse() return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore @@ -218,265 +208,162 @@ def _convert_ir_simple_node( _continue_target: Optional[IRBasicBlock] = None -def _get_variable_from_address( - variables: OrderedSet[VariableRecord], addr: int -) -> Optional[VariableRecord]: - assert isinstance(addr, int), "non-int address" - for var in variables.keys(): - if var.location.name != "memory": - continue - if addr >= var.pos and addr < var.pos + var.size: # type: ignore - return var - return None - - -def _append_return_for_stack_operand( - ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable -) -> None: - bb = ctx.get_basic_block() - if isinstance(ret_ir, IRLiteral): - sym = symbols.get(f"&{ret_ir.value}", None) - new_var = bb.append_instruction("alloca", 32, ret_ir) - bb.append_instruction("mstore", sym, new_var) # type: ignore - else: - sym = symbols.get(ret_ir.value, None) - if sym is None: - # FIXME: needs real allocations - new_var = bb.append_instruction("alloca", 32, 0) - bb.append_instruction("mstore", ret_ir, new_var) # type: ignore - else: - new_var = ret_ir - bb.append_instruction("return", last_ir, new_var) # type: ignore - - -def _convert_ir_bb_list(ctx, ir, symbols, variables, allocated_variables): +def _convert_ir_bb_list(ctx, ir, symbols): ret = [] for ir_node in ir: - venom = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) - assert venom is not None, ir_node + venom = _convert_ir_bb(ctx, ir_node, symbols) ret.append(venom) return ret -def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): - assert isinstance(ir, IRnode), ir - assert isinstance(variables, OrderedSet) - global _break_target, _continue_target +current_func = None +var_list: list[str] = [] - frame_info = ir.passthrough_metadata.get("frame_info", None) - if frame_info is not None: - local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) - variables |= local_vars - assert isinstance(variables, OrderedSet) +def pop_source_on_return(func): + @functools.wraps(func) + def pop_source(*args, **kwargs): + ctx = args[0] + ret = func(*args, **kwargs) + ctx.pop_source() + return ret - if ir.value in _BINARY_IR_INSTRUCTIONS: - return _convert_binary_op( - ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] - ) + return pop_source + + +@pop_source_on_return +def _convert_ir_bb(ctx, ir, symbols): + assert isinstance(ir, IRnode), ir + global _break_target, _continue_target, current_func, var_list + + ctx.push_source(ir) - elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + if ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: org_value = ir.value ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] - new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + new_var = _convert_ir_simple_node(ctx, ir, symbols) ir.value = org_value return ctx.get_basic_block().append_instruction("iszero", new_var) - elif ir.value in PASS_THROUGH_INSTRUCTIONS: - return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) - - elif ir.value in ["pass", "stop", "return"]: - pass + return _convert_ir_simple_node(ctx, ir, symbols) + elif ir.value == "return": + ctx.get_basic_block().append_instruction( + "return", IRVariable("ret_size"), IRVariable("ret_ofst") + ) elif ir.value == "deploy": ctx.ctor_mem_size = ir.args[0].value ctx.immutables_len = ir.args[2].value return None elif ir.value == "seq": - func_t = ir.passthrough_metadata.get("func_t", None) + if len(ir.args) == 0: + return None if ir.is_self_call: - return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) - elif func_t is not None: - symbols = {} - allocated_variables = {} - variables = OrderedSet( - {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} - ) - if func_t.is_internal: - ir = _handle_internal_func(ctx, ir, func_t, symbols) - # fallthrough - - ret = None - for ir_node in ir.args: # NOTE: skip the last one - ret = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) - - return ret - elif ir.value in ["staticcall", "call"]: # external call - idx = 0 - gas = _convert_ir_bb(ctx, ir.args[idx], symbols, variables, allocated_variables) - address = _convert_ir_bb(ctx, ir.args[idx + 1], symbols, variables, allocated_variables) - - value = None - if ir.value == "call": - value = _convert_ir_bb(ctx, ir.args[idx + 2], symbols, variables, allocated_variables) - else: - idx -= 1 - - argsOffset, argsSize, retOffset, retSize = _convert_ir_bb_list( - ctx, ir.args[idx + 3 : idx + 7], symbols, variables, allocated_variables - ) - - if isinstance(argsOffset, IRLiteral): - offset = int(argsOffset.value) - addr = offset - 32 + 4 if offset > 0 else 0 - argsOffsetVar = symbols.get(f"&{addr}", None) - if argsOffsetVar is None: - argsOffsetVar = argsOffset - elif isinstance(argsOffsetVar, IRVariable): - argsOffsetVar.mem_type = MemType.MEMORY - argsOffsetVar.mem_addr = addr - argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 - else: # pragma: nocover - raise CompilerPanic("unreachable") + return _handle_self_call(ctx, ir, symbols) + elif ir.args[0].value == "label": + current_func = ir.args[0].args[0].value + is_external = current_func.startswith("external") + is_internal = current_func.startswith("internal") + if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", current_func)) > 0: + # Internal definition + var_list = ir.args[0].args[1] + does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args + symbols = {} + _handle_internal_func(ctx, ir, does_return_data, symbols) + for ir_node in ir.args[1:]: + ret = _convert_ir_bb(ctx, ir_node, symbols) + + return ret + elif is_external: + ret = _convert_ir_bb(ctx, ir.args[0], symbols) + _append_return_args(ctx) else: - argsOffsetVar = argsOffset - - retOffsetValue = int(retOffset.value) if retOffset else 0 - retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) - symbols[f"&{retOffsetValue}"] = retVar + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("seq"), ctx) + ctx.append_basic_block(bb) + ret = _convert_ir_bb(ctx, ir.args[0], symbols) - bb = ctx.get_basic_block() + for ir_node in ir.args[1:]: + ret = _convert_ir_bb(ctx, ir_node, symbols) - if ir.value == "call": - args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] - return bb.append_instruction(ir.value, *args) - else: - args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] - return bb.append_instruction(ir.value, *args) + return ret elif ir.value == "if": cond = ir.args[0] # convert the condition - cont_ret = _convert_ir_bb(ctx, cond, symbols, variables, allocated_variables) - current_bb = ctx.get_basic_block() + cont_ret = _convert_ir_bb(ctx, cond, symbols) + cond_block = ctx.get_basic_block() + + cond_symbols = symbols.copy() - else_block = IRBasicBlock(ctx.get_next_label(), ctx) + else_block = IRBasicBlock(ctx.get_next_label("else"), ctx) ctx.append_basic_block(else_block) # convert "else" else_ret_val = None - else_syms = symbols.copy() if len(ir.args) == 3: - else_ret_val = _convert_ir_bb( - ctx, ir.args[2], else_syms, variables, allocated_variables.copy() - ) + else_ret_val = _convert_ir_bb(ctx, ir.args[2], cond_symbols) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) - after_else_syms = else_syms.copy() - else_block = ctx.get_basic_block() + + else_block_finish = ctx.get_basic_block() # convert "then" - then_block = IRBasicBlock(ctx.get_next_label(), ctx) + cond_symbols = symbols.copy() + + then_block = IRBasicBlock(ctx.get_next_label("then"), ctx) ctx.append_basic_block(then_block) - then_ret_val = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) + then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols) if isinstance(then_ret_val, IRLiteral): then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) - current_bb.append_instruction("jnz", cont_ret, then_block.label, else_block.label) + cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label) - after_then_syms = symbols.copy() - then_block = ctx.get_basic_block() + then_block_finish = ctx.get_basic_block() # exit bb - exit_label = ctx.get_next_label() - exit_bb = IRBasicBlock(exit_label, ctx) + exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), ctx) exit_bb = ctx.append_basic_block(exit_bb) - if_ret = None + if_ret = ctx.get_next_variable() if then_ret_val is not None and else_ret_val is not None: - if_ret = exit_bb.append_instruction( - "phi", then_block.label, then_ret_val, else_block.label, else_ret_val - ) - - common_symbols = _get_symbols_common(after_then_syms, after_else_syms) - for sym, val in common_symbols.items(): - ret = exit_bb.append_instruction( - "phi", then_block.label, val[0], else_block.label, val[1] - ) - old_var = symbols.get(sym, None) - symbols[sym] = ret - if old_var is not None: - for idx, var_rec in allocated_variables.items(): # type: ignore - if var_rec.value == old_var.value: - allocated_variables[idx] = ret # type: ignore - - if not else_block.is_terminated: - else_block.append_instruction("jmp", exit_bb.label) - - if not then_block.is_terminated: - then_block.append_instruction("jmp", exit_bb.label) + then_block_finish.append_instruction("store", then_ret_val, ret=if_ret) + else_block_finish.append_instruction("store", else_ret_val, ret=if_ret) + + if not else_block_finish.is_terminated: + else_block_finish.append_instruction("jmp", exit_bb.label) + + if not then_block_finish.is_terminated: + then_block_finish.append_instruction("jmp", exit_bb.label) return if_ret elif ir.value == "with": - ret = _convert_ir_bb( - ctx, ir.args[1], symbols, variables, allocated_variables - ) # initialization + ret = _convert_ir_bb(ctx, ir.args[1], symbols) # initialization + + ret = ctx.get_basic_block().append_instruction("store", ret) # Handle with nesting with same symbol with_symbols = symbols.copy() sym = ir.args[0] - if isinstance(ret, IRLiteral): - new_var = ctx.get_basic_block().append_instruction("store", ret) # type: ignore - with_symbols[sym.value] = new_var - else: - with_symbols[sym.value] = ret # type: ignore + with_symbols[sym.value] = ret - return _convert_ir_bb(ctx, ir.args[2], with_symbols, variables, allocated_variables) # body + return _convert_ir_bb(ctx, ir.args[2], with_symbols) # body elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "djump": - args = [_convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables)] + args = [_convert_ir_bb(ctx, ir.args[0], symbols)] for target in ir.args[1:]: args.append(IRLabel(target.value)) ctx.get_basic_block().append_instruction("djmp", *args) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] - arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) - new_var = ctx.get_basic_block().append_instruction("store", arg_1) # type: ignore - symbols[sym.value] = new_var - - elif ir.value == "calldatacopy": - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - new_v = arg_0 - var = ( - _get_variable_from_address(variables, int(arg_0.value)) - if isinstance(arg_0, IRLiteral) - else None - ) - bb = ctx.get_basic_block() - if var is not None: - if allocated_variables.get(var.name, None) is None: - new_v = bb.append_instruction("alloca", var.size, var.pos) # type: ignore - allocated_variables[var.name] = new_v # type: ignore - bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore - symbols[f"&{var.pos}"] = new_v # type: ignore - else: - bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore - - return new_v - elif ir.value == "codecopy": - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - ctx.get_basic_block().append_instruction("codecopy", size, arg_1, arg_0) # type: ignore + arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols) + ctx.get_basic_block().append_instruction("store", arg_1, ret=symbols[sym.value]) elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": @@ -486,15 +373,11 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): if isinstance(c, int): assert 0 <= c <= 255, "data with invalid size" ctx.append_data("db", [c]) # type: ignore - elif isinstance(c, bytes): - ctx.append_data("db", [c]) # type: ignore + elif isinstance(c.value, bytes): + ctx.append_data("db", [c.value]) # type: ignore elif isinstance(c, IRnode): - data = _convert_ir_bb(ctx, c, symbols, variables, allocated_variables) + data = _convert_ir_bb(ctx, c, symbols) ctx.append_data("db", [data]) # type: ignore - elif ir.value == "assert": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - current_bb = ctx.get_basic_block() - current_bb.append_instruction("assert", arg_0) elif ir.value == "label": label = IRLabel(ir.args[0].value, True) bb = ctx.get_basic_block() @@ -502,97 +385,30 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): bb.append_instruction("jmp", label) bb = IRBasicBlock(label, ctx) ctx.append_basic_block(bb) - _convert_ir_bb(ctx, ir.args[2], symbols, variables, allocated_variables) + code = ir.args[2] + if code.value == "pass": + bb.append_instruction("exit") + else: + _convert_ir_bb(ctx, code, symbols) elif ir.value == "exit_to": - func_t = ir.passthrough_metadata.get("func_t", None) - assert func_t is not None, "exit_to without func_t" - - if func_t.is_external: - # Hardcoded constructor special case - bb = ctx.get_basic_block() - if func_t.name == "__init__": - label = IRLabel(ir.args[0].value, True) - bb.append_instruction("jmp", label) - return None - if func_t.return_type is None: - bb.append_instruction("stop") - return None - else: - last_ir = None - ret_var = ir.args[1] - deleted = None - if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: - deleted = symbols[f"&{ret_var.value}"] - del symbols[f"&{ret_var.value}"] - for arg in ir.args[2:]: - last_ir = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) - if deleted is not None: - symbols[f"&{ret_var.value}"] = deleted - - ret_ir = _convert_ir_bb(ctx, ret_var, symbols, variables, allocated_variables) - - bb = ctx.get_basic_block() - - var = ( - _get_variable_from_address(variables, int(ret_ir.value)) - if isinstance(ret_ir, IRLiteral) - else None - ) - if var is not None: - allocated_var = allocated_variables.get(var.name, None) - assert allocated_var is not None, "unallocated variable" - new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore - - if var.size and int(var.size) > 32: - offset = int(ret_ir.value) - var.pos # type: ignore - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_var - bb.append_instruction("return", last_ir, ptr_var) - else: - _append_return_for_stack_operand(ctx, symbols, new_var, last_ir) - else: - if isinstance(ret_ir, IRLiteral): - sym = symbols.get(f"&{ret_ir.value}", None) - if sym is None: - bb.append_instruction("return", last_ir, ret_ir) - else: - if func_t.return_type.memory_bytes_required > 32: - new_var = bb.append_instruction("alloca", 32, ret_ir) - bb.append_instruction("mstore", sym, new_var) - bb.append_instruction("return", last_ir, new_var) - else: - bb.append_instruction("return", last_ir, ret_ir) - else: - if last_ir and int(last_ir.value) > 32: - bb.append_instruction("return", last_ir, ret_ir) - else: - ret_buf = 128 # TODO: need allocator - new_var = bb.append_instruction("alloca", 32, ret_buf) - bb.append_instruction("mstore", ret_ir, new_var) - bb.append_instruction("return", last_ir, new_var) - - ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) - + args = _convert_ir_bb_list(ctx, ir.args[1:], symbols) + var_list = args + _append_return_args(ctx, *var_list) + bb = ctx.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) + ctx.append_basic_block(bb) bb = ctx.get_basic_block() - if func_t.is_internal: - assert ir.args[1].value == "return_pc", "return_pc not found" - if func_t.return_type is None: - bb.append_instruction("ret", symbols["return_pc"]) - else: - if func_t.return_type.memory_bytes_required > 32: - bb.append_instruction("ret", symbols["return_buffer"], symbols["return_pc"]) - else: - ret_by_value = bb.append_instruction("mload", symbols["return_buffer"]) - bb.append_instruction("ret", ret_by_value, symbols["return_pc"]) - elif ir.value == "revert": - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("revert", arg_1, arg_0) + label = IRLabel(ir.args[0].value) + if label.value == "return_pc": + label = symbols.get("return_pc") + bb.append_instruction("ret", label) + else: + bb.append_instruction("jmp", label) elif ir.value == "dload": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) bb = ctx.get_basic_block() src = bb.append_instruction("add", arg_0, IRLabel("code_end")) @@ -600,9 +416,7 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) elif ir.value == "dloadbytes": - dst, src_offset, len_ = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) + dst, src_offset, len_ = _convert_ir_bb_list(ctx, ir.args, symbols) bb = ctx.get_basic_block() src = bb.append_instruction("add", src_offset, IRLabel("code_end")) @@ -610,210 +424,106 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): return None elif ir.value == "mload": - sym_ir = ir.args[0] - var = ( - _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None - ) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) bb = ctx.get_basic_block() - if var is not None: - if var.size and var.size > 32: - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = bb.append_instruction( - "alloca", var.size, var.pos - ) - - offset = int(sym_ir.value) - var.pos - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_variables[var.name] + if isinstance(arg_0, IRVariable): + return bb.append_instruction("mload", arg_0) - return bb.append_instruction("mload", ptr_var) - else: - if sym_ir.is_literal: - sym = symbols.get(f"&{sym_ir.value}", None) - if sym is None: - new_var = _convert_ir_bb( - ctx, sym_ir, symbols, variables, allocated_variables - ) - symbols[f"&{sym_ir.value}"] = new_var - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = new_var - return new_var - else: - return sym - - sym = symbols.get(f"&{sym_ir.value}", None) - assert sym is not None, "unallocated variable" - return sym - else: - if sym_ir.is_literal: - new_var = symbols.get(f"&{sym_ir.value}", None) - if new_var is not None: - return bb.append_instruction("mload", new_var) - else: - return bb.append_instruction("mload", sym_ir.value) - else: - new_var = _convert_ir_bb(ctx, sym_ir, symbols, variables, allocated_variables) - # - # Old IR gets it's return value as a reference in the stack - # New IR gets it's return value in stack in case of 32 bytes or less - # So here we detect ahead of time if this mload leads a self call and - # and we skip the mload - # - if sym_ir.is_self_call: - return new_var - return bb.append_instruction("mload", new_var) + if isinstance(arg_0, IRLiteral): + avar = symbols.get(f"%{arg_0.value}") + if avar is not None: + return bb.append_instruction("mload", avar) + return bb.append_instruction("mload", arg_0) elif ir.value == "mstore": - sym_ir, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - - bb = ctx.get_basic_block() - - var = None - if isinstance(sym_ir, IRLiteral): - var = _get_variable_from_address(variables, int(sym_ir.value)) + # some upstream code depends on reversed order of evaluation -- + # to fix upstream. + arg_1, arg_0 = _convert_ir_bb_list(ctx, reversed(ir.args), symbols) - if var is not None and var.size is not None: - if var.size and var.size > 32: - if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = bb.append_instruction( - "alloca", var.size, var.pos - ) - - offset = int(sym_ir.value) - var.pos - if offset > 0: - ptr_var = bb.append_instruction("add", var.pos, offset) - else: - ptr_var = allocated_variables[var.name] + if isinstance(arg_1, IRVariable): + symbols[f"&{arg_0.value}"] = arg_1 - bb.append_instruction("mstore", arg_1, ptr_var) - else: - if isinstance(sym_ir, IRLiteral): - new_var = bb.append_instruction("store", arg_1) - symbols[f"&{sym_ir.value}"] = new_var - # if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = new_var - return new_var - else: - if not isinstance(sym_ir, IRLiteral): - bb.append_instruction("mstore", arg_1, sym_ir) - return None - - sym = symbols.get(f"&{sym_ir.value}", None) - if sym is None: - bb.append_instruction("mstore", arg_1, sym_ir) - if arg_1 and not isinstance(sym_ir, IRLiteral): - symbols[f"&{sym_ir.value}"] = arg_1 - return None - - if isinstance(sym_ir, IRLiteral): - bb.append_instruction("mstore", arg_1, sym) - return None - else: - symbols[sym_ir.value] = arg_1 - return arg_1 + ctx.get_basic_block().append_instruction("mstore", arg_1, arg_0) elif ir.value == "ceil32": x = ir.args[0] expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) - return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) + return _convert_ir_bb(ctx, expanded, symbols) elif ir.value == "select": - # b ^ ((a ^ b) * cond) where cond is 1 or 0 cond, a, b = ir.args - expanded = IRnode.from_list(["xor", b, ["mul", cond, ["xor", a, b]]]) - return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) - - elif ir.value in ["sload", "iload"]: - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - return ctx.get_basic_block().append_instruction(ir.value, arg_0) - elif ir.value in ["sstore", "istore"]: - arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) - elif ir.value == "unique_symbol": - sym = ir.args[0] - new_var = ctx.get_next_variable() - symbols[f"&{sym.value}"] = new_var - return new_var + expanded = IRnode.from_list( + [ + "with", + "cond", + cond, + [ + "with", + "a", + a, + ["with", "b", b, ["xor", "b", ["mul", "cond", ["xor", "a", "b"]]]], + ], + ] + ) + return _convert_ir_bb(ctx, expanded, symbols) elif ir.value == "repeat": - # - # repeat(sym, start, end, bound, body) - # 1) entry block ] - # 2) init counter block ] -> same block - # 3) condition block (exit block, body block) - # 4) body block - # 5) increment block - # 6) exit block - # TODO: Add the extra bounds check after clarify + def emit_body_blocks(): global _break_target, _continue_target old_targets = _break_target, _continue_target - _break_target, _continue_target = exit_block, increment_block - _convert_ir_bb(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = exit_block, incr_block + _convert_ir_bb(ctx, body, symbols.copy()) _break_target, _continue_target = old_targets sym = ir.args[0] - start, end, _ = _convert_ir_bb_list( - ctx, ir.args[1:4], symbols, variables, allocated_variables - ) + start, end, _ = _convert_ir_bb_list(ctx, ir.args[1:4], symbols) + + assert ir.args[3].is_literal, "repeat bound expected to be literal" + + bound = ir.args[3].value + if ( + isinstance(end, IRLiteral) + and isinstance(start, IRLiteral) + and end.value + start.value <= bound + ): + bound = None body = ir.args[4] - entry_block = ctx.get_basic_block() - cond_block = IRBasicBlock(ctx.get_next_label(), ctx) - body_block = IRBasicBlock(ctx.get_next_label(), ctx) - jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) - increment_block = IRBasicBlock(ctx.get_next_label(), ctx) - exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + entry_block = IRBasicBlock(ctx.get_next_label("repeat"), ctx) + cond_block = IRBasicBlock(ctx.get_next_label("condition"), ctx) + body_block = IRBasicBlock(ctx.get_next_label("body"), ctx) + incr_block = IRBasicBlock(ctx.get_next_label("incr"), ctx) + exit_block = IRBasicBlock(ctx.get_next_label("exit"), ctx) - counter_inc_var = ctx.get_next_variable() + bb = ctx.get_basic_block() + bb.append_instruction("jmp", entry_block.label) + ctx.append_basic_block(entry_block) - counter_var = ctx.get_basic_block().append_instruction("store", start) + counter_var = entry_block.append_instruction("store", start) symbols[sym.value] = counter_var - ctx.get_basic_block().append_instruction("jmp", cond_block.label) - - ret = cond_block.append_instruction( - "phi", entry_block.label, counter_var, increment_block.label, counter_inc_var - ) - symbols[sym.value] = ret + end = entry_block.append_instruction("add", start, end) + if bound: + bound = entry_block.append_instruction("add", start, bound) + entry_block.append_instruction("jmp", cond_block.label) - xor_ret = cond_block.append_instruction("xor", ret, end) + xor_ret = cond_block.append_instruction("xor", counter_var, end) cont_ret = cond_block.append_instruction("iszero", xor_ret) ctx.append_basic_block(cond_block) - start_syms = symbols.copy() ctx.append_basic_block(body_block) - emit_body_blocks() - end_syms = symbols.copy() - diff_syms = _get_symbols_common(start_syms, end_syms) - - replacements = {} - for sym, val in diff_syms.items(): - new_var = ctx.get_next_variable() - symbols[sym] = new_var - replacements[val[0]] = new_var - replacements[val[1]] = new_var - cond_block.insert_instruction( - IRInstruction( - "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var - ), - 1, - ) - - body_block.replace_operands(replacements) + if bound: + xor_ret = body_block.append_instruction("xor", counter_var, bound) + body_block.append_instruction("assert", xor_ret) + emit_body_blocks() body_end = ctx.get_basic_block() - if not body_end.is_terminated: - body_end.append_instruction("jmp", jump_up_block.label) + if body_end.is_terminated is False: + body_end.append_instruction("jmp", incr_block.label) - jump_up_block.append_instruction("jmp", increment_block.label) - ctx.append_basic_block(jump_up_block) - - increment_block.insert_instruction( - IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var), 0 + ctx.append_basic_block(incr_block) + incr_block.insert_instruction( + IRInstruction("add", [counter_var, IRLiteral(1)], counter_var) ) - - increment_block.append_instruction("jmp", cond_block.label) - ctx.append_basic_block(increment_block) + incr_block.append_instruction("jmp", cond_block.label) ctx.append_basic_block(exit_block) @@ -826,32 +536,15 @@ def emit_body_blocks(): assert _continue_target is not None, "Continue with no contrinue target" ctx.get_basic_block().append_instruction("jmp", _continue_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) - elif ir.value == "gas": - return ctx.get_basic_block().append_instruction("gas") - elif ir.value == "returndatasize": - return ctx.get_basic_block().append_instruction("returndatasize") - elif ir.value == "returndatacopy": - assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" - arg_0, arg_1, size = _convert_ir_bb_list( - ctx, ir.args, symbols, variables, allocated_variables - ) - - new_var = ctx.get_basic_block().append_instruction("returndatacopy", arg_1, size) - - symbols[f"&{arg_0.value}"] = new_var - return new_var - elif ir.value == "selfdestruct": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("selfdestruct", arg_0) + elif ir.value in NOOP_INSTRUCTIONS: + pass elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = reversed( - [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] - ) + args = reversed(_convert_ir_bb_list(ctx, ir.args, symbols)) topic_count = int(ir.value[3:]) assert topic_count >= 0 and topic_count <= 4, "invalid topic count" ctx.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): - _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + _convert_ir_opcode(ctx, ir, symbols) elif isinstance(ir.value, str) and ir.value in symbols: return symbols[ir.value] elif ir.is_literal: @@ -862,28 +555,10 @@ def emit_body_blocks(): return None -def _convert_ir_opcode( - ctx: IRFunction, - ir: IRnode, - symbols: SymbolTable, - variables: OrderedSet, - allocated_variables: dict[str, IRVariable], -) -> None: +def _convert_ir_opcode(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> None: opcode = ir.value.upper() # type: ignore inst_args = [] for arg in ir.args: if isinstance(arg, IRnode): - inst_args.append(_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables)) + inst_args.append(_convert_ir_bb(ctx, arg, symbols)) ctx.get_basic_block().append_instruction(opcode, *inst_args) - - -def _data_ofst_of(sym, ofst, height_): - # e.g. _OFST _sym_foo 32 - assert is_symbol(sym) or is_mem_sym(sym) - if isinstance(ofst.value, int): - # resolve at compile time using magic _OFST op - return ["_OFST", sym, ofst.value] - else: - # if we can't resolve at compile time, resolve at runtime - # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) - return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py index 11da80ac66..3fbbdef6df 100644 --- a/vyper/venom/passes/base_pass.py +++ b/vyper/venom/passes/base_pass.py @@ -9,11 +9,13 @@ def run_pass(cls, *args, **kwargs): t = cls() count = 0 - while True: + for _ in range(1000): changes_count = t._run_pass(*args, **kwargs) or 0 count += changes_count if changes_count == 0: break + else: + raise Exception("Too many iterations in IR pass!", t.__class__) return count diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 26994bd27f..5d149cf003 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -1,13 +1,31 @@ from vyper.utils import OrderedSet from vyper.venom.analysis import DFG -from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.basicblock import BB_TERMINATORS, IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass -# DataFlow Transformation class DFTPass(IRPass): - def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + inst_order: dict[IRInstruction, int] + inst_order_num: int + + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: int = 0): + for op in inst.get_outputs(): + assert isinstance(op, IRVariable), f"expected variable, got {op}" + uses = self.dfg.get_uses(op) + + for uses_this in uses: + if uses_this.parent != inst.parent or uses_this.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + + # if the instruction is a terminator, we need to place + # it at the end of the basic block + # along with all the instructions that "lead" to it + if uses_this.opcode in BB_TERMINATORS: + offset = len(bb.instructions) + self._process_instruction_r(bb, uses_this, offset) + if inst in self.visited_instructions: return self.visited_instructions.add(inst) @@ -15,35 +33,43 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): if inst.opcode == "phi": # phi instructions stay at the beginning of the basic block # and no input processing is needed - bb.instructions.append(inst) + # bb.instructions.append(inst) + self.inst_order[inst] = 0 return for op in inst.get_inputs(): target = self.dfg.get_producing_instruction(op) + assert target is not None, f"no producing instruction for {op}" if target.parent != inst.parent or target.fence_id != inst.fence_id: # don't reorder across basic block or fence boundaries continue - self._process_instruction_r(bb, target) + self._process_instruction_r(bb, target, offset) - bb.instructions.append(inst) + self.inst_order_num += 1 + self.inst_order[inst] = self.inst_order_num + offset def _process_basic_block(self, bb: IRBasicBlock) -> None: self.ctx.append_basic_block(bb) - instructions = bb.instructions - bb.instructions = [] - - for inst in instructions: + for inst in bb.instructions: inst.fence_id = self.fence_id if inst.volatile: self.fence_id += 1 - for inst in instructions: + # We go throught the instructions and calculate the order in which they should be executed + # based on the data flow graph. This order is stored in the inst_order dictionary. + # We then sort the instructions based on this order. + self.inst_order = {} + self.inst_order_num = 0 + for inst in bb.instructions: self._process_instruction_r(bb, inst) + bb.instructions.sort(key=lambda x: self.inst_order[x]) + def _run_pass(self, ctx: IRFunction) -> None: self.ctx = ctx self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py new file mode 100644 index 0000000000..06c61c9ea7 --- /dev/null +++ b/vyper/venom/passes/make_ssa.py @@ -0,0 +1,174 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IROperand, IRVariable +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class MakeSSA(IRPass): + """ + This pass converts the function into Static Single Assignment (SSA) form. + """ + + dom: DominatorTree + defs: dict[IRVariable, OrderedSet[IRBasicBlock]] + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int: + self.ctx = ctx + + calculate_cfg(ctx) + self.dom = DominatorTree.build_dominator_tree(ctx, entry) + + calculate_liveness(ctx) + self._add_phi_nodes() + + self.var_name_counters = {var.name: 0 for var in self.defs.keys()} + self.var_name_stacks = {var.name: [0] for var in self.defs.keys()} + self._rename_vars(entry) + self._remove_degenerate_phis(entry) + + return 0 + + def _add_phi_nodes(self): + """ + Add phi nodes to the function. + """ + self._compute_defs() + work = {var: 0 for var in self.dom.dfs_walk} + has_already = {var: 0 for var in self.dom.dfs_walk} + i = 0 + + # Iterate over all variables + for var, d in self.defs.items(): + i += 1 + defs = list(d) + while len(defs) > 0: + bb = defs.pop() + for dom in self.dom.dominator_frontiers[bb]: + if has_already[dom] >= i: + continue + + self._place_phi(var, dom) + has_already[dom] = i + if work[dom] < i: + work[dom] = i + defs.append(dom) + + def _place_phi(self, var: IRVariable, basic_block: IRBasicBlock): + if var not in basic_block.in_vars: + return + + args: list[IROperand] = [] + for bb in basic_block.cfg_in: + if bb == basic_block: + continue + + args.append(bb.label) # type: ignore + args.append(var) # type: ignore + + basic_block.insert_instruction(IRInstruction("phi", args, var), 0) + + def _add_phi(self, var: IRVariable, basic_block: IRBasicBlock) -> bool: + for inst in basic_block.instructions: + if inst.opcode == "phi" and inst.output is not None and inst.output.name == var.name: + return False + + args: list[IROperand] = [] + for bb in basic_block.cfg_in: + if bb == basic_block: + continue + + args.append(bb.label) + args.append(var) + + phi = IRInstruction("phi", args, var) + basic_block.instructions.insert(0, phi) + + return True + + def _rename_vars(self, basic_block: IRBasicBlock): + """ + Rename variables. This follows the placement of phi nodes. + """ + outs = [] + + # Pre-action + for inst in basic_block.instructions: + new_ops = [] + if inst.opcode != "phi": + for op in inst.operands: + if not isinstance(op, IRVariable): + new_ops.append(op) + continue + + new_ops.append(IRVariable(op.name, version=self.var_name_stacks[op.name][-1])) + + inst.operands = new_ops + + if inst.output is not None: + v_name = inst.output.name + i = self.var_name_counters[v_name] + + self.var_name_stacks[v_name].append(i) + self.var_name_counters[v_name] = i + 1 + + inst.output = IRVariable(v_name, version=i) + # note - after previous line, inst.output.name != v_name + outs.append(inst.output.name) + + for bb in basic_block.cfg_out: + for inst in bb.instructions: + if inst.opcode != "phi": + continue + assert inst.output is not None, "Phi instruction without output" + for i, op in enumerate(inst.operands): + if op == basic_block.label: + inst.operands[i + 1] = IRVariable( + inst.output.name, version=self.var_name_stacks[inst.output.name][-1] + ) + + for bb in self.dom.dominated[basic_block]: + if bb == basic_block: + continue + self._rename_vars(bb) + + # Post-action + for op_name in outs: + # NOTE: each pop corresponds to an append in the pre-action above + self.var_name_stacks[op_name].pop() + + def _remove_degenerate_phis(self, entry: IRBasicBlock): + for inst in entry.instructions.copy(): + if inst.opcode != "phi": + continue + + new_ops = [] + for label, op in inst.phi_operands: + if op == inst.output: + continue + new_ops.extend([label, op]) + new_ops_len = len(new_ops) + if new_ops_len == 0: + entry.instructions.remove(inst) + elif new_ops_len == 2: + entry.instructions.remove(inst) + else: + inst.operands = new_ops + + for bb in self.dom.dominated[entry]: + if bb == entry: + continue + self._remove_degenerate_phis(bb) + + def _compute_defs(self): + """ + Compute the definition points of variables in the function. + """ + self.defs = {} + for bb in self.dom.dfs_walk: + assignments = bb.get_assignments() + for var in assignments: + if var not in self.defs: + self.defs[var] = OrderedSet() + self.defs[var].add(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 26699099b2..9ca8127b2d 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -28,7 +28,7 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB source = in_bb.label.value target = bb.label.value - split_label = IRLabel(f"{target}_split_{source}") + split_label = IRLabel(f"{source}_split_{target}") in_terminal = in_bb.instructions[-1] in_terminal.replace_label_operands({bb.label: split_label}) @@ -36,6 +36,13 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) + for inst in bb.instructions: + if inst.opcode != "phi": + continue + for i in range(0, len(inst.operands), 2): + if inst.operands[i] == in_bb.label: + inst.operands[i] = split_bb.label + # Update the labels in the data segment for inst in self.ctx.data_segment: if inst.opcode == "db" and inst.operands[0] == bb.label: @@ -55,5 +62,6 @@ def _run_pass(self, ctx: IRFunction) -> int: # If we made changes, recalculate the cfg if self.changes > 0: calculate_cfg(ctx) + ctx.remove_unreachable_blocks() return self.changes diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py new file mode 100644 index 0000000000..7f02ccf819 --- /dev/null +++ b/vyper/venom/passes/simplify_cfg.py @@ -0,0 +1,82 @@ +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class SimplifyCFGPass(IRPass): + visited: OrderedSet + + def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): + a.instructions.pop() + for inst in b.instructions: + assert inst.opcode != "phi", "Not implemented yet" + if inst.opcode == "phi": + a.instructions.insert(0, inst) + else: + inst.parent = a + a.instructions.append(inst) + + # Update CFG + a.cfg_out = b.cfg_out + if len(b.cfg_out) > 0: + next_bb = b.cfg_out.first() + next_bb.remove_cfg_in(b) + next_bb.add_cfg_in(a) + + self.ctx.basic_blocks.remove(b) + + def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): + next_bb = b.cfg_out.first() + jump_inst = a.instructions[-1] + assert b.label in jump_inst.operands, f"{b.label} {jump_inst.operands}" + jump_inst.operands[jump_inst.operands.index(b.label)] = next_bb.label + + # Update CFG + a.remove_cfg_out(b) + a.add_cfg_out(next_bb) + next_bb.remove_cfg_in(b) + next_bb.add_cfg_in(a) + + self.ctx.basic_blocks.remove(b) + + def _collapse_chained_blocks_r(self, bb: IRBasicBlock): + """ + DFS into the cfg and collapse blocks with a single predecessor to the predecessor + """ + if len(bb.cfg_out) == 1: + next_bb = bb.cfg_out.first() + if len(next_bb.cfg_in) == 1: + self._merge_blocks(bb, next_bb) + self._collapse_chained_blocks_r(bb) + return + elif len(bb.cfg_out) == 2: + bb_out = bb.cfg_out.copy() + for next_bb in bb_out: + if ( + len(next_bb.cfg_in) == 1 + and len(next_bb.cfg_out) == 1 + and len(next_bb.instructions) == 1 + ): + self._merge_jump(bb, next_bb) + self._collapse_chained_blocks_r(bb) + return + + if bb in self.visited: + return + self.visited.add(bb) + + for bb_out in bb.cfg_out: + self._collapse_chained_blocks_r(bb_out) + + def _collapse_chained_blocks(self, entry: IRBasicBlock): + """ + Collapse blocks with a single predecessor to their predecessor + """ + self.visited = OrderedSet() + self._collapse_chained_blocks_r(entry) + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> None: + self.ctx = ctx + + self._collapse_chained_blocks(entry) diff --git a/vyper/venom/passes/stack_reorder.py b/vyper/venom/passes/stack_reorder.py new file mode 100644 index 0000000000..b32ec4abde --- /dev/null +++ b/vyper/venom/passes/stack_reorder.py @@ -0,0 +1,24 @@ +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class StackReorderPass(IRPass): + visited: OrderedSet + + def _reorder_stack(self, bb: IRBasicBlock): + pass + + def _visit(self, bb: IRBasicBlock): + if bb in self.visited: + return + self.visited.add(bb) + + for bb_out in bb.cfg_out: + self._visit(bb_out) + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock): + self.ctx = ctx + self.visited = OrderedSet() + self._visit(entry) diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py index 66c62b74d2..a98e5bb25b 100644 --- a/vyper/venom/stack_model.py +++ b/vyper/venom/stack_model.py @@ -30,34 +30,36 @@ def push(self, op: IROperand) -> None: def pop(self, num: int = 1) -> None: del self._stack[len(self._stack) - num :] - def get_depth(self, op: IROperand) -> int: + def get_depth(self, op: IROperand, n: int = 1) -> int: """ - Returns the depth of the first matching operand in the stack map. + Returns the depth of the n-th matching operand in the stack map. If the operand is not in the stack map, returns NOT_IN_STACK. """ assert isinstance(op, IROperand), f"{type(op)}: {op}" for i, stack_op in enumerate(reversed(self._stack)): if stack_op.value == op.value: - return -i + if n <= 1: + return -i + else: + n -= 1 return StackModel.NOT_IN_STACK # type: ignore - def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + def get_phi_depth(self, phis: list[IRVariable]) -> int: """ Returns the depth of the first matching phi variable in the stack map. If the none of the phi operands are in the stack, returns NOT_IN_STACK. - Asserts that exactly one of phi1 and phi2 is found. + Asserts that exactly one of phis is found. """ - assert isinstance(phi1, IRVariable) - assert isinstance(phi2, IRVariable) + assert isinstance(phis, list) ret = StackModel.NOT_IN_STACK for i, stack_item in enumerate(reversed(self._stack)): - if stack_item in (phi1, phi2): + if stack_item in phis: assert ( ret is StackModel.NOT_IN_STACK - ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ), f"phi argument is not unique! {phis}, {self._stack}" ret = -i return ret # type: ignore diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 608e100cd1..0cb13becf2 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -1,8 +1,22 @@ +from collections import Counter from typing import Any -from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.exceptions import CompilerPanic, StackTooDeep +from vyper.ir.compile_ir import ( + PUSH, + DataHeader, + Instruction, + RuntimeHeader, + mksymbol, + optimize_assembly, +) from vyper.utils import MemoryPositions, OrderedSet -from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.analysis import ( + calculate_cfg, + calculate_dup_requirements, + calculate_liveness, + input_vars_from, +) from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -10,7 +24,6 @@ IRLiteral, IROperand, IRVariable, - MemType, ) from vyper.venom.function import IRFunction from vyper.venom.passes.normalization import NormalizationPass @@ -23,15 +36,18 @@ "coinbase", "calldatasize", "calldatacopy", + "mcopy", "calldataload", "gas", "gasprice", "gaslimit", + "chainid", "address", "origin", "number", "extcodesize", "extcodehash", + "extcodecopy", "returndatasize", "returndatacopy", "callvalue", @@ -40,13 +56,17 @@ "sstore", "mload", "mstore", + "tload", + "tstore", "timestamp", "caller", + "blockhash", "selfdestruct", "signextend", "stop", "shr", "shl", + "sar", "and", "xor", "or", @@ -54,8 +74,13 @@ "sub", "mul", "div", + "smul", + "sdiv", "mod", + "smod", "exp", + "addmod", + "mulmod", "eq", "iszero", "not", @@ -63,12 +88,34 @@ "lt", "slt", "sgt", + "create", + "create2", + "msize", + "balance", + "call", + "staticcall", + "delegatecall", + "codesize", + "basefee", + "prevrandao", + "difficulty", + "invalid", ] ) _REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] +def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: + ret = [] + for op in asm: + if isinstance(op, str) and not isinstance(op, Instruction): + ret.append(Instruction(op, inst.ast_source, inst.error_msg)) + else: + ret.append(op) + return ret # type: ignore + + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. # There, recursing into the deploy instruction was by design, and @@ -105,18 +152,18 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: # This is a side-effect of how dynamic jumps are temporarily being used # to support the O(1) dispatcher. -> look into calculate_cfg() for ctx in self.ctxs: - calculate_cfg(ctx) NormalizationPass.run_pass(ctx) + calculate_cfg(ctx) calculate_liveness(ctx) + calculate_dup_requirements(ctx) assert ctx.normalized, "Non-normalized CFG!" self._generate_evm_for_basicblock_r(asm, ctx.basic_blocks[0], StackModel()) # TODO make this property on IRFunction + asm.extend(["_sym__ctor_exit", "JUMPDEST"]) if ctx.immutables_len is not None and ctx.ctor_mem_size is not None: - while asm[-1] != "JUMPDEST": - asm.pop() asm.extend( ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] ) @@ -139,7 +186,11 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: label = inst.operands[0].value data_segments[label] = [DataHeader(f"_sym_{label}")] elif inst.opcode == "db": - data_segments[label].append(f"_sym_{inst.operands[0].value}") + data = inst.operands[0] + if isinstance(data, IRLabel): + data_segments[label].append(f"_sym_{data.value}") + else: + data_segments[label].append(data) asm.extend(list(data_segments.values())) @@ -149,20 +200,27 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: return top_asm def _stack_reorder( - self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + self, assembly: list, stack: StackModel, stack_ops: list[IRVariable] ) -> None: - # make a list so we can index it - stack_ops = [x for x in _stack_ops.keys()] - stack_ops_count = len(_stack_ops) + stack_ops_count = len(stack_ops) + + counts = Counter(stack_ops) for i in range(stack_ops_count): op = stack_ops[i] final_stack_depth = -(stack_ops_count - i - 1) - depth = stack.get_depth(op) # type: ignore + depth = stack.get_depth(op, counts[op]) # type: ignore + counts[op] -= 1 + + if depth == StackModel.NOT_IN_STACK: + raise CompilerPanic(f"Variable {op} not in stack") if depth == final_stack_depth: continue + if op == stack.peek(final_stack_depth): + continue + self.swap(assembly, stack, depth) self.swap(assembly, stack, final_stack_depth) @@ -192,23 +250,20 @@ def _emit_input_operands( continue if isinstance(op, IRLiteral): - assembly.extend([*PUSH(op.value)]) + if op.value < -(2**255): + raise Exception(f"Value too low: {op.value}") + elif op.value >= 2**256: + raise Exception(f"Value too high: {op.value}") + assembly.extend(PUSH(op.value % 2**256)) stack.push(op) continue - if op in inst.dup_requirements: + if op in inst.dup_requirements and op not in emitted_ops: self.dup_op(assembly, stack, op) if op in emitted_ops: self.dup_op(assembly, stack, op) - # REVIEW: this seems like it can be reordered across volatile - # boundaries (which includes memory fences). maybe just - # remove it entirely at this point - if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: - assembly.extend([*PUSH(op.mem_addr)]) - assembly.append("MLOAD") - emitted_ops.add(op) def _generate_evm_for_basicblock_r( @@ -224,12 +279,34 @@ def _generate_evm_for_basicblock_r( self.clean_stack_from_cfg_in(asm, basicblock, stack) - for inst in basicblock.instructions: - asm = self._generate_evm_for_instruction(asm, inst, stack) + param_insts = [inst for inst in basicblock.instructions if inst.opcode == "param"] + main_insts = [inst for inst in basicblock.instructions if inst.opcode != "param"] + + for inst in param_insts: + asm.extend(self._generate_evm_for_instruction(inst, stack)) + + self._clean_unused_params(asm, basicblock, stack) + + for i, inst in enumerate(main_insts): + next_liveness = main_insts[i + 1].liveness if i + 1 < len(main_insts) else OrderedSet() - for bb in basicblock.cfg_out: + asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) + + for bb in basicblock.reachable: self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + def _clean_unused_params(self, asm: list, bb: IRBasicBlock, stack: StackModel) -> None: + for i, inst in enumerate(bb.instructions): + if inst.opcode != "param": + break + if inst.volatile and i + 1 < len(bb.instructions): + liveness = bb.instructions[i + 1].liveness + if inst.output is not None and inst.output not in liveness: + depth = stack.get_depth(inst.output) + if depth != 0: + self.swap(asm, stack, depth) + self.pop(asm, stack) + # pop values from stack at entry to bb # note this produces the same result(!) no matter which basic block # we enter from in the CFG. @@ -258,12 +335,14 @@ def clean_stack_from_cfg_in( continue if depth != 0: - stack.swap(depth) + self.swap(asm, stack, depth) self.pop(asm, stack) def _generate_evm_for_instruction( - self, assembly: list, inst: IRInstruction, stack: StackModel + self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet = None ) -> list[str]: + assembly: list[str | int] = [] + next_liveness = next_liveness or OrderedSet() opcode = inst.opcode # @@ -276,10 +355,22 @@ def _generate_evm_for_instruction( operands = inst.get_non_label_operands() elif opcode == "alloca": operands = inst.operands[1:2] + + # iload and istore are special cases because they can take a literal + # that is handled specialy with the _OFST macro. Look below, after the + # stack reordering. elif opcode == "iload": - operands = [] + addr = inst.operands[0] + if isinstance(addr, IRLiteral): + operands = [] + else: + operands = inst.operands elif opcode == "istore": - operands = inst.operands[0:1] + addr = inst.operands[1] + if isinstance(addr, IRLiteral): + operands = inst.operands[:1] + else: + operands = inst.operands elif opcode == "log": log_topic_count = inst.operands[0].value assert log_topic_count in [0, 1, 2, 3, 4], "Invalid topic count" @@ -289,8 +380,8 @@ def _generate_evm_for_instruction( if opcode == "phi": ret = inst.get_outputs()[0] - phi1, phi2 = inst.get_inputs() - depth = stack.get_phi_depth(phi1, phi2) + phis = inst.get_inputs() + depth = stack.get_phi_depth(phis) # collapse the arguments to the phi node in the stack. # example, for `%56 = %label1 %13 %label2 %14`, we will # find an instance of %13 *or* %14 in the stack and replace it with %56. @@ -301,7 +392,7 @@ def _generate_evm_for_instruction( stack.poke(0, ret) else: stack.poke(depth, ret) - return assembly + return apply_line_numbers(inst, assembly) # Step 2: Emit instruction's input operands self._emit_input_operands(assembly, inst, operands, stack) @@ -313,11 +404,15 @@ def _generate_evm_for_instruction( b = next(iter(inst.parent.cfg_out)) target_stack = input_vars_from(inst.parent, b) # TODO optimize stack reordering at entry and exit from basic blocks - self._stack_reorder(assembly, stack, target_stack) + # NOTE: stack in general can contain multiple copies of the same variable, + # however we are safe in the case of jmp/djmp/jnz as it's not going to + # have multiples. + target_stack_list = list(target_stack) + self._stack_reorder(assembly, stack, target_stack_list) # final step to get the inputs to this instruction ordered # correctly on the stack - self._stack_reorder(assembly, stack, OrderedSet(operands)) + self._stack_reorder(assembly, stack, operands) # type: ignore # some instructions (i.e. invoke) need to do stack manipulations # with the stack model containing the return value(s), so we fiddle @@ -359,7 +454,9 @@ def _generate_evm_for_instruction( assembly.append(f"_sym_{inst.operands[0].value}") assembly.append("JUMP") elif opcode == "djmp": - assert isinstance(inst.operands[0], IRVariable) + assert isinstance( + inst.operands[0], IRVariable + ), f"Expected IRVariable, got {inst.operands[0]}" assembly.append("JUMP") elif opcode == "gt": assembly.append("GT") @@ -367,7 +464,9 @@ def _generate_evm_for_instruction( assembly.append("LT") elif opcode == "invoke": target = inst.operands[0] - assert isinstance(target, IRLabel), "invoke target must be a label" + assert isinstance( + target, IRLabel + ), f"invoke target must be a label (is ${type(target)} ${target})" assembly.extend( [ f"_sym_label_ret_{self.label_counter}", @@ -378,16 +477,12 @@ def _generate_evm_for_instruction( ] ) self.label_counter += 1 - if stack.height > 0 and stack.peek(0) in inst.dup_requirements: - self.pop(assembly, stack) - elif opcode == "call": - assembly.append("CALL") - elif opcode == "staticcall": - assembly.append("STATICCALL") elif opcode == "ret": assembly.append("JUMP") elif opcode == "return": assembly.append("RETURN") + elif opcode == "exit": + assembly.extend(["_sym__ctor_exit", "JUMP"]) elif opcode == "phi": pass elif opcode == "sha3": @@ -395,10 +490,10 @@ def _generate_evm_for_instruction( elif opcode == "sha3_64": assembly.extend( [ - *PUSH(MemoryPositions.FREE_VAR_SPACE2), - "MSTORE", *PUSH(MemoryPositions.FREE_VAR_SPACE), "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", *PUSH(64), *PUSH(MemoryPositions.FREE_VAR_SPACE), "SHA3", @@ -408,12 +503,23 @@ def _generate_evm_for_instruction( assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) elif opcode == "assert": assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "assert_unreachable": + end_symbol = mksymbol("reachable") + assembly.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"]) elif opcode == "iload": - loc = inst.operands[0].value - assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + addr = inst.operands[0] + if isinstance(addr, IRLiteral): + assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + else: + assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.append("MLOAD") elif opcode == "istore": - loc = inst.operands[1].value - assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + addr = inst.operands[1] + if isinstance(addr, IRLiteral): + assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + else: + assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.append("MSTORE") elif opcode == "log": assembly.extend([f"LOG{log_topic_count}"]) else: @@ -421,19 +527,20 @@ def _generate_evm_for_instruction( # Step 6: Emit instructions output operands (if any) if inst.output is not None: - assert isinstance(inst.output, IRVariable), "Return value must be a variable" - if inst.output.mem_type == MemType.MEMORY: - assembly.extend([*PUSH(inst.output.mem_addr)]) + if "call" in inst.opcode and inst.output not in next_liveness: + self.pop(assembly, stack) - return assembly + return apply_line_numbers(inst, assembly) def pop(self, assembly, stack, num=1): stack.pop(num) assembly.extend(["POP"] * num) def swap(self, assembly, stack, depth): + # Swaps of the top is no op if depth == 0: return + stack.swap(depth) assembly.append(_evm_swap_for(depth)) @@ -450,11 +557,13 @@ def dup_op(self, assembly, stack, op): def _evm_swap_for(depth: int) -> str: swap_idx = -depth - assert 1 <= swap_idx <= 16, "Unsupported swap depth" + if not (1 <= swap_idx <= 16): + raise StackTooDeep(f"Unsupported swap depth {swap_idx}") return f"SWAP{swap_idx}" def _evm_dup_for(depth: int) -> str: dup_idx = 1 - depth - assert 1 <= dup_idx <= 16, "Unsupported dup depth" + if not (1 <= dup_idx <= 16): + raise StackTooDeep(f"Unsupported dup depth {dup_idx}") return f"DUP{dup_idx}"