diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8bd03de79b..10312413e9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,8 +67,9 @@ jobs: matrix: python-version: [["3.11", "311"]] opt-mode: ["gas", "none", "codesize"] - evm-version: [shanghai] debug: [true, false] + evm-version: [shanghai] + experimental-codegen: [false] memorymock: [false] # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs#expanding-or-adding-matrix-configurations @@ -94,6 +95,14 @@ jobs: opt-mode: gas evm-version: cancun + # test experimental pipeline + - python-version: ["3.11", "311"] + opt-mode: gas + debug: false + evm-version: shanghai + experimental-codegen: true + # TODO: test experimental_codegen + -Ocodesize + # run with `--memorymock`, but only need to do it one configuration # TODO: consider removing the memorymock tests - python-version: ["3.11", "311"] @@ -108,12 +117,14 @@ jobs: opt-mode: gas debug: false evm-version: shanghai + - python-version: ["3.12", "312"] opt-mode: gas debug: false evm-version: shanghai - name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}-${{ matrix.evm-version }} + + name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}${{ matrix.memorymock && '-memorymock' || '' }}${{ matrix.experimental-codegen && '-experimental' || '' }}-${{ matrix.evm-version }} steps: - uses: actions/checkout@v4 @@ -141,6 +152,7 @@ jobs: --evm-version ${{ matrix.evm-version }} \ ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} \ ${{ matrix.memorymock && '--memorymock' || '' }} \ + ${{ matrix.experimental-codegen && '--experimental-codegen' || '' }} \ --cov-branch \ --cov-report xml:coverage.xml \ --cov=vyper \ @@ -193,8 +205,7 @@ jobs: # NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo. # `pytest -m "fuzzing" --store-durations -r aR tests/` - name: Fetch test-durations - run: | - curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations + run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/master/test_durations" -o .test_durations - name: Run tests run: | diff --git a/setup.cfg b/setup.cfg index 467c6a372b..f84c947981 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ testpaths = tests xfail_strict = true markers = fuzzing: Run Hypothesis fuzz test suite (deselect with '-m "not fuzzing"') + venom_xfail: mark a test case as a regression (expected to fail) under the venom pipeline diff --git a/setup.py b/setup.py index b403405122..933e8bfa4b 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "hypothesis[lark]>=6.0,<7.0", "eth-stdlib==0.2.7", "setuptools", + "hexbytes<1.0", "typing_extensions", # we can remove this once dependencies upgrade to eth-rlp>=2.0 ], "lint": [ diff --git a/tests/conftest.py b/tests/conftest.py index 2e5f11b9b8..d0681cdf42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ def pytest_addoption(parser): help="change optimization mode", ) parser.addoption("--enable-compiler-debug-mode", action="store_true") + parser.addoption("--experimental-codegen", action="store_true") parser.addoption( "--evm-version", @@ -73,6 +74,8 @@ def output_formats(): output_formats = compiler.OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] return output_formats @@ -89,6 +92,36 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture(scope="session") +def experimental_codegen(pytestconfig): + ret = pytestconfig.getoption("experimental_codegen") + assert isinstance(ret, bool) + return ret + + +@pytest.fixture(autouse=True) +def check_venom_xfail(request, experimental_codegen): + if not experimental_codegen: + return + + marker = request.node.get_closest_marker("venom_xfail") + if marker is None: + return + + # https://github.com/okken/pytest-runtime-xfail?tab=readme-ov-file#alternatives + request.node.add_marker(pytest.mark.xfail(strict=True, **marker.kwargs)) + + +@pytest.fixture +def venom_xfail(request, experimental_codegen): + def _xfail(*args, **kwargs): + if not experimental_codegen: + return + request.node.add_marker(pytest.mark.xfail(*args, strict=True, **kwargs)) + + return _xfail + + @pytest.fixture(scope="session", autouse=True) def evm_version(pytestconfig): # note: we configure the evm version that we emit code for, @@ -108,6 +141,7 @@ def chdir_tmp_path(tmp_path): yield +# CMC 2024-03-01 this doesn't need to be a fixture @pytest.fixture def keccak(): return Web3.keccak @@ -321,6 +355,7 @@ def _get_contract( w3, source_code, optimize, + experimental_codegen, output_formats, *args, override_opt_level=None, @@ -329,6 +364,7 @@ def _get_contract( ): settings = Settings() settings.optimize = override_opt_level or optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, # test that all output formats can get generated @@ -352,17 +388,21 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize, output_formats): +def get_contract(w3, optimize, experimental_codegen, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): +def get_contract_with_gas_estimation(tester, w3, optimize, experimental_codegen, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + contract = _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -372,15 +412,19 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): +def get_contract_with_gas_estimation_for_constants( + w3, optimize, experimental_codegen, output_formats +): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize, output_formats): +def get_contract_module(optimize, experimental_codegen, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -393,16 +437,25 @@ def get_contract_module(optimize, output_formats): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) + return _get_contract( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return get_contract_module def _deploy_blueprint_for( - w3, source_code, optimize, output_formats, initcode_prefix=ERC5202_PREFIX, **kwargs + w3, + source_code, + optimize, + experimental_codegen, + output_formats, + initcode_prefix=ERC5202_PREFIX, + **kwargs, ): settings = Settings() settings.optimize = optimize + settings.experimental_codegen = experimental_codegen out = compiler.compile_code( source_code, output_formats=output_formats, @@ -438,9 +491,11 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize, output_formats): +def deploy_blueprint_for(w3, optimize, experimental_codegen, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) + return _deploy_blueprint_for( + w3, source_code, optimize, experimental_codegen, output_formats, *args, **kwargs + ) return deploy_blueprint_for diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index d281851f8e..213738957b 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -3,7 +3,7 @@ import pytest from eth.codecs import abi -from vyper.exceptions import ArgumentException, StructureException +from vyper.exceptions import ArgumentException, StackTooDeep, StructureException TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() @@ -196,6 +196,7 @@ def abi_decode(x: Bytes[{len}]) -> DynArray[DynArray[uint256, 3], 3]: @pytest.mark.parametrize("args", nested_3d_array_args) @pytest.mark.parametrize("unwrap_tuple", (True, False)) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_nested_dynarray2(get_contract, args, unwrap_tuple): if unwrap_tuple is True: encoded = abi.encode("(uint256[][][])", (args,)) @@ -273,6 +274,7 @@ def foo(bs: Bytes[160]) -> (uint256, DynArray[uint256, 3]): assert c.foo(encoded) == [2**256 - 1, bs] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_private_nested_dynarray(get_contract): code = """ bytez: DynArray[DynArray[DynArray[uint256, 3], 3], 3] diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f014c47a19..305c4b1356 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -3,6 +3,8 @@ import pytest from eth.codecs import abi +from vyper.exceptions import StackTooDeep + # @pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) def test_abi_encode(get_contract): @@ -226,6 +228,7 @@ def abi_encode( @pytest.mark.parametrize("args", nested_3d_array_args) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_nested_dynarray_2(get_contract, args): code = """ @external @@ -330,6 +333,7 @@ def foo(bs: DynArray[uint256, 3]) -> (uint256, Bytes[160]): assert c.foo(bs) == [2**256 - 1, abi.encode("(uint256[])", (bs,))] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_private_nested_dynarray(get_contract): code = """ bytez: Bytes[1696] diff --git a/tests/functional/codegen/features/decorators/test_pure.py b/tests/functional/codegen/features/decorators/test_pure.py index 5f4c168687..7c49c2091b 100644 --- a/tests/functional/codegen/features/decorators/test_pure.py +++ b/tests/functional/codegen/features/decorators/test_pure.py @@ -1,21 +1,22 @@ +import pytest + +from vyper.compiler import compile_code from vyper.exceptions import FunctionDeclarationException, StateAccessViolation -def test_pure_operation(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_operation(get_contract): + code = """ @pure @external def foo() -> int128: return 5 """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_call(get_contract_with_gas_estimation_for_constants): - c = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_call(get_contract): + code = """ @pure @internal def _foo() -> int128: @@ -26,21 +27,18 @@ def _foo() -> int128: def foo() -> int128: return self._foo() """ - ) + c = get_contract(code) assert c.foo() == 5 -def test_pure_interface(get_contract_with_gas_estimation_for_constants): - c1 = get_contract_with_gas_estimation_for_constants( - """ +def test_pure_interface(get_contract): + code1 = """ @pure @external def foo() -> int128: return 5 """ - ) - c2 = get_contract_with_gas_estimation_for_constants( - """ + code2 = """ interface Foo: def foo() -> int128: pure @@ -49,28 +47,35 @@ def foo() -> int128: pure def foo(a: address) -> int128: return staticcall Foo(a).foo() """ - ) + c1 = get_contract(code1) + c2 = get_contract(code2) assert c2.foo(c1.address) == 5 -def test_invalid_envar_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_envar_access(get_contract): + code = """ @pure @external def foo() -> uint256: return chain.id """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_codesize_access(get_contract): + code = """ +@pure +@external +def foo(s: address) -> uint256: + return s.codesize + """ + with pytest.raises(StateAccessViolation): + compile_code(code) def test_invalid_state_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ + code = """ x: uint256 @pure @@ -78,29 +83,84 @@ def test_invalid_state_access(get_contract, assert_compile_failed): def foo() -> uint256: return self.x """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_immutable_access(): + code = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 1234 + +@pure +@external +def foo() -> uint256: + return COUNTER + """ + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_self_access(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_self_access(): + code = """ @pure @external def foo() -> address: return self """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) + + +def test_invalid_module_variable_access(make_input_bundle): + lib1 = """ +counter: uint256 + """ + code = """ +import lib1 +initializes: lib1 + +@pure +@external +def foo() -> uint256: + return lib1.counter + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) + + +def test_invalid_module_immutable_access(make_input_bundle): + lib1 = """ +COUNTER: immutable(uint256) + +@deploy +def __init__(): + COUNTER = 123 + """ + code = """ +import lib1 +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + +@pure +@external +def foo() -> uint256: + return lib1.COUNTER + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StateAccessViolation): + compile_code(code, input_bundle=input_bundle) -def test_invalid_call(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_call(): + code = """ @view @internal def _foo() -> uint256: @@ -111,21 +171,17 @@ def _foo() -> uint256: def foo() -> uint256: return self._foo() # Fails because of calling non-pure fn """ - ), - StateAccessViolation, - ) + with pytest.raises(StateAccessViolation): + compile_code(code) -def test_invalid_conflicting_decorators(get_contract, assert_compile_failed): - assert_compile_failed( - lambda: get_contract( - """ +def test_invalid_conflicting_decorators(): + code = """ @pure @external @payable def foo() -> uint256: return 5 """ - ), - FunctionDeclarationException, - ) + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 578413a8f4..fe51c026fe 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -4,6 +4,7 @@ from eth.codecs import abi from eth_utils import keccak +from vyper.exceptions import StackTooDeep from vyper.utils import int_bounds @@ -506,6 +507,7 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_multidimension_dynarray_clamper_passing(w3, get_contract, value): code = """ @external diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index 9146ace8a6..d96a889497 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -1,6 +1,8 @@ import pytest from web3.exceptions import ValidationError +from vyper.exceptions import StackTooDeep + def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ @@ -163,6 +165,7 @@ def get_foo() -> uint256: assert c.get_foo() == 39 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estimation): code = """ foo: int128 @@ -208,6 +211,7 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]: assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 49ff54b353..874600633a 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -1,6 +1,7 @@ import pytest from vyper.compiler.settings import OptimizationLevel +from vyper.exceptions import StackTooDeep @pytest.mark.parametrize( @@ -198,6 +199,7 @@ def get_idx_two() -> uint256: assert c.get_idx_two() == expected_values[2][2] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index b55f07639b..efa2799480 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -8,6 +8,7 @@ ArrayIndexException, ImmutableViolation, OverflowException, + StackTooDeep, StateAccessViolation, TypeMismatch, ) @@ -60,6 +61,7 @@ def loo(x: DynArray[DynArray[int128, 2], 2]) -> int128: print("Passed list tests") +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_string_list(get_contract): code = """ @external @@ -732,6 +734,7 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]: assert c.test_array_decimal_return3() == [[1.0, 2.0], [3.0]] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_mult_list(get_contract_with_gas_estimation): code = """ nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2] @@ -1478,6 +1481,7 @@ def foo(x: int128) -> int128: assert c.foo(7) == 392 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_struct_of_lists(get_contract): code = """ struct Foo: @@ -1566,6 +1570,7 @@ def bar(x: int128) -> DynArray[int128, 3]: assert c.bar(7) == [7, 14] +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): code = """ struct nestedFoo: @@ -1695,7 +1700,9 @@ def __init__(): ("DynArray[DynArray[DynArray[uint256, 5], 5], 5]", [[[], []], []]), ], ) -def test_empty_nested_dynarray(get_contract, typ, val): +def test_empty_nested_dynarray(get_contract, typ, val, venom_xfail): + if val == [[[], []], []]: + venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") code = f""" @external def foo() -> {typ}: diff --git a/tests/functional/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py index 0c5cf61b04..18f6222c20 100644 --- a/tests/functional/examples/factory/test_factory.py +++ b/tests/functional/examples/factory/test_factory.py @@ -31,12 +31,14 @@ def create_exchange(token, factory): @pytest.fixture -def factory(get_contract, optimize): +def factory(get_contract, optimize, experimental_codegen): with open("examples/factory/Exchange.vy") as f: code = f.read() exchange_interface = vyper.compile_code( - code, output_formats=["bytecode_runtime"], settings=Settings(optimize=optimize) + code, + output_formats=["bytecode_runtime"], + settings=Settings(optimize=optimize, experimental_codegen=experimental_codegen), ) exchange_deployed_bytecode = exchange_interface["bytecode_runtime"] diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index 6556fc90b9..6be50a509b 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -161,7 +161,7 @@ def test_address_code_compile_success(code: str): compiler.compile_code(code) -def test_address_code_self_success(get_contract, optimize): +def test_address_code_self_success(get_contract, optimize, experimental_codegen): code = """ code_deployment: public(Bytes[32]) @@ -174,7 +174,7 @@ def code_runtime() -> Bytes[32]: return slice(self.code, 0, 32) """ contract = get_contract(code) - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) code_compiled = compiler.compile_code( code, output_formats=["bytecode", "bytecode_runtime"], settings=settings ) diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index d351981946..7aa01a68e9 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -3,7 +3,7 @@ from vyper.utils import keccak256 -def test_get_extcodehash(get_contract, optimize): +def test_get_extcodehash(get_contract, optimize, experimental_codegen): code = """ a: address @@ -28,7 +28,7 @@ def foo3() -> bytes32: def foo4() -> bytes32: return self.a.codehash """ - settings = Settings(optimize=optimize) + settings = Settings(optimize=optimize, experimental_codegen=experimental_codegen) compiled = compile_code(code, output_formats=["bytecode_runtime"], settings=settings) bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:]) hash_ = keccak256(bytecode) diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 4fe2111f43..62a799db65 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,11 +113,13 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") - # remove bb and bb_runtime from output formats + # remove venom related from output formats # because they require venom (experimental) output_formats = OUTPUT_FORMATS.copy() del output_formats["bb"] del output_formats["bb_runtime"] + del output_formats["cfg"] + del output_formats["cfg_runtime"] foo = compile_from_file_input( foo_input, output_formats=output_formats, input_bundle=input_bundle ) diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index ce32249202..5742f7c8df 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -95,7 +95,7 @@ def test_dead_code_eliminator(code): assert all(ctor_only not in instr for instr in runtime_asm) -def test_library_code_eliminator(make_input_bundle): +def test_library_code_eliminator(make_input_bundle, experimental_codegen): library = """ @internal def unused1(): @@ -120,5 +120,6 @@ def foo(): res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) asm = res["asm"] assert "some_function()" in asm + assert "unused1()" not in asm assert "unused2()" not in asm diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py new file mode 100644 index 0000000000..dc27380796 --- /dev/null +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -0,0 +1,73 @@ +from typing import Optional + +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def _add_bb( + ctx: IRFunction, label: IRLabel, cfg_outs: [IRLabel], bb: Optional[IRBasicBlock] = None +) -> IRBasicBlock: + bb = bb if bb is not None else IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + cfg_outs_len = len(cfg_outs) + if cfg_outs_len == 0: + bb.append_instruction("stop") + elif cfg_outs_len == 1: + bb.append_instruction("jmp", cfg_outs[0]) + elif cfg_outs_len == 2: + bb.append_instruction("jnz", IRLiteral(1), cfg_outs[0], cfg_outs[1]) + else: + raise CompilerPanic("Invalid number of CFG outs") + return bb + + +def _make_test_ctx(): + lab = [IRLabel(str(i)) for i in range(0, 9)] + + ctx = IRFunction(lab[1]) + + bb1 = ctx.basic_blocks[0] + bb1.append_instruction("jmp", lab[2]) + + _add_bb(ctx, lab[7], []) + _add_bb(ctx, lab[6], [lab[7], lab[2]]) + _add_bb(ctx, lab[5], [lab[6], lab[3]]) + _add_bb(ctx, lab[4], [lab[6]]) + _add_bb(ctx, lab[3], [lab[5]]) + _add_bb(ctx, lab[2], [lab[3], lab[4]]) + + return ctx + + +def test_deminator_frontier_calculation(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + calculate_cfg(ctx) + dom = DominatorTree.build_dominator_tree(ctx, bb1) + df = dom.dominator_frontiers + + assert len(df[bb1]) == 0, df[bb1] + assert df[bb2] == OrderedSet({bb2}), df[bb2] + assert df[bb3] == OrderedSet({bb3, bb6}), df[bb3] + assert df[bb4] == OrderedSet({bb6}), df[bb4] + assert df[bb5] == OrderedSet({bb3, bb6}), df[bb5] + assert df[bb6] == OrderedSet({bb2}), df[bb6] + assert len(df[bb7]) == 0, df[bb7] + + +def test_phi_placement(): + ctx = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + + x = IRVariable("%x") + bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], x), 0) + bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], x), 0) + bb7.insert_instruction(IRInstruction("mstore", [x, IRLiteral(0)]), 0) + + MakeSSA.run_pass(ctx, bb1) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 437185cc72..7cc58e6f5c 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -23,5 +23,4 @@ def test_duplicate_operands(): bb.append_instruction("stop") asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) - - assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP"] + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "ADD", "MUL", "STOP"] diff --git a/tests/unit/compiler/venom/test_liveness_simple_loop.py b/tests/unit/compiler/venom/test_liveness_simple_loop.py new file mode 100644 index 0000000000..e725518179 --- /dev/null +++ b/tests/unit/compiler/venom/test_liveness_simple_loop.py @@ -0,0 +1,16 @@ +import vyper +from vyper.compiler.settings import Settings + +source = """ +@external +def foo(a: uint256): + _numBids: uint256 = 20 + b: uint256 = 10 + + for i: uint256 in range(128): + b = 1 + _numBids +""" + + +def test_liveness_simple_loop(): + vyper.compile_code(source, ["opcodes"], settings=Settings(experimental_codegen=True)) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py new file mode 100644 index 0000000000..2a04dfc134 --- /dev/null +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -0,0 +1,48 @@ +from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA + + +def test_phi_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + bb_cont = IRBasicBlock(IRLabel("condition"), ctx) + bb_then = IRBasicBlock(IRLabel("then"), ctx) + bb_else = IRBasicBlock(IRLabel("else"), ctx) + bb_if_exit = IRBasicBlock(IRLabel("if_exit"), ctx) + ctx.append_basic_block(bb_cont) + ctx.append_basic_block(bb_then) + ctx.append_basic_block(bb_else) + ctx.append_basic_block(bb_if_exit) + + v = bb.append_instruction("mload", 64) + bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) + + bb_if_exit.append_instruction("add", v, 1, ret=v) + bb_if_exit.append_instruction("jmp", bb_cont.label) + + bb_then.append_instruction("assert", bb_then.append_instruction("mload", 96)) + bb_then.append_instruction("jmp", bb_if_exit.label) + bb_else.append_instruction("jmp", bb_if_exit.label) + + bb.append_instruction("jmp", bb_cont.label) + + calculate_cfg(ctx) + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + calculate_liveness(ctx) + + condition_block = ctx.get_basic_block("condition") + assert len(condition_block.instructions) == 2 + + phi_inst = condition_block.instructions[0] + assert phi_inst.opcode == "phi" + assert phi_inst.operands[0].name == "_global" + assert phi_inst.operands[1].name == "%1" + assert phi_inst.operands[2].name == "if_exit" + assert phi_inst.operands[3].name == "%1" + assert phi_inst.output.name == "%1" + assert phi_inst.output.value != phi_inst.operands[1].value + assert phi_inst.output.value != phi_inst.operands[3].value diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 6d8b074994..47f4b88707 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -39,10 +39,10 @@ def test_multi_entry_block_1(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" # more complicated one @@ -91,10 +91,10 @@ def test_multi_entry_block_2(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" def test_multi_entry_block_with_dynamic_jump(): @@ -132,7 +132,7 @@ def test_multi_entry_block_with_dynamic_jump(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" - assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" + assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" diff --git a/tests/unit/compiler/venom/test_variables.py b/tests/unit/compiler/venom/test_variables.py new file mode 100644 index 0000000000..cded8d0e1a --- /dev/null +++ b/tests/unit/compiler/venom/test_variables.py @@ -0,0 +1,8 @@ +from vyper.venom.basicblock import IRVariable + + +def test_variable_equality(): + v1 = IRVariable("%x") + v2 = IRVariable("%x") + assert v1 == v2 + assert v1 != IRVariable("%y") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f673bb765c..fe01bf9260 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -141,7 +141,7 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class Name(VyperNode): +class Name(ExprNode): id: str = ... _type: str = ... diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 227b639ad5..f0c339cca7 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -203,6 +203,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: if evm_version not in EVM_VERSIONS: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + elif pragma.startswith("experimental-codegen"): + if settings.experimental_codegen is not None: + raise StructureException( + "pragma experimental-codegen specified twice!", start + ) + settings.experimental_codegen = True else: raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 778d68b5b1..7a3aa800f6 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -65,6 +65,7 @@ def _parse_cli_args(): def _cli_helper(f, output_formats, compiled): if output_formats == ("combined_json",): + compiled = {str(path): v for (path, v) in compiled.items()} print(json.dumps(compiled), file=f) return diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 6f783bb9c5..fe706699bb 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -154,10 +154,6 @@ def _adjust_gas_estimate(func_t, common_ir): common_ir.add_gas_estimate += mem_expansion_cost func_t._ir_info.gas_estimate = common_ir.gas - # pass metadata through for venom pipeline: - common_ir.passthrough_metadata["func_t"] = func_t - common_ir.passthrough_metadata["frame_info"] = frame_info - def generate_ir_for_external_function(code, compilation_target): # TODO type hints: diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 0cf9850b70..cde1ec5c87 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -80,10 +80,6 @@ def generate_ir_for_internal_function( # tag gas estimate and frame info func_t._ir_info.gas_estimate = ir_node.gas - frame_info = tag_frame_info(func_t, context) - - # pass metadata through for venom pipeline: - ir_node.passthrough_metadata["frame_info"] = frame_info - ir_node.passthrough_metadata["func_t"] = func_t + tag_frame_info(func_t, context) return InternalFuncIR(ir_node) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index da585ff0a1..a8dac640db 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -42,7 +42,6 @@ def finalize(fill_return_buffer): # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break jump_to_exit_ir = IRnode.from_list(jump_to_exit) - jump_to_exit_ir.passthrough_metadata["func_t"] = func_t return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f53e4a81b4..2363de3641 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -112,6 +112,4 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True - o.passthrough_metadata["func_t"] = func_t - o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index ee909a57d4..84aea73071 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -25,6 +25,8 @@ "interface": output.build_interface_output, "bb": output.build_bb_output, "bb_runtime": output.build_bb_runtime_output, + "cfg": output.build_cfg_output, + "cfg_runtime": output.build_cfg_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index de8e34370d..f8beb9d11b 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -90,6 +90,14 @@ def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: return compiler_data.venom_functions[1] +def build_cfg_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[0].as_graph() + + +def build_cfg_runtime_output(compiler_data: CompilerData) -> str: + return compiler_data.venom_functions[1].as_graph() + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index e343938021..d794185195 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -97,8 +97,6 @@ def __init__( no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False """ - # to force experimental codegen, uncomment: - # settings.experimental_codegen = True if isinstance(file_input, str): file_input = FileInput( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3897f0ea41..996a1ddbd9 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -386,6 +386,10 @@ class CodegenPanic(VyperInternalException): """Invalid code generated during codegen phase""" +class StackTooDeep(CodegenPanic): + """Stack too deep""" # (should not happen) + + class UnexpectedNodeType(VyperInternalException): """Unexpected AST node type.""" diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index e4a4cc60f7..191803295e 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -1020,7 +1020,11 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue - if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + if ( + isinstance(assembly[i], str) + and assembly[i].startswith("SWAP") + and assembly[i] == assembly[i + 1] + ): changed = True del assembly[i : i + 2] if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e424f94e19..82b8dbe359 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -8,6 +8,7 @@ from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.semantics.types.primitives import SelfT from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: @@ -199,8 +200,11 @@ def set_position(self, position: VarOffset) -> None: assert isinstance(position, VarOffset) # sanity check self.position = position - def is_module_variable(self): - return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + def is_state_variable(self): + non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA) + # `self` gets a VarInfo, but it is not considered a state + # variable (it is magic), so we ignore it here. + return self.location not in non_state_locations and not isinstance(self.typ, SelfT) def get_size(self) -> int: return self.typ.get_size_in(self.location) diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index bb4322c7b2..e5e8b998ca 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -282,9 +282,9 @@ def _allocate_layout_r( continue assert isinstance(node, vy_ast.VariableDecl) - # skip non-storage variables + # skip non-state variables varinfo = node.target._metadata["varinfo"] - if not varinfo.is_module_variable(): + if not varinfo.is_state_variable(): continue location = varinfo.location diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 33dcce3645..1b2e3252c8 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -38,7 +38,7 @@ from vyper.semantics.data_locations import DataLocation # TODO consolidate some of these imports -from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS +from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( TYPE_T, @@ -51,6 +51,7 @@ HashMapT, IntegerT, SArrayT, + SelfT, StringT, StructT, TupleT, @@ -164,21 +165,32 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None: raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) -def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: - env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) - if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: - if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: - return +def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None: + info = get_expr_info(node) - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", node - ) + env_vars = CONSTANT_ENVIRONMENT_VARS + # check env variable access like `block.number` + if isinstance(node, vy_ast.Attribute): + if node.get("value.id") in env_vars: + raise StateAccessViolation( + "not allowed to query environment variables in pure functions" + ) + parent_info = get_expr_info(node.value) + if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members: + raise StateAccessViolation("not allowed to query address members in pure functions") + if (varinfo := info.var_info) is None: + return + # self is magic. we only need to check it if it is not the root of an Attribute + # node. (i.e. it is bare like `self`, not `self.foo`) + is_naked_self = isinstance(varinfo.typ, SelfT) and not isinstance( + node.get_ancestor(), vy_ast.Attribute + ) + if is_naked_self: + raise StateAccessViolation("not allowed to query `self` in pure functions") -def _validate_self_reference(node: vy_ast.Name) -> None: - # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through - if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): - raise StateAccessViolation("not allowed to query self in pure functions", node) + if varinfo.is_state_variable() or is_naked_self: + raise StateAccessViolation("not allowed to query state variables in pure functions") # analyse the variable access for the attribute chain for a node @@ -429,7 +441,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode): - if not var_access.variable.is_module_variable(): + if not var_access.variable.is_state_variable(): return root_module_info = check_module_uses(target) @@ -693,7 +705,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_value_access(node) if self.func and self.func.mutability == StateMutability.PURE: - _validate_pure_access(node, typ) + _validate_pure_access(node) value_type = get_exact_type_from_node(node.value) @@ -765,10 +777,10 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if not func_type.from_interface: for s in func_type.get_variable_writes(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._writes.add(s) for s in func_type.get_variable_reads(): - if s.variable.is_module_variable(): + if s.variable.is_state_variable(): func_info._reads.add(s) if self.function_analyzer: @@ -873,10 +885,8 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func: - # TODO: refactor to use expr_info mutability - if self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func and self.func.mutability == StateMutability.PURE: + _validate_pure_access(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index b4f4381444..84199ec82c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -536,7 +536,7 @@ def visit_ExportsDecl(self, node): # check module uses var_accesses = func_t.get_variable_accesses() - if any(s.variable.is_module_variable() for s in var_accesses): + if any(s.variable.is_state_variable() for s in var_accesses): module_info = check_module_uses(item) assert module_info is not None # guaranteed by above checks used_modules.add(module_info) diff --git a/vyper/utils.py b/vyper/utils.py index ba615e58d7..114ddf97c2 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -15,7 +15,7 @@ _T = TypeVar("_T") -class OrderedSet(Generic[_T], dict[_T, None]): +class OrderedSet(Generic[_T]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -25,57 +25,82 @@ class OrderedSet(Generic[_T], dict[_T, None]): """ def __init__(self, iterable=None): - super().__init__() + self._data = dict() if iterable is not None: - for item in iterable: - self.add(item) + self.update(iterable) def __repr__(self): - keys = ", ".join(repr(k) for k in self.keys()) + keys = ", ".join(repr(k) for k in self) return f"{{{keys}}}" - def get(self, *args, **kwargs): - raise RuntimeError("can't call get() on OrderedSet!") + def __iter__(self): + return iter(self._data) + + def __contains__(self, item): + return self._data.__contains__(item) + + def __len__(self): + return len(self._data) def first(self): return next(iter(self)) def add(self, item: _T) -> None: - self[item] = None + self._data[item] = None def remove(self, item: _T) -> None: - del self[item] + del self._data[item] + + def drop(self, item: _T): + # friendly version of remove + self._data.pop(item, None) + + def dropmany(self, iterable): + for item in iterable: + self._data.pop(item, None) def difference(self, other): ret = self.copy() - for k in other.keys(): - if k in ret: - ret.remove(k) + ret.dropmany(other) return ret + def update(self, other): + # CMC 2024-03-22 for some reason, this is faster than dict.update? + # (maybe size dependent) + for item in other: + self._data[item] = None + def union(self, other): return self | other - def update(self, other): - super().update(self.__class__.fromkeys(other)) + def __ior__(self, other): + self.update(other) + return self def __or__(self, other): - return self.__class__(super().__or__(other)) + ret = self.copy() + ret.update(other) + return ret + + def __eq__(self, other): + return self._data == other._data def copy(self): - return self.__class__(super().copy()) + cls = self.__class__ + ret = cls.__new__(cls) + ret._data = self._data.copy() + return ret @classmethod def intersection(cls, *sets): - res = OrderedSet() if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - if len(sets) == 1: - return sets[0].copy() - for e in sets[0].keys(): - if all(e in s for s in sets[1:]): - res.add(e) - return res + + ret = sets[0].copy() + for e in sets[0]: + if any(e not in s for s in sets[1:]): + ret.remove(e) + return ret class StringEnum(enum.Enum): diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index d1c2d0c342..2efd58ad6c 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -11,10 +11,14 @@ ir_pass_optimize_unused_variables, ir_pass_remove_unreachable_blocks, ) +from vyper.venom.dominators import DominatorTree from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation from vyper.venom.passes.dft import DFTPass +from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes.simplify_cfg import SimplifyCFGPass from vyper.venom.venom_to_assembly import VenomCompiler DEFAULT_OPT_LEVEL = OptimizationLevel.default() @@ -38,6 +42,24 @@ def generate_assembly_experimental( def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: # Run passes on Venom IR # TODO: Add support for optimization levels + + ir_pass_optimize_empty_blocks(ctx) + ir_pass_remove_unreachable_blocks(ctx) + + internals = [ + bb + for bb in ctx.basic_blocks + if bb.label.value.startswith("internal") and len(bb.cfg_in) == 0 + ] + + SimplifyCFGPass.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + SimplifyCFGPass.run_pass(ctx, entry) + + MakeSSA.run_pass(ctx, ctx.basic_blocks[0]) + for entry in internals: + MakeSSA.run_pass(ctx, entry) + while True: changes = 0 @@ -51,7 +73,6 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: calculate_cfg(ctx) calculate_liveness(ctx) - changes += ir_pass_constant_propagation(ctx) changes += DFTPass.run_pass(ctx) calculate_cfg(ctx) diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index daebd2560c..066a60f45e 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -1,3 +1,5 @@ +from typing import Optional + from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet from vyper.venom.basicblock import ( @@ -38,6 +40,7 @@ def calculate_cfg(ctx: IRFunction) -> None: def _reset_liveness(ctx: IRFunction) -> None: for bb in ctx.basic_blocks: + bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() @@ -50,16 +53,15 @@ def _calculate_liveness(bb: IRBasicBlock) -> bool: orig_liveness = bb.instructions[0].liveness.copy() liveness = bb.out_vars.copy() for instruction in reversed(bb.instructions): - ops = instruction.get_inputs() + ins = instruction.get_inputs() + outs = instruction.get_outputs() - for op in ops: - if op in liveness: - instruction.dup_requirements.add(op) + if ins or outs: + # perf: only copy if changed + liveness = liveness.copy() + liveness.update(ins) + liveness.dropmany(outs) - liveness = liveness.union(OrderedSet.fromkeys(ops)) - out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None - if out in liveness: - liveness.remove(out) instruction.liveness = liveness return orig_liveness != bb.instructions[0].liveness @@ -89,6 +91,18 @@ def calculate_liveness(ctx: IRFunction) -> None: break +def calculate_dup_requirements(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + last_liveness = bb.out_vars + for inst in reversed(bb.instructions): + inst.dup_requirements = OrderedSet() + ops = inst.get_inputs() + for op in ops: + if op in last_liveness: + inst.dup_requirements.add(op) + last_liveness = inst.liveness + + # calculate the input variables into self from source def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: liveness = target.instructions[0].liveness.copy() @@ -104,19 +118,17 @@ def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IR # will arbitrarily choose either %12 or %14 to be in the liveness # set, and then during instruction selection, after this instruction, # %12 will be replaced by %56 in the liveness set - source1, source2 = inst.operands[0], inst.operands[2] - phi1, phi2 = inst.operands[1], inst.operands[3] - if source.label == source1: - liveness.add(phi1) - if phi2 in liveness: - liveness.remove(phi2) - elif source.label == source2: - liveness.add(phi2) - if phi1 in liveness: - liveness.remove(phi1) - else: - # bad path into this phi node - raise CompilerPanic(f"unreachable: {inst}") + + # bad path into this phi node + if source.label not in inst.operands: + raise CompilerPanic(f"unreachable: {inst} from {source.label}") + + for label, var in inst.phi_operands: + if label == source.label: + liveness.add(var) + else: + if var in liveness: + liveness.remove(var) return liveness @@ -137,8 +149,8 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]: return self._dfg_inputs.get(op, []) # the instruction which produces this variable. - def get_producing_instruction(self, op: IRVariable) -> IRInstruction: - return self._dfg_outputs[op] + def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: + return self._dfg_outputs.get(op) @classmethod def build_dfg(cls, ctx: IRFunction) -> "DFG": @@ -163,3 +175,20 @@ def build_dfg(cls, ctx: IRFunction) -> "DFG": dfg._dfg_outputs[op] = inst return dfg + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dfg + """ + lines = ["digraph dfg_graph {"] + for var, inputs in self._dfg_inputs.items(): + for input in inputs: + for op in input.get_outputs(): + if isinstance(op, IRVariable): + lines.append(f' " {var.name} " -> " {op.name} "') + + lines.append("}") + return "\n".join(lines) + + def __repr__(self) -> str: + return self.as_graph() diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index ed70a5eaa0..6c509d8f95 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,10 +1,9 @@ -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -12,15 +11,22 @@ "alloca", "call", "staticcall", + "delegatecall", "invoke", "sload", "sstore", "iload", "istore", + "tload", + "tstore", "assert", + "assert_unreachable", "mstore", "mload", "calldatacopy", + "mcopy", + "extcodecopy", + "returndatacopy", "codecopy", "dloadbytes", "dload", @@ -39,11 +45,15 @@ "istore", "dloadbytes", "calldatacopy", + "mcopy", + "returndatacopy", "codecopy", + "extcodecopy", "return", "ret", "revert", "assert", + "assert_unreachable", "selfdestruct", "stop", "invalid", @@ -52,6 +62,7 @@ "djmp", "jnz", "log", + "exit", ] ) @@ -87,6 +98,10 @@ class IROperand: value: Any + @property + def name(self) -> str: + return self.value + class IRValue(IROperand): """ @@ -109,13 +124,16 @@ def __init__(self, value: int) -> None: assert isinstance(value, int), "value must be an int" self.value = value - def __repr__(self) -> str: - return str(self.value) + def __hash__(self) -> int: + return self.value.__hash__() + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value -class MemType(Enum): - OPERAND_STACK = auto() - MEMORY = auto() + def __repr__(self) -> str: + return str(self.value) class IRVariable(IRValue): @@ -126,18 +144,34 @@ class IRVariable(IRValue): value: str offset: int = 0 - # some variables can be in memory for conversion from legacy IR to venom - mem_type: MemType = MemType.OPERAND_STACK - mem_addr: Optional[int] = None - - def __init__( - self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None - ) -> None: + def __init__(self, value: str, version: Optional[str | int] = None) -> None: assert isinstance(value, str) + assert ":" not in value, "Variable name cannot contain ':'" + if version: + assert isinstance(value, str) or isinstance(value, int), "value must be an str or int" + value = f"{value}:{version}" + if value[0] != "%": + value = f"%{value}" self.value = value self.offset = 0 - self.mem_type = mem_type - self.mem_addr = mem_addr + + @property + def name(self) -> str: + return self.value.split(":")[0] + + @property + def version(self) -> int: + if ":" not in self.value: + return 0 + return int(self.value.split(":")[1]) + + def __hash__(self) -> int: + return self.value.__hash__() + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value def __repr__(self) -> str: return self.value @@ -158,6 +192,14 @@ def __init__(self, value: str, is_symbol: bool = False) -> None: self.value = value self.is_symbol = is_symbol + def __hash__(self) -> int: + return hash(self.value) + + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return False + return self.value == other.value + def __repr__(self) -> str: return self.value @@ -182,6 +224,8 @@ class IRInstruction: parent: Optional["IRBasicBlock"] fence_id: int annotation: Optional[str] + ast_source: Optional[int] + error_msg: Optional[str] def __init__( self, @@ -200,6 +244,8 @@ def __init__( self.parent = None self.fence_id = -1 self.annotation = None + self.ast_source = None + self.error_msg = None def get_label_operands(self) -> list[IRLabel]: """ @@ -246,22 +292,37 @@ def replace_label_operands(self, replacements: dict) -> None: if isinstance(operand, IRLabel) and operand.value in replacements: self.operands[i] = replacements[operand.value] + @property + def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]: + """ + Get phi operands for instruction. + """ + assert self.opcode == "phi", "instruction must be a phi" + for i in range(0, len(self.operands), 2): + label = self.operands[i] + var = self.operands[i + 1] + assert isinstance(label, IRLabel), "phi operand must be a label" + assert isinstance(var, IRVariable), "phi operand must be a variable" + yield label, var + def __repr__(self) -> str: s = "" if self.output: s += f"{self.output} = " opcode = f"{self.opcode} " if self.opcode != "store" else "" s += opcode - operands = ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + operands = self.operands + if opcode not in ["jmp", "jnz", "invoke"]: + operands = reversed(operands) # type: ignore + s += ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] ) - s += operands if self.annotation: s += f" <{self.annotation}>" - # if self.liveness: - # return f"{s: <30} # {self.liveness}" + if self.liveness: + return f"{s: <30} # {self.liveness}" return s @@ -307,6 +368,9 @@ class IRBasicBlock: # stack items which this basic block produces out_vars: OrderedSet[IRVariable] + reachable: OrderedSet["IRBasicBlock"] + is_reachable: bool = False + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: assert isinstance(label, IRLabel), "label must be an IRLabel" self.label = label @@ -315,6 +379,8 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None: self.cfg_in = OrderedSet() self.cfg_out = OrderedSet() self.out_vars = OrderedSet() + self.reachable = OrderedSet() + self.is_reachable = False def add_cfg_in(self, bb: "IRBasicBlock") -> None: self.cfg_in.add(bb) @@ -333,23 +399,26 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: assert bb in self.cfg_out self.cfg_out.remove(bb) - @property - def is_reachable(self) -> bool: - return len(self.cfg_in) > 0 - - def append_instruction(self, opcode: str, *args: Union[IROperand, int]) -> Optional[IRVariable]: + def append_instruction( + self, opcode: str, *args: Union[IROperand, int], ret: IRVariable = None + ) -> Optional[IRVariable]: """ Append an instruction to the basic block Returns the output variable if the instruction supports one """ - ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + assert not self.is_terminated, self + + if ret is None: + ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] inst = IRInstruction(opcode, inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret @@ -357,10 +426,9 @@ def append_invoke_instruction( self, args: list[IROperand | int], returns: bool ) -> Optional[IRVariable]: """ - Append an instruction to the basic block - - Returns the output variable if the instruction supports one + Append an invoke to the basic block """ + assert not self.is_terminated, self ret = None if returns: ret = self.parent.get_next_variable() @@ -368,16 +436,30 @@ def append_invoke_instruction( # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] + assert isinstance(inst_args[0], IRLabel), "Invoked non label" + inst = IRInstruction("invoke", inst_args, ret) inst.parent = self + inst.ast_source = self.parent.ast_source + inst.error_msg = self.parent.error_msg self.instructions.append(inst) return ret - def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = None) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + + if index is None: + assert not self.is_terminated, self + index = len(self.instructions) instruction.parent = self + instruction.ast_source = self.parent.ast_source + instruction.error_msg = self.parent.error_msg self.instructions.insert(index, instruction) + def remove_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + self.instructions.remove(instruction) + def clear_instructions(self) -> None: self.instructions = [] @@ -388,6 +470,19 @@ def replace_operands(self, replacements: dict) -> None: for instruction in self.instructions: instruction.replace_operands(replacements) + def get_assignments(self): + """ + Get all assignments in basic block. + """ + return [inst.output for inst in self.instructions if inst.output] + + @property + def is_empty(self) -> bool: + """ + Check if the basic block is empty, i.e. it has no instructions. + """ + return len(self.instructions) == 0 + @property def is_terminated(self) -> bool: """ @@ -399,6 +494,20 @@ def is_terminated(self) -> bool: return False return self.instructions[-1].opcode in BB_TERMINATORS + @property + def is_terminal(self) -> bool: + """ + Check if the basic block is terminal. + """ + return len(self.cfg_out) == 0 + + @property + def in_vars(self) -> OrderedSet[IRVariable]: + for inst in self.instructions: + if inst.opcode != "phi": + return inst.liveness + return OrderedSet() + def copy(self): bb = IRBasicBlock(self.label, self.parent) bb.instructions = self.instructions.copy() diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py index 620ee66d15..60dd8bbee1 100644 --- a/vyper/venom/bb_optimizer.py +++ b/vyper/venom/bb_optimizer.py @@ -56,6 +56,25 @@ def _optimize_empty_basicblocks(ctx: IRFunction) -> int: return count +def _daisychain_empty_basicblocks(ctx: IRFunction) -> int: + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if bb.is_terminated: + continue + + if i < len(ctx.basic_blocks) - 1: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + else: + bb.append_instruction("stop") + + count += 1 + + return count + + @ir_pass def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: changes = _optimize_empty_basicblocks(ctx) diff --git a/vyper/venom/dominators.py b/vyper/venom/dominators.py new file mode 100644 index 0000000000..b69c17e1d8 --- /dev/null +++ b/vyper/venom/dominators.py @@ -0,0 +1,166 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +class DominatorTree: + """ + Dominator tree implementation. This class computes the dominator tree of a + function and provides methods to query the tree. The tree is computed using + the Lengauer-Tarjan algorithm. + """ + + ctx: IRFunction + entry_block: IRBasicBlock + dfs_order: dict[IRBasicBlock, int] + dfs_walk: list[IRBasicBlock] + dominators: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + immediate_dominators: dict[IRBasicBlock, IRBasicBlock] + dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + dominator_frontiers: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + + @classmethod + def build_dominator_tree(cls, ctx, entry): + ret = DominatorTree() + ret.compute(ctx, entry) + return ret + + def compute(self, ctx: IRFunction, entry: IRBasicBlock): + """ + Compute the dominator tree. + """ + self.ctx = ctx + self.entry_block = entry + self.dfs_order = {} + self.dfs_walk = [] + self.dominators = {} + self.immediate_dominators = {} + self.dominated = {} + self.dominator_frontiers = {} + + self._compute_dfs(self.entry_block, OrderedSet()) + self._compute_dominators() + self._compute_idoms() + self._compute_df() + + def dominates(self, bb1, bb2): + """ + Check if bb1 dominates bb2. + """ + return bb2 in self.dominators[bb1] + + def immediate_dominator(self, bb): + """ + Return the immediate dominator of a basic block. + """ + return self.immediate_dominators.get(bb) + + def _compute_dominators(self): + """ + Compute dominators + """ + basic_blocks = list(self.dfs_order.keys()) + self.dominators = {bb: OrderedSet(basic_blocks) for bb in basic_blocks} + self.dominators[self.entry_block] = OrderedSet([self.entry_block]) + changed = True + count = len(basic_blocks) ** 2 # TODO: find a proper bound for this + while changed: + count -= 1 + if count < 0: + raise CompilerPanic("Dominators computation failed to converge") + changed = False + for bb in basic_blocks: + if bb == self.entry_block: + continue + preds = bb.cfg_in + if len(preds) == 0: + continue + new_dominators = OrderedSet.intersection(*[self.dominators[pred] for pred in preds]) + new_dominators.add(bb) + if new_dominators != self.dominators[bb]: + self.dominators[bb] = new_dominators + changed = True + + def _compute_idoms(self): + """ + Compute immediate dominators + """ + self.immediate_dominators = {bb: None for bb in self.dfs_order.keys()} + self.immediate_dominators[self.entry_block] = self.entry_block + for bb in self.dfs_walk: + if bb == self.entry_block: + continue + doms = sorted(self.dominators[bb], key=lambda x: self.dfs_order[x]) + self.immediate_dominators[bb] = doms[1] + + self.dominated = {bb: OrderedSet() for bb in self.dfs_walk} + for dom, target in self.immediate_dominators.items(): + self.dominated[target].add(dom) + + def _compute_df(self): + """ + Compute dominance frontier + """ + basic_blocks = self.dfs_walk + self.dominator_frontiers = {bb: OrderedSet() for bb in basic_blocks} + + for bb in self.dfs_walk: + if len(bb.cfg_in) > 1: + for pred in bb.cfg_in: + runner = pred + while runner != self.immediate_dominators[bb]: + self.dominator_frontiers[runner].add(bb) + runner = self.immediate_dominators[runner] + + def dominance_frontier(self, basic_blocks: list[IRBasicBlock]) -> OrderedSet[IRBasicBlock]: + """ + Compute dominance frontier of a set of basic blocks. + """ + df = OrderedSet[IRBasicBlock]() + for bb in basic_blocks: + df.update(self.dominator_frontiers[bb]) + return df + + def _intersect(self, bb1, bb2): + """ + Find the nearest common dominator of two basic blocks. + """ + dfs_order = self.dfs_order + while bb1 != bb2: + while dfs_order[bb1] < dfs_order[bb2]: + bb1 = self.immediate_dominators[bb1] + while dfs_order[bb1] > dfs_order[bb2]: + bb2 = self.immediate_dominators[bb2] + return bb1 + + def _compute_dfs(self, entry: IRBasicBlock, visited): + """ + Depth-first search to compute the DFS order of the basic blocks. This + is used to compute the dominator tree. The sequence of basic blocks in + the DFS order is stored in `self.dfs_walk`. The DFS order of each basic + block is stored in `self.dfs_order`. + """ + visited.add(entry) + + for bb in entry.cfg_out: + if bb not in visited: + self._compute_dfs(bb, visited) + + self.dfs_walk.append(entry) + self.dfs_order[entry] = len(self.dfs_walk) + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dominator tree. + """ + lines = ["digraph dominator_tree {"] + for bb in self.ctx.basic_blocks: + if bb == self.entry_block: + continue + idom = self.immediate_dominator(bb) + if idom is None: + continue + lines.append(f' " {idom.label} " -> " {bb.label} "') + lines.append("}") + return "\n".join(lines) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 771dcf73ce..d1680385f5 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -1,12 +1,14 @@ -from typing import Optional +from typing import Iterator, Optional +from vyper.codegen.ir_node import IRnode +from vyper.utils import OrderedSet from vyper.venom.basicblock import ( + CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRInstruction, IRLabel, IROperand, IRVariable, - MemType, ) GLOBAL_LABEL = IRLabel("__global") @@ -27,6 +29,11 @@ class IRFunction: last_label: int last_variable: int + # Used during code generation + _ast_source_stack: list[int] + _error_msg_stack: list[str] + _bb_index: dict[str, int] + def __init__(self, name: IRLabel = None) -> None: if name is None: name = GLOBAL_LABEL @@ -40,6 +47,10 @@ def __init__(self, name: IRLabel = None) -> None: self.last_label = 0 self.last_variable = 0 + self._ast_source_stack = [] + self._error_msg_stack = [] + self._bb_index = {} + self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) @@ -62,10 +73,22 @@ def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" self.basic_blocks.append(bb) - # TODO add sanity check somewhere that basic blocks have unique labels - return self.basic_blocks[-1] + def _get_basicblock_index(self, label: str): + # perf: keep an "index" of labels to block indices to + # perform fast lookup. + # TODO: maybe better just to throw basic blocks in an ordered + # dict of some kind. + ix = self._bb_index.get(label, -1) + if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: + return ix + # do a reindex + self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) + # sanity check - no duplicate labels + assert len(self._bb_index) == len(self.basic_blocks) + return self._bb_index[label] + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ Get basic block by label. @@ -73,49 +96,97 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ if label is None: return self.basic_blocks[-1] - for bb in self.basic_blocks: - if bb.label.value == label: - return bb - raise AssertionError(f"Basic block '{label}' not found") + ix = self._get_basicblock_index(label) + return self.basic_blocks[ix] def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: """ Get basic block after label. """ - for i, bb in enumerate(self.basic_blocks[:-1]): - if bb.label.value == label.value: - return self.basic_blocks[i + 1] + ix = self._get_basicblock_index(label.value) + if 0 <= ix < len(self.basic_blocks) - 1: + return self.basic_blocks[ix + 1] raise AssertionError(f"Basic block after '{label}' not found") + def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: + """ + Get basic blocks that are terminal. + """ + for bb in self.basic_blocks: + if bb.is_terminal: + yield bb + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: """ - Get basic blocks that contain label. + Get basic blocks that point to the given basic block """ return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_label(self) -> IRLabel: + def get_next_label(self, suffix: str = "") -> IRLabel: + if suffix != "": + suffix = f"_{suffix}" self.last_label += 1 - return IRLabel(f"{self.last_label}") + return IRLabel(f"{self.last_label}{suffix}") - def get_next_variable( - self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None - ) -> IRVariable: + def get_next_variable(self) -> IRVariable: self.last_variable += 1 - return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + return IRVariable(f"%{self.last_variable}") def get_last_variable(self) -> str: return f"%{self.last_variable}" def remove_unreachable_blocks(self) -> int: - removed = 0 + self._compute_reachability() + + removed = [] new_basic_blocks = [] + + # Remove unreachable basic blocks for bb in self.basic_blocks: - if not bb.is_reachable and bb.label not in self.entry_points: - removed += 1 + if not bb.is_reachable: + removed.append(bb) else: new_basic_blocks.append(bb) self.basic_blocks = new_basic_blocks - return removed + + # Remove phi instructions that reference removed basic blocks + for bb in removed: + for out_bb in bb.cfg_out: + out_bb.remove_cfg_in(bb) + for inst in out_bb.instructions: + if inst.opcode != "phi": + continue + in_labels = inst.get_label_operands() + if bb.label in in_labels: + out_bb.remove_instruction(inst) + + return len(removed) + + def _compute_reachability(self) -> None: + """ + Compute reachability of basic blocks. + """ + for bb in self.basic_blocks: + bb.reachable = OrderedSet() + bb.is_reachable = False + + for entry in self.entry_points: + entry_bb = self.get_basic_block(entry.value) + self._compute_reachability_from(entry_bb) + + def _compute_reachability_from(self, bb: IRBasicBlock) -> None: + """ + Compute reachability of basic blocks from bb. + """ + if bb.is_reachable: + return + bb.is_reachable = True + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS or inst.opcode == "invoke": + for op in inst.get_label_operands(): + out_bb = self.get_basic_block(op.value) + bb.reachable.add(out_bb) + self._compute_reachability_from(out_bb) def append_data(self, opcode: str, args: list[IROperand]) -> None: """ @@ -147,6 +218,25 @@ def normalized(self) -> bool: # The function is normalized return True + def push_source(self, ir): + if isinstance(ir, IRnode): + self._ast_source_stack.append(ir.ast_source) + self._error_msg_stack.append(ir.error_msg) + + def pop_source(self): + assert len(self._ast_source_stack) > 0, "Empty source stack" + self._ast_source_stack.pop() + assert len(self._error_msg_stack) > 0, "Empty error stack" + self._error_msg_stack.pop() + + @property + def ast_source(self) -> Optional[int]: + return self._ast_source_stack[-1] if len(self._ast_source_stack) > 0 else None + + @property + def error_msg(self) -> Optional[str]: + return self._error_msg_stack[-1] if len(self._error_msg_stack) > 0 else None + def copy(self): new = IRFunction(self.name) new.basic_blocks = self.basic_blocks.copy() @@ -155,6 +245,32 @@ def copy(self): new.last_variable = self.last_variable return new + def as_graph(self) -> str: + import html + + def _make_label(bb): + ret = '<
{html.escape(str(bb.label))} |
{html.escape(str(inst))} |