From f38b61a27a975c80600d54981386e3cf8697edb6 Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Thu, 21 Nov 2024 16:54:55 +0000 Subject: [PATCH 1/6] refactor[test]: add some sanity checks to `abi_decode` tests (#4096) QOL improvements for abi_decode tests: - added sanity checks to abi decode tests to ensure that we're never failing on calldatasize - also split the creation of payload into two parts when calling target contract with low-level msg_call to increase readability --------- Co-authored-by: Charles Cooper --- .../builtins/codegen/test_abi_decode.py | 377 ++++++++++-------- 1 file changed, 216 insertions(+), 161 deletions(-) diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 9ae869c9cc..475118c7e3 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -8,6 +8,8 @@ TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() +BUFFER_OVERHEAD = 4 + 2 * 32 + def test_abi_decode_complex(get_contract): contract = """ @@ -474,8 +476,10 @@ def test_abi_decode_length_mismatch(get_contract, assert_compile_failed, bad_cod assert_compile_failed(lambda: get_contract(bad_code), exception) -def _abi_payload_from_tuple(payload: tuple[int | bytes, ...]) -> bytes: - return b"".join(p.to_bytes(32, "big") if isinstance(p, int) else p for p in payload) +def _abi_payload_from_tuple(payload: tuple[int | bytes, ...], max_sz: int) -> bytes: + ret = b"".join(p.to_bytes(32, "big") if isinstance(p, int) else p for p in payload) + assert len(ret) <= max_sz + return ret def _replicate(value: int, count: int) -> tuple[int, ...]: @@ -486,11 +490,12 @@ def test_abi_decode_arithmetic_overflow(env, tx_failed, get_contract): # test based on GHSA-9p8r-4xp4-gw5w: # https://github.com/vyperlang/vyper/security/advisories/GHSA-9p8r-4xp4-gw5w#advisory-comment-91841 # buf + head causes arithmetic overflow - code = """ + buffer_size = 32 * 3 + code = f""" @external -def f(x: Bytes[32 * 3]): +def f(x: Bytes[{buffer_size}]): a: Bytes[32] = b"foo" - y: Bytes[32 * 3] = x + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[32] = _abi_decode(y, Bytes[32]) a = b"bar" @@ -500,39 +505,47 @@ def f(x: Bytes[32 * 3]): """ c = get_contract(code) - data = method_id("f(bytes)") - payload = ( - 0x20, # tuple head - 0x60, # parent array length - # parent payload - this word will be considered as the head of the abi-encoded inner array - # and it will be added to base ptr leading to an arithmetic overflow - 2**256 - 0x60, - ) - data += _abi_payload_from_tuple(payload) + tuple_head_ofst = 0x20 + parent_array_len = 0x60 + msg_call_overhead = (method_id("f(bytes)"), tuple_head_ofst, parent_array_len) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + # parent payload - this word will be considered as the head of the + # abi-encoded inner array and it will be added to base ptr leading to an + # arithmetic overflow + buffer_payload = (2**256 - 0x60,) + + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) -def test_abi_decode_nonstrict_head(env, tx_failed, get_contract): +def test_abi_decode_nonstrict_head(env, get_contract): # data isn't strictly encoded - head is 0x21 instead of 0x20 # but the head + length is still within runtime bounds of the parent buffer - code = """ + buffer_size = 32 * 5 + code = f""" @external -def f(x: Bytes[32 * 5]): - y: Bytes[32 * 5] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x a: Bytes[32] = b"a" decoded_y1: DynArray[uint256, 3] = _abi_decode(y, DynArray[uint256, 3]) + assert len(decoded_y1) == 1 and decoded_y1[0] == 0 a = b"aaaa" decoded_y1 = _abi_decode(y, DynArray[uint256, 3]) + assert len(decoded_y1) == 1 and decoded_y1[0] == 0 """ c = get_contract(code) - data = method_id("f(bytes)") + tuple_head_ofst = 0x20 + parent_array_len = 0xA0 + msg_call_overhead = (method_id("f(bytes)"), tuple_head_ofst, parent_array_len) - payload = ( - 0x20, # tuple head - 0xA0, # parent array length + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( # head should be 0x20 but is 0x21 thus the data isn't strictly encoded 0x21, # we don't want to revert on invalid length, so set this to 0 @@ -543,27 +556,30 @@ def f(x: Bytes[32 * 5]): *_replicate(0x03, 2), ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) env.message_call(c.address, data=data) def test_abi_decode_child_head_points_to_parent(tx_failed, get_contract): # data isn't strictly encoded and the head for the inner array - # skipts the corresponding payload and points to other valid section of the parent buffer - code = """ + # skips the corresponding payload and points to other valid section of the + # parent buffer + buffer_size = 14 * 32 + code = f""" @external -def run(x: Bytes[14 * 32]): - y: Bytes[14 * 32] = x +def run(x: Bytes[{buffer_size}]) -> DynArray[DynArray[DynArray[uint256, 2], 1], 2]: + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[DynArray[uint256, 2], 1], 2] = _abi_decode( y, DynArray[DynArray[DynArray[uint256, 2], 1], 2] ) + return decoded_y1 """ c = get_contract(code) # encode [[[1, 1]], [[2, 2]]] and modify the head for [1, 1] # to actually point to [2, 2] - payload = ( + buffer_payload = ( 0x20, # top-level array head 0x02, # top-level array length 0x40, # head of DAr[DAr[DAr, uint256]]][0] @@ -582,30 +598,33 @@ def run(x: Bytes[14 * 32]): 0x02, # DAr[DAr[DAr, uint256]]][1][0][1] ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - c.run(data) + res = c.run(data) + assert res == [[[2, 2]], [[2, 2]]] def test_abi_decode_nonstrict_head_oob(tx_failed, get_contract): # data isn't strictly encoded and (non_strict_head + len(DynArray[..][2])) > parent_static_sz # thus decoding the data pointed to by the head would cause an OOB read # non_strict_head + length == parent + parent_static_sz + 1 - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length - # non_strict_head - if the length pointed to by this head is 0x60 (which is valid - # length for the Bytes[32*3] buffer), the decoding function would decode - # 1 byte over the end of the buffer - # we define the non_strict_head as: skip the remaining heads, 1st and 2nd tail + # non_strict_head - if the length pointed to by this head is 0x60 + # (which is valid length for the Bytes[32*3] buffer), the decoding + # function would decode 1 byte over the end of the buffer + # we define the non_strict_head as: + # skip the remaining heads, 1st and 2nd tail # to the third tail + 1B 0x20 * 8 + 0x20 * 3 + 0x01, # inner array0 head 0x20 * 4 + 0x20 * 3, # inner array1 head @@ -622,7 +641,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x03, 2), ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -631,10 +650,11 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_nonstrict_head_oob2(tx_failed, get_contract): # same principle as in Test_abi_decode_nonstrict_head_oob # but adapted for dynarrays - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -642,7 +662,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length (0x20 * 8 + 0x20 * 3 + 0x01), # inner array0 head @@ -658,7 +678,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 2), # DynArray[..][2] data ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -666,33 +686,36 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_head_pointing_outside_buffer(tx_failed, get_contract): # the head points completely outside the buffer - code = """ + buffer_size = 3 * 32 + code = f""" @external -def run(x: Bytes[3 * 32]): - y: Bytes[3 * 32] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[32] = _abi_decode(y, Bytes[32]) """ c = get_contract(code) - payload = (0x80, 0x20, 0x01) - data = _abi_payload_from_tuple(payload) + buffer_payload = (0x80, 0x20, 0x01) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_bytearray_clamp(tx_failed, get_contract): - # data has valid encoding, but the length of DynArray[Bytes[96], 3][0] is set to 0x61 + # data has valid encoding, but the length of DynArray[Bytes[96], 3][0] is + # set to 0x61 # and thus the decoding should fail on bytestring clamp - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -707,32 +730,38 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[Bytes[96], 3][2] data ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_runtimesz_oob(tx_failed, get_contract, env): - # provide enough data, but set the runtime size to be smaller than the actual size - # so after y: [..] = x, y will have the incorrect size set and only part of the - # original data will be copied. This will cause oob read outside the - # runtime sz (but still within static size of the buffer) - code = """ + # provide enough data, but set the runtime size to be smaller than the + # actual size so after y: [..] = x, y will have the incorrect size set and + # only part of the original data will be copied. This will cause oob read + # outside the runtime sz (but still within static size of the buffer) + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Bytes[32 * 3], 3] = _abi_decode(y, DynArray[Bytes[32 * 3], 3]) """ c = get_contract(code) - data = method_id("f(bytes)") - - payload = ( + msg_call_overhead = ( + method_id("f(bytes)"), 0x20, # tuple head # the correct size is 0x220 (2*32+3*32+4*3*32) - # therefore we will decode after the end of runtime size (but still within the buffer) + # therefore we will decode after the end of runtime size (but still + # within the buffer) 0x01E4, # top-level bytes array length + ) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -746,7 +775,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[Bytes[96], 3][2] data ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) @@ -755,10 +784,11 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_runtimesz_oob2(tx_failed, get_contract, env): # same principle as in test_abi_decode_runtimesz_oob # but adapted for dynarrays - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def f(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -766,11 +796,15 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - data = method_id("f(bytes)") - - payload = ( + msg_call_overhead = ( + method_id("f(bytes)"), 0x20, # tuple head 0x01E4, # top-level bytes array length + ) + + data = _abi_payload_from_tuple(msg_call_overhead, BUFFER_OVERHEAD) + + buffer_payload = ( 0x20, # DynArray head 0x03, # DynArray length 0x20 * 3, # inner array0 head @@ -784,7 +818,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x01, 3), # DynArray[..][2] data ) - data += _abi_payload_from_tuple(payload) + data += _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): env.message_call(c.address, data=data) @@ -796,11 +830,13 @@ def test_abi_decode_head_roundtrip(tx_failed, get_contract, env): # which are in turn in the y2 buffer # NOTE: the test is memory allocator dependent - we assume that y1 and y2 # have the 800 & 960 addresses respectively - code = """ + buffer_size1 = 4 * 32 + buffer_size2 = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y1: Bytes[4*32] = x1 # addr: 800 - y2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x2 # addr: 960 +def run(x1: Bytes[{buffer_size1}], x2: Bytes[{buffer_size2}]): + y1: Bytes[{buffer_size1}] = x1 # addr: 800 + y2: Bytes[{buffer_size2}] = x2 # addr: 960 decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y2, DynArray[DynArray[uint256, 3], 3] @@ -808,7 +844,7 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x03, # DynArray length # distance to y2 from y1 is 160 160 + 0x20 + 0x20 * 3, # points to DynArray[..][0] length @@ -816,9 +852,9 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): 160 + 0x20 + 0x20 * 8 + 0x20 * 3, # points to DynArray[..][2] length ) - data1 = _abi_payload_from_tuple(payload) + data1 = _abi_payload_from_tuple(buffer_payload, buffer_size1) - payload = ( + buffer_payload = ( # (960 + (2**256 - 160)) % 2**256 == 800, ie will roundtrip to y1 2**256 - 160, # points to y1 0x03, # DynArray length (not used) @@ -833,7 +869,7 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): *_replicate(0x03, 3), # DynArray[..][2] data ) - data2 = _abi_payload_from_tuple(payload) + data2 = _abi_payload_from_tuple(buffer_payload, buffer_size2) with tx_failed(): c.run(data1, data2) @@ -841,22 +877,23 @@ def run(x1: Bytes[4 * 32], x2: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): def test_abi_decode_merge_head_and_length(get_contract): # compress head and length into 33B - code = """ + buffer_size = 32 * 2 + 8 * 32 + code = f""" @external -def run(x: Bytes[32 * 2 + 8 * 32]) -> uint256: - y: Bytes[32 * 2 + 8 * 32] = x +def run(x: Bytes[{buffer_size}]) -> Bytes[{buffer_size}]: + y: Bytes[{buffer_size}] = x decoded_y1: Bytes[256] = _abi_decode(y, Bytes[256]) - return len(decoded_y1) + return decoded_y1 """ c = get_contract(code) - payload = (0x01, (0x00).to_bytes(1, "big"), *_replicate(0x00, 8)) + buffer_payload = (0x01, (0x00).to_bytes(1, "big"), *_replicate(0x00, 8)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - length = c.run(data) + res = c.run(data) - assert length == 256 + assert res == bytes(256) def test_abi_decode_extcall_invalid_head(tx_failed, get_contract): @@ -880,8 +917,8 @@ def foo(): def test_abi_decode_extcall_oob(tx_failed, get_contract): # the head returned from the extcall is 1 byte bigger than expected - # thus we'll take the last 31 0-bytes from tuple[1] and the 1st byte from tuple[2] - # and consider this the length - thus the length is 2**5 + # thus we'll take the last 31 0-bytes from tuple[1] and the 1st byte from + # tuple[2] and consider this the length - thus the length is 2**5 # and thus we'll read 1B over the buffer end (33 + 32 + 32) code = """ @external @@ -902,7 +939,8 @@ def foo(): def test_abi_decode_extcall_runtimesz_oob(tx_failed, get_contract): # the runtime size (33) is bigger than the actual payload (32 bytes) - # thus we'll read 1B over the runtime size - but still within the static size of the buffer + # thus we'll read 1 byte over the runtime size - but still within the + # static size of the buffer code = """ @external def bar() -> (uint256, uint256, uint256): @@ -932,11 +970,13 @@ def bar() -> (uint256, uint256, uint256, uint256): def bar() -> Bytes[32]: nonpayable @external -def foo(): - x:Bytes[32] = extcall A(self).bar() +def foo() -> Bytes[32]: + return extcall A(self).bar() """ c = get_contract(code) - c.foo() + res = c.foo() + + assert res == (36).to_bytes(32, "big") def test_abi_decode_extcall_truncate_returndata2(tx_failed, get_contract): @@ -1053,12 +1093,14 @@ def bar() -> (uint256, uint256): def bar() -> DynArray[Bytes[32], 2]: nonpayable @external -def run(): - x: DynArray[Bytes[32], 2] = extcall A(self).bar() +def run() -> DynArray[Bytes[32], 2]: + return extcall A(self).bar() """ c = get_contract(code) - c.run() + res = c.run() + + assert res == [] def test_abi_decode_extcall_complex_empty_dynarray(get_contract): @@ -1079,13 +1121,14 @@ def bar() -> (uint256, uint256, uint256, uint256, uint256, uint256): def bar() -> DynArray[Point, 2]: nonpayable @external -def run(): - x: DynArray[Point, 2] = extcall A(self).bar() - assert len(x) == 1 and len(x[0].y) == 0 +def run() -> DynArray[Point, 2]: + return extcall A(self).bar() """ c = get_contract(code) - c.run() + res = c.run() + + assert res == [(1, [], 0)] def test_abi_decode_extcall_complex_empty_dynarray2(tx_failed, get_contract): @@ -1124,21 +1167,21 @@ def bar() -> (uint256, uint256): def bar() -> DynArray[Bytes[32], 2]: nonpayable @external -def run() -> uint256: - x: DynArray[Bytes[32], 2] = extcall A(self).bar() - return len(x) +def run() -> DynArray[Bytes[32], 2]: + return extcall A(self).bar() """ c = get_contract(code) - length = c.run() + res = c.run() - assert length == 0 + assert res == [] def test_abi_decode_top_level_head_oob(tx_failed, get_contract): - code = """ + buffer_size = 256 + code = f""" @external -def run(x: Bytes[256], y: uint256): +def run(x: Bytes[{buffer_size}], y: uint256): player_lost: bool = empty(bool) if y == 1: @@ -1150,9 +1193,9 @@ def run(x: Bytes[256], y: uint256): c = get_contract(code) # head points over the buffer end - payload = (0x0100, *_replicate(0x00, 7)) + bufffer_payload = (0x0100, *_replicate(0x00, 7)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(bufffer_payload, buffer_size) with tx_failed(): c.run(data, 1) @@ -1162,23 +1205,24 @@ def run(x: Bytes[256], y: uint256): def test_abi_decode_dynarray_complex_insufficient_data(env, tx_failed, get_contract): - code = """ + buffer_size = 32 * 8 + code = f""" struct Point: x: uint256 y: uint256 @external -def run(x: Bytes[32 * 8]): - y: Bytes[32 * 8] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Point, 3] = _abi_decode(y, DynArray[Point, 3]) """ c = get_contract(code) # runtime buffer has insufficient size - we decode 3 points, but provide only # 3 * 32B of payload - payload = (0x20, 0x03, *_replicate(0x03, 3)) + buffer_payload = (0x20, 0x03, *_replicate(0x03, 3)) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -1187,7 +1231,8 @@ def run(x: Bytes[32 * 8]): def test_abi_decode_dynarray_complex2(env, tx_failed, get_contract): # point head to the 1st 0x01 word (ie the length) # but size of the point is 3 * 32B, thus we'd decode 2B over the buffer end - code = """ + buffer_size = 32 * 8 + code = f""" struct Point: x: uint256 y: uint256 @@ -1195,19 +1240,19 @@ def test_abi_decode_dynarray_complex2(env, tx_failed, get_contract): @external -def run(x: Bytes[32 * 8]): +def run(x: Bytes[{buffer_size}]): y: Bytes[32 * 11] = x decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0xC0, # points to the 1st 0x01 word (ie the length) *_replicate(0x03, 5), *_replicate(0x01, 2), ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) @@ -1216,7 +1261,8 @@ def run(x: Bytes[32 * 8]): def test_abi_decode_complex_empty_dynarray(env, tx_failed, get_contract): # point head to the last word of the payload # this will be the length, but because it's set to 0, the decoding should succeed - code = """ + buffer_size = 32 * 16 + code = f""" struct Point: x: uint256 y: DynArray[uint256, 2] @@ -1224,14 +1270,13 @@ def test_abi_decode_complex_empty_dynarray(env, tx_failed, get_contract): @external -def run(x: Bytes[32 * 16]): - y: Bytes[32 * 16] = x - decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) - assert len(decoded_y1) == 1 and len(decoded_y1[0].y) == 0 +def run(x: Bytes[{buffer_size}]) -> DynArray[Point, 2]: + y: Bytes[{buffer_size}] = x + return _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, 0x01, 0x20, @@ -1243,14 +1288,17 @@ def run(x: Bytes[32 * 16]): 0x00, # length is 0, so decoding should succeed ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) + + res = c.run(data) - c.run(data) + assert res == [(1, [], 4)] def test_abi_decode_complex_arithmetic_overflow(tx_failed, get_contract): # inner head roundtrips due to arithmetic overflow - code = """ + buffer_size = 32 * 16 + code = f""" struct Point: x: uint256 y: DynArray[uint256, 2] @@ -1258,13 +1306,13 @@ def test_abi_decode_complex_arithmetic_overflow(tx_failed, get_contract): @external -def run(x: Bytes[32 * 16]): - y: Bytes[32 * 16] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) """ c = get_contract(code) - payload = ( + buffer_payload = ( 0x20, 0x01, 0x20, @@ -1276,39 +1324,43 @@ def run(x: Bytes[32 * 16]): 0x00, ) - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_abi_decode_empty_toplevel_dynarray(get_contract): - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]) -> DynArray[DynArray[uint256, 3], 3]: + y: Bytes[{buffer_size}] = x assert len(y) == 2 * 32 decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] ) - assert len(decoded_y1) == 0 + return decoded_y1 """ c = get_contract(code) - payload = (0x20, 0x00) # DynArray head, DynArray length + buffer_payload = (0x20, 0x00) # DynArray head, DynArray length + + data = _abi_payload_from_tuple(buffer_payload, buffer_size) - data = _abi_payload_from_tuple(payload) + res = c.run(data) - c.run(data) + assert res == [] def test_abi_decode_invalid_toplevel_dynarray_head(tx_failed, get_contract): # head points 1B over the bounds of the runtime buffer - code = """ + buffer_size = 2 * 32 + 3 * 32 + 3 * 32 * 4 + code = f""" @external -def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): - y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x +def run(x: Bytes[{buffer_size}]): + y: Bytes[{buffer_size}] = x decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( y, DynArray[DynArray[uint256, 3], 3] @@ -1317,33 +1369,34 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): c = get_contract(code) # head points 1B over the bounds of the runtime buffer - payload = (0x21, 0x00) # DynArray head, DynArray length + buffer_payload = (0x21, 0x00) # DynArray head, DynArray length - data = _abi_payload_from_tuple(payload) + data = _abi_payload_from_tuple(buffer_payload, buffer_size) with tx_failed(): c.run(data) def test_nested_invalid_dynarray_head(get_contract, tx_failed): - code = """ + buffer_size = 320 + code = f""" @nonpayable @external -def foo(x:Bytes[320]): +def foo(x:Bytes[{buffer_size}]): if True: a: Bytes[320-32] = b'' # make the word following the buffer x_mem dirty to make a potential # OOB revert fake_head: uint256 = 32 - x_mem: Bytes[320] = x + x_mem: Bytes[{buffer_size}] = x y: DynArray[DynArray[uint256, 2], 2] = _abi_decode(x_mem,DynArray[DynArray[uint256, 2], 2]) @nonpayable @external -def bar(x:Bytes[320]): - x_mem: Bytes[320] = x +def bar(x:Bytes[{buffer_size}]): + x_mem: Bytes[{buffer_size}] = x y:DynArray[DynArray[uint256, 2], 2] = _abi_decode(x_mem,DynArray[DynArray[uint256, 2], 2]) """ @@ -1355,7 +1408,7 @@ def bar(x:Bytes[320]): # 0x0, # head2 ) - encoded = _abi_payload_from_tuple(encoded + inner) + encoded = _abi_payload_from_tuple(encoded + inner, buffer_size) with tx_failed(): c.foo(encoded) # revert with tx_failed(): @@ -1363,22 +1416,23 @@ def bar(x:Bytes[320]): def test_static_outer_type_invalid_heads(get_contract, tx_failed): - code = """ + buffer_size = 320 + code = f""" @nonpayable @external -def foo(x:Bytes[320]): - x_mem: Bytes[320] = x +def foo(x:Bytes[{buffer_size}]): + x_mem: Bytes[{buffer_size}] = x y:DynArray[uint256, 2][2] = _abi_decode(x_mem,DynArray[uint256, 2][2]) @nonpayable @external -def bar(x:Bytes[320]): +def bar(x:Bytes[{buffer_size}]): if True: a: Bytes[160] = b'' # write stuff here to make the call revert in case decode do # an out of bound access: fake_head: uint256 = 32 - x_mem: Bytes[320] = x + x_mem: Bytes[{buffer_size}] = x y:DynArray[uint256, 2][2] = _abi_decode(x_mem,DynArray[uint256, 2][2]) """ c = get_contract(code) @@ -1389,7 +1443,7 @@ def bar(x:Bytes[320]): # 0x00, # head of the second dynarray ) - encoded = _abi_payload_from_tuple(encoded + inner) + encoded = _abi_payload_from_tuple(encoded + inner, buffer_size) with tx_failed(): c.foo(encoded) @@ -1402,9 +1456,10 @@ def test_abi_decode_max_size(get_contract, tx_failed): # of abi encoding the type. this can happen when the payload is # "sparse" and has garbage bytes in between the static and dynamic # sections - code = """ + buffer_size = 1000 + code = f""" @external -def foo(a:Bytes[1000]): +def foo(a:Bytes[{buffer_size}]): v: DynArray[uint256, 1] = _abi_decode(a,DynArray[uint256, 1]) """ c = get_contract(code) @@ -1420,7 +1475,7 @@ def foo(a:Bytes[1000]): ) with tx_failed(): - c.foo(_abi_payload_from_tuple(payload)) + c.foo(_abi_payload_from_tuple(payload, buffer_size)) # returndatasize check for uint256 From bd876b114bc34643a7d210b319f69642ce80f018 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 24 Nov 2024 04:34:57 +0800 Subject: [PATCH 2/6] fix[lang]: use folded node for typechecking (#4365) This commit addresses several issues in the frontend where valid code fails to compile because typechecking was performed on non-literal AST nodes, specifically in `slice()` and `raw_log()` builtins. This is fixed by using the folded node for typechecking instead. Additionally, folding is applied for the argument to `convert()`, which results in the typechecker being able to reject more invalid programs. --- .../functional/codegen/features/test_logging.py | 17 +++++++++++++++++ .../test_invalid_literal_exception.py | 8 ++++++++ .../exceptions/test_type_mismatch_exception.py | 8 ++++++++ tests/functional/syntax/test_slice.py | 16 ++++++++++++++++ vyper/builtins/_convert.py | 2 +- vyper/builtins/functions.py | 7 ++++--- vyper/semantics/analysis/local.py | 4 ++-- 7 files changed, 56 insertions(+), 6 deletions(-) diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 2bb646e6ef..87d848fae5 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -1254,6 +1254,23 @@ def foo(): assert log.topics == [event_id, topic1, topic2, topic3] +valid_list = [ + # test constant folding inside raw_log + """ +topic: constant(bytes32) = 0x1212121212121210212801291212121212121210121212121212121212121212 + +@external +def foo(): + raw_log([[topic]][0], b'') + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_raw_log_pass(code): + assert compile_code(code) is not None + + fail_list = [ ( """ diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index a0cf10ad02..f3fd73fbfc 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -36,6 +36,14 @@ def foo(): def foo(): a: bytes32 = keccak256("ั“test") """, + # test constant folding inside of `convert()` + """ +BAR: constant(uint16) = 256 + +@external +def foo(): + a: uint8 = convert(BAR, uint8) + """, ] diff --git a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py index 76c5c481f0..63e0eb6d11 100644 --- a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py +++ b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py @@ -47,6 +47,14 @@ def foo(): """ a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 """, + # test constant folding inside `convert()` + """ +BAR: constant(Bytes[5]) = b"vyper" + +@external +def foo(): + a: Bytes[4] = convert(BAR, Bytes[4]) + """, ] diff --git a/tests/functional/syntax/test_slice.py b/tests/functional/syntax/test_slice.py index 6bb666527e..6a091c9da3 100644 --- a/tests/functional/syntax/test_slice.py +++ b/tests/functional/syntax/test_slice.py @@ -53,6 +53,22 @@ def foo(inp: Bytes[10]) -> Bytes[4]: def foo() -> Bytes[10]: return slice(b"badmintonzzz", 1, 10) """, + # test constant folding for `slice()` `length` argument + """ +@external +def foo(): + x: Bytes[32] = slice(msg.data, 0, 31 + 1) + """, + """ +@external +def foo(a: address): + x: Bytes[32] = slice(a.code, 0, 31 + 1) + """, + """ +@external +def foo(inp: Bytes[5], start: uint256) -> Bytes[3]: + return slice(inp, 0, 1 + 1) + """, ] diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index aa53dee429..a494e4a344 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -463,7 +463,7 @@ def to_flag(expr, arg, out_typ): def convert(expr, context): assert len(expr.args) == 2, "bad typecheck: convert" - arg_ast = expr.args[0] + arg_ast = expr.args[0].reduced() arg = Expr(arg_ast, context).ir_node original_arg = arg diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 674efda7ce..9ed74b8cfe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -305,7 +305,7 @@ def fetch_call_return(self, node): arg = node.args[0] start_expr = node.args[1] - length_expr = node.args[2] + length_expr = node.args[2].reduced() # CMC 2022-03-22 NOTE slight code duplication with semantics/analysis/local is_adhoc_slice = arg.get("attr") == "code" or ( @@ -1257,7 +1257,8 @@ def fetch_call_return(self, node): def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) - if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: + arg = node.args[0].reduced() + if not isinstance(arg, vy_ast.List) or len(arg.elements) > 4: raise InvalidType("Expecting a list of 0-4 topics as first argument", node.args[0]) # return a concrete type for `data` @@ -1269,7 +1270,7 @@ def infer_arg_types(self, node, expected_return_typ=None): def build_IR(self, expr, args, kwargs, context): context.check_is_not_constant(f"use {self._id}", expr) - topics_length = len(expr.args[0].elements) + topics_length = len(expr.args[0].reduced().elements) topics = args[0].args topics = [unwrap_location(topic) for topic in topics] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 809c6532c6..461326d72d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -136,7 +136,7 @@ def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> Non parent = node.get_ancestor() if isinstance(parent, vy_ast.Call): ok_func = isinstance(parent.func, vy_ast.Name) and parent.func.id == "slice" - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if ok_func and ok_args: return @@ -154,7 +154,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: "msg.data is only allowed inside of the slice, len or raw_call functions", node ) if parent.get("func.id") == "slice": - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if not ok_args: raise StructureException( "slice(msg.data) must use a compile-time constant for length argument", parent From 8f433f8de9ec3ead39e1691c45e2821fe8e3922b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 23 Nov 2024 21:39:41 +0100 Subject: [PATCH 3/6] refactor[tool]: refactor `compile_from_zip()` (#4366) refactor `compile_from_zip()`, and also a generalized `outputs_from_compiler_data()` so the user can pass a `CompilerData` instead of `FileInput` + a bunch of settings. --- vyper/cli/compile_archive.py | 12 ++++++++---- vyper/compiler/__init__.py | 20 ++++++++++++-------- vyper/compiler/phases.py | 4 ++++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/vyper/cli/compile_archive.py b/vyper/cli/compile_archive.py index 1b52343c1c..c6d07de9f1 100644 --- a/vyper/cli/compile_archive.py +++ b/vyper/cli/compile_archive.py @@ -8,8 +8,9 @@ import zipfile from pathlib import PurePath -from vyper.compiler import compile_from_file_input +from vyper.compiler import outputs_from_compiler_data from vyper.compiler.input_bundle import FileInput, ZipInputBundle +from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings, merge_settings from vyper.exceptions import BadArchive @@ -19,6 +20,11 @@ class NotZipInput(Exception): def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): + compiler_data = compiler_data_from_zip(file_name, settings, no_bytecode_metadata) + return outputs_from_compiler_data(compiler_data, output_formats) + + +def compiler_data_from_zip(file_name, settings, no_bytecode_metadata): with open(file_name, "rb") as f: bcontents = f.read() @@ -59,11 +65,9 @@ def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): settings, archive_settings, lhs_source="command line", rhs_source="archive settings" ) - # TODO: validate integrity sum (probably in CompilerData) - return compile_from_file_input( + return CompilerData( file, input_bundle=input_bundle, - output_formats=output_formats, integrity_sum=integrity, settings=settings, no_bytecode_metadata=no_bytecode_metadata, diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0345c24931..d885599cec 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -99,13 +99,6 @@ def compile_from_file_input( """ settings = settings or get_global_settings() or Settings() - if output_formats is None: - output_formats = ("bytecode",) - - # make IR output the same between runs - # TODO: move this to CompilerData.__init__() - codegen.reset_names() - compiler_data = CompilerData( file_input, input_bundle, @@ -116,6 +109,17 @@ def compile_from_file_input( no_bytecode_metadata=no_bytecode_metadata, ) + return outputs_from_compiler_data(compiler_data, output_formats, exc_handler) + + +def outputs_from_compiler_data( + compiler_data: CompilerData, + output_formats: Optional[OutputFormats] = None, + exc_handler: Optional[Callable] = None, +): + if output_formats is None: + output_formats = ("bytecode",) + ret = {} with anchor_settings(compiler_data.settings): for output_format in output_formats: @@ -126,7 +130,7 @@ def compile_from_file_input( ret[output_format] = formatter(compiler_data) except Exception as exc: if exc_handler is not None: - exc_handler(str(file_input.path), exc) + exc_handler(str(compiler_data.file_input.path), exc) else: raise exc diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index d9b6b13b48..503281a867 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -4,6 +4,7 @@ from pathlib import Path, PurePath from typing import Any, Optional +import vyper.codegen.core as codegen from vyper import ast as vy_ast from vyper.ast import natspec from vyper.codegen import module @@ -304,6 +305,9 @@ def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, IR to generate deployment bytecode IR to generate runtime bytecode """ + # make IR output the same between runs + codegen.reset_names() + with anchor_settings(settings): ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) if settings.optimize != OptimizationLevel.NONE: From f249c9364a07044135e368bf846a0da1477d62e3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 25 Nov 2024 18:08:45 +0100 Subject: [PATCH 4/6] feat[lang]: add `module.__at__()` to cast to interface (#4090) add `module.__at__`, a new `MemberFunctionT`, which allows the user to cast addresses to a module's interface. additionally, fix a bug where interfaces defined inline could not be exported. this is simultaneously fixed as a related bug because previously, interfaces could come up in export analysis as `InterfaceT` or `TYPE_T` depending on their provenance. this commit fixes the bug by making them `TYPE_T` in both imported and inlined provenance. this also allows `module.__interface__` to be used in export position by adding it to `ModuleT`'s members. note this has an unwanted side effect of allowing `module.__interface__` in call position; in other words, `module.__interface__(
)` has the same behavior as `module.__at__(
)` when use as an expression. this can be addressed in a later refactor. refactor: - wrap interfaces in `TYPE_T` - streamline an `isinstance(t, (VyperType, TYPE_T))` check. TYPE_T` now inherits from `VyperType`, so it doesn't need to be listed separately --------- Co-authored-by: cyberthirst --- docs/using-modules.rst | 15 +++ .../codegen/modules/test_exports.py | 23 ++++ .../codegen/modules/test_interface_imports.py | 36 +++++- tests/functional/codegen/test_interfaces.py | 89 +++++++++++++++ .../syntax/modules/test_deploy_visibility.py | 34 +++++- .../functional/syntax/modules/test_exports.py | 106 ++++++++++++++++++ tests/functional/syntax/test_interfaces.py | 50 +++++++++ vyper/codegen/expr.py | 24 ++-- vyper/compiler/output.py | 3 + vyper/semantics/analysis/base.py | 1 + vyper/semantics/analysis/module.py | 24 +++- vyper/semantics/analysis/utils.py | 2 +- vyper/semantics/types/__init__.py | 4 +- vyper/semantics/types/base.py | 10 +- vyper/semantics/types/function.py | 2 +- vyper/semantics/types/module.py | 36 ++++-- 16 files changed, 431 insertions(+), 28 deletions(-) diff --git a/docs/using-modules.rst b/docs/using-modules.rst index 7d63eb6617..4400a8dfa8 100644 --- a/docs/using-modules.rst +++ b/docs/using-modules.rst @@ -62,6 +62,21 @@ The ``_times_two()`` helper function in the above module can be immediately used The other functions cannot be used yet, because they touch the ``ownable`` module's state. There are two ways to declare a module so that its state can be used. +Using a module as an interface +============================== + +A module can be used as an interface with the ``__at__`` syntax. + +.. code-block:: vyper + + import ownable + + an_ownable: ownable.__interface__ + + def call_ownable(addr: address): + self.an_ownable = ownable.__at__(addr) + self.an_ownable.transfer_ownership(...) + Initializing a module ===================== diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 93f4fe6c2f..3cc21d61a9 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -440,3 +440,26 @@ def __init__(): # call `c.__default__()` env.message_call(c.address) assert c.counter() == 6 + + +def test_inline_interface_export(make_input_bundle, get_contract): + lib1 = """ +interface IAsset: + def asset() -> address: view + +implements: IAsset + +@external +@view +def asset() -> address: + return self + """ + main = """ +import lib1 + +exports: lib1.IAsset + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.asset() == c.address diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py index 3f0f8cb010..af9f9b5e68 100644 --- a/tests/functional/codegen/modules/test_interface_imports.py +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -1,3 +1,6 @@ +import pytest + + def test_import_interface_types(make_input_bundle, get_contract): ifaces = """ interface IFoo: @@ -50,9 +53,10 @@ def foo() -> bool: # check that this typechecks both directions a: lib1.IERC20 = IERC20(msg.sender) b: lib2.IERC20 = IERC20(msg.sender) + c: IERC20 = lib1.IERC20(msg.sender) # allowed in call position # return the equality so we can sanity check it - return a == b + return a == b and b == c """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) c = get_contract(main, input_bundle=input_bundle) @@ -60,6 +64,36 @@ def foo() -> bool: assert c.foo() is True +@pytest.mark.parametrize("interface_syntax", ["__at__", "__interface__"]) +def test_intrinsic_interface(get_contract, make_input_bundle, interface_syntax): + lib = """ +@external +@view +def foo() -> uint256: + # detect self call + if msg.sender == self: + return 4 + else: + return 5 + """ + + main = f""" +import lib + +exports: lib.__interface__ + +@external +@view +def bar() -> uint256: + return staticcall lib.{interface_syntax}(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.foo() == 5 + assert c.bar() == 4 + + def test_import_interface_flags(make_input_bundle, get_contract): ifaces = """ flag Foo: diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 8887bf07cb..31475a3bc0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -774,3 +774,92 @@ def foo(s: MyStruct) -> MyStruct: assert "b: uint256" in out assert "struct Voter:" in out assert "voted: bool" in out + + +def test_intrinsic_interface_instantiation(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +i: lib1.__interface__ + +@external +def bar() -> lib1.__interface__: + self.i = lib1.__at__(self) + return self.i + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_converts(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +@external +def bar() -> lib1.__interface__: + return lib1.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_kws(env, make_input_bundle, get_contract): + value = 10**5 + lib1 = f""" +@external +@payable +def foo(a: address): + send(a, {value}) + """ + main = f""" +import lib1 + +exports: lib1.__interface__ + +@external +def bar(a: address): + extcall lib1.__at__(self).foo(a, value={value}) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + env.set_balance(c.address, value) + original_balance = env.get_balance(env.deployer) + c.bar(env.deployer) + assert env.get_balance(env.deployer) == original_balance + value + + +def test_intrinsic_interface_defaults(env, make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def foo(i: uint256=1) -> uint256: + return i + """ + main = """ +import lib1 + +exports: lib1.__interface__ + +@external +def bar() -> uint256: + return extcall lib1.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.bar() == 1 diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py index f51bf9575b..c908d4adae 100644 --- a/tests/functional/syntax/modules/test_deploy_visibility.py +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -1,7 +1,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import CallViolation +from vyper.exceptions import CallViolation, UnknownAttribute def test_call_deploy_from_external(make_input_bundle): @@ -25,3 +25,35 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value.message == "Cannot call an @deploy function from an @external function!" + + +@pytest.mark.parametrize("interface_syntax", ["__interface__", "__at__"]) +def test_module_interface_init(make_input_bundle, tmp_path, interface_syntax): + lib1 = """ +#lib1.vy +k: uint256 + +@external +def bar(): + pass + +@deploy +def __init__(): + self.k = 10 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + code = f""" +import lib1 + +@deploy +def __init__(): + lib1.{interface_syntax}(self).__init__() + """ + + with pytest.raises(UnknownAttribute) as e: + compile_code(code, input_bundle=input_bundle) + + # as_posix() for windows tests + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value.message == f"interface {lib1_path} has no member '__init__'." diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 7b00d29c98..4314c1bbf0 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -385,6 +385,28 @@ def do_xyz(): assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" +def test_no_export_unimplemented_inline_interface(make_input_bundle): + lib1 = """ +interface ifoo: + def do_xyz(): nonpayable + +# technically implements ifoo, but missing `implements: ifoo` + +@external +def do_xyz(): + pass + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + def test_export_selector_conflict(make_input_bundle): ifoo = """ @external @@ -444,3 +466,87 @@ def __init__(): with pytest.raises(InterfaceViolation) as e: compile_code(main, input_bundle=input_bundle) assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + +def test_export_empty_interface(make_input_bundle, tmp_path): + lib1 = """ +def an_internal_function(): + pass + """ + main = """ +import lib1 + +exports: lib1.__interface__ + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + # as_posix() for windows + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value._message == f"lib1 (located at `{lib1_path}`) has no external functions!" + + +def test_invalid_export(make_input_bundle): + lib1 = """ +@external +def foo(): + pass + """ + main = """ +import lib1 +a: address + +exports: lib1.__interface__(self.a).foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "invalid export of a value" + assert e.value._hint == "exports should look like ." + + main = """ +interface Foo: + def foo(): nonpayable + +exports: Foo + """ + with pytest.raises(StructureException) as e: + compile_code(main) + + assert e.value._message == "invalid export" + assert e.value._hint == "exports should look like ." + + +@pytest.mark.parametrize("exports_item", ["__at__", "__at__(self)", "__at__(self).__interface__"]) +def test_invalid_at_exports(get_contract, make_input_bundle, exports_item): + lib = """ +@external +@view +def foo() -> uint256: + return 5 + """ + + main = f""" +import lib + +exports: lib.{exports_item} + +@external +@view +def bar() -> uint256: + return staticcall lib.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + + with pytest.raises(Exception) as e: + compile_code(main, input_bundle=input_bundle) + + if exports_item == "__at__": + assert "not a function or interface" in str(e.value) + if exports_item == "__at__(self)": + assert "invalid exports" in str(e.value) + if exports_item == "__at__(self).__interface__": + assert "has no member '__interface__'" in str(e.value) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 86ea4bcfd0..baf0c73c30 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -571,3 +571,53 @@ def bar(): compiler.compile_code(code, input_bundle=input_bundle) assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()" + + +def test_intrinsic_interfaces_different_types(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + lib2 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +@external +def bar(): + assert lib1.__at__(self) == lib2.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(TypeMismatch): + compiler.compile_code(main, input_bundle=input_bundle) + + +@pytest.mark.xfail +def test_intrinsic_interfaces_default_function(make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def __default__(): + pass + """ + main = """ +import lib1 + +@external +def bar(): + extcall lib1.__at__(self).__default__() + + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + # TODO make the exception more precise once fixed + with pytest.raises(Exception): # noqa: B017 + compiler.compile_code(main, input_bundle=input_bundle) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index cd51966710..3a09bbe6c0 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -51,6 +51,7 @@ FlagT, HashMapT, InterfaceT, + ModuleT, SArrayT, StringT, StructT, @@ -680,7 +681,8 @@ def parse_Call(self): # TODO fix cyclic import from vyper.builtins._signatures import BuiltinFunctionT - func_t = self.expr.func._metadata["type"] + func = self.expr.func + func_t = func._metadata["type"] if isinstance(func_t, BuiltinFunctionT): return func_t.build_IR(self.expr, self.context) @@ -691,8 +693,14 @@ def parse_Call(self): return self.handle_struct_literal() # Interface constructor. Bar(
). - if is_type_t(func_t, InterfaceT): + if is_type_t(func_t, InterfaceT) or func.get("attr") == "__at__": assert not self.is_stmt # sanity check typechecker + + # magic: do sanity checks for module.__at__ + if func.get("attr") == "__at__": + assert isinstance(func_t, MemberFunctionT) + assert isinstance(func.value._metadata["type"], ModuleT) + (arg0,) = self.expr.args arg_ir = Expr(arg0, self.context).ir_node @@ -702,16 +710,16 @@ def parse_Call(self): return arg_ir if isinstance(func_t, MemberFunctionT): - darray = Expr(self.expr.func.value, self.context).ir_node + # TODO consider moving these to builtins or a dedicated file + darray = Expr(func.value, self.context).ir_node assert isinstance(darray.typ, DArrayT) args = [Expr(x, self.context).ir_node for x in self.expr.args] - if self.expr.func.attr == "pop": - # TODO consider moving this to builtins - darray = Expr(self.expr.func.value, self.context).ir_node + if func.attr == "pop": + darray = Expr(func.value, self.context).ir_node assert len(self.expr.args) == 0 return_item = not self.is_stmt return pop_dyn_array(darray, return_popped_item=return_item) - elif self.expr.func.attr == "append": + elif func.attr == "append": (arg,) = args check_assign( dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) @@ -726,6 +734,8 @@ def parse_Call(self): ret.append(append_dyn_array(darray, arg)) return IRnode.from_list(ret) + raise CompilerPanic("unreachable!") # pragma: nocover + assert isinstance(func_t, ContractFunctionT) assert func_t.is_internal or func_t.is_constructor diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index f5f99a0bc3..e0eea293bc 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -268,6 +268,9 @@ def build_abi_output(compiler_data: CompilerData) -> list: _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() + if module_t.init_function: + abi += module_t.init_function.to_toplevel_abi_dict() + if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e275930fa0..adfc7540a0 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -96,6 +96,7 @@ class AnalysisResult: class ModuleInfo(AnalysisResult): module_t: "ModuleT" alias: str + # import_node: vy_ast._ImportStmt # maybe could be useful ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP ownership_decl: Optional[vy_ast.VyperNode] = None diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8a2beb61e6..737f675b7c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -40,7 +40,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace -from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT +from vyper.semantics.types import TYPE_T, EventT, FlagT, InterfaceT, StructT, is_type_t from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation @@ -499,9 +499,19 @@ def visit_ExportsDecl(self, node): raise StructureException("not a public variable!", decl, item) funcs = [decl._expanded_getter._metadata["func_type"]] elif isinstance(info.typ, ContractFunctionT): + # e.g. lib1.__interface__(self._addr).foo + if not isinstance(get_expr_info(item.value).typ, (ModuleT, TYPE_T)): + raise StructureException( + "invalid export of a value", + item.value, + hint="exports should look like .", + ) + # regular function funcs = [info.typ] - elif isinstance(info.typ, InterfaceT): + elif is_type_t(info.typ, InterfaceT): + interface_t = info.typ.typedef + if not isinstance(item, vy_ast.Attribute): raise StructureException( "invalid export", @@ -512,7 +522,7 @@ def visit_ExportsDecl(self, node): if module_info is None: raise StructureException("not a valid module!", item.value) - if info.typ not in module_info.typ.implemented_interfaces: + if interface_t not in module_info.typ.implemented_interfaces: iface_str = item.node_source_code module_str = item.value.node_source_code msg = f"requested `{iface_str}` but `{module_str}`" @@ -523,9 +533,15 @@ def visit_ExportsDecl(self, node): # find the specific implementation of the function in the module funcs = [ module_exposed_fns[fn.name] - for fn in info.typ.functions.values() + for fn in interface_t.functions.values() if fn.is_external ] + + if len(funcs) == 0: + path = module_info.module_node.path + msg = f"{module_info.alias} (located at `{path}`) has no external functions!" + raise StructureException(msg, item) + else: raise StructureException( f"not a function or interface: `{info.typ}`", info.typ.decl_node, item diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 9734087fc3..a31ce7acc1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -199,7 +199,7 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, (VyperType, TYPE_T)): + if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 59a20dd99f..b881f52b2b 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,8 +1,8 @@ from . import primitives, subscriptable, user from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void from .bytestrings import BytesT, StringT, _BytestringT -from .function import MemberFunctionT -from .module import InterfaceT +from .function import ContractFunctionT, MemberFunctionT +from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 128ede0d5b..aca37b33a3 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -114,8 +114,13 @@ def __eq__(self, other): ) def __lt__(self, other): + # CMC 2024-10-20 what is this for? return self.abi_type.selector_name() < other.abi_type.selector_name() + def __repr__(self): + # TODO: add `pretty()` to the VyperType API? + return self._id + # return a dict suitable for serializing in the AST def to_dict(self): ret = {"name": self._id} @@ -362,10 +367,7 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": raise StructureException(f"{self} instance does not have members", node) hint = get_levenshtein_error_suggestions(key, self.members, 0.3) - raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint) - - def __repr__(self): - return self._id + raise UnknownAttribute(f"{repr(self)} has no member '{key}'.", node, hint=hint) class KwargSettings: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7a56b01281..ffeb5b7299 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -874,7 +874,7 @@ def _id(self): return self.name def __repr__(self): - return f"{self.underlying_type._id} member function '{self.name}'" + return f"{self.underlying_type} member function '{self.name}'" def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: validate_call_args(node, len(self.arg_types)) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index dabeaf21b6..498757b94e 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -19,7 +19,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType from vyper.utils import OrderedSet @@ -240,9 +240,6 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for fn_t in module_t.exposed_functions: funcs.append((fn_t.name, fn_t)) - if (fn_t := module_t.init_function) is not None: - funcs.append((fn_t.name, fn_t)) - event_set: OrderedSet[EventT] = OrderedSet() event_set.update([node._metadata["event_type"] for node in module_t.event_defs]) event_set.update(module_t.used_events) @@ -273,6 +270,19 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": return cls._from_lists(node.name, node, functions) +def _module_at(module_t): + return MemberFunctionT( + # set underlying_type to a TYPE_T as a bit of a kludge, since it's + # kind of like a class method (but we don't have classmethod + # abstraction) + underlying_type=TYPE_T(module_t), + name="__at__", + arg_types=[AddressT()], + return_type=module_t.interface, + is_modifying=False, + ) + + # Datatype to store all module information. class ModuleT(VyperType): typeclass = "module" @@ -330,16 +340,28 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.import_stmts: import_info = i._metadata["import_info"] - self.add_member(import_info.alias, import_info.typ) if hasattr(import_info.typ, "module_t"): - self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + module_info = import_info.typ + # get_expr_info uses ModuleInfo + self.add_member(import_info.alias, module_info) + # type_from_annotation uses TYPE_T + self._helper.add_member(import_info.alias, TYPE_T(module_info.module_t)) + else: # interfaces + assert isinstance(import_info.typ, InterfaceT) + self.add_member(import_info.alias, TYPE_T(import_info.typ)) for name, interface_t in self.interfaces.items(): # can access interfaces in type position self._helper.add_member(name, TYPE_T(interface_t)) - self.add_member("__interface__", self.interface) + # module.__at__(addr) + self.add_member("__at__", _module_at(self)) + + # allow `module.__interface__` (in exports declarations) + self.add_member("__interface__", TYPE_T(self.interface)) + # allow `module.__interface__` (in type position) + self._helper.add_member("__interface__", TYPE_T(self.interface)) # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, From cda634dd3a1a20db20d3565f47aa1cf37ede8b9c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Nov 2024 12:38:43 +0100 Subject: [PATCH 5/6] fix[venom]: fix duplicate allocas (#4321) this commit fixes a bug in the ir_node_to_venom translator. previously, `ir_node_to_venom` tried to detect unique allocas based on heuristics. this commit removes the heuristics and fixes the issue in the frontend by passing through a unique ID for each variable in the metadata. this ID is also passed into the `alloca` and `palloca` instructions to aid with debugging. note that this results in improved code, presumably due to more allocas being able to be reified. this commit makes a minor change to the `sqrt()`, builtin, which is to use `z_var.as_ir_node()` instead of `z_var.pos`, since `.as_ir_node()` correctly tags with the alloca metadata. to be maximally conservative, we could branch, only using `z_var.as_ir_node()` if we are using the venom pipeline, but the change should be correct for the legacy pipeline as well anyways. --------- Co-authored-by: Harry Kalogirou --- vyper/builtins/functions.py | 5 ++- vyper/codegen/context.py | 17 ++++++++- vyper/codegen/core.py | 4 +++ vyper/venom/README.md | 5 +-- vyper/venom/__init__.py | 3 ++ vyper/venom/ir_node_to_venom.py | 53 +++++++++++++---------------- vyper/venom/passes/__init__.py | 1 + vyper/venom/passes/float_allocas.py | 36 ++++++++++++++++++++ vyper/venom/passes/sccp/sccp.py | 2 +- 9 files changed, 89 insertions(+), 37 deletions(-) create mode 100644 vyper/venom/passes/float_allocas.py diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 9ed74b8cfe..0cfcb636d7 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2167,10 +2167,9 @@ def build_IR(self, expr, args, kwargs, context): variables_2=variables_2, memory_allocator=context.memory_allocator, ) + z_ir = new_ctx.vars["z"].as_ir_node() ret = IRnode.from_list( - ["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable - typ=DecimalT(), - location=MEMORY, + ["seq", placeholder_copy, sqrt_ir, z_ir], typ=DecimalT(), location=MEMORY ) return b1.resolve(ret) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index f49914ac78..7995b7b9f5 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -15,6 +15,17 @@ class Constancy(enum.Enum): Constant = 1 +_alloca_id = 0 + + +def _generate_alloca_id(): + # note: this gets reset between compiler runs by codegen.core.reset_names + global _alloca_id + + _alloca_id += 1 + return _alloca_id + + @dataclass(frozen=True) class Alloca: name: str @@ -22,6 +33,8 @@ class Alloca: typ: VyperType size: int + _id: int + def __post_init__(self): assert self.typ.memory_bytes_required == self.size @@ -233,7 +246,9 @@ def _new_variable( pos = f"$palloca_{ofst}_{size}" else: pos = f"$alloca_{ofst}_{size}" - alloca = Alloca(name=name, offset=ofst, typ=typ, size=size) + + alloca_id = _generate_alloca_id() + alloca = Alloca(name=name, offset=ofst, typ=typ, size=size, _id=alloca_id) var = VariableRecord( name=name, diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 2bd4f81f50..0ad7fa79c6 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,3 +1,4 @@ +import vyper.codegen.context as ctx from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import _opt_codesize, _opt_gas, _opt_none from vyper.evm.address_space import ( @@ -855,6 +856,9 @@ def reset_names(): global _label _label = 0 + # could be refactored + ctx._alloca_id = 0 + # returns True if t is ABI encoded and is a type that needs any kind of # validation diff --git a/vyper/venom/README.md b/vyper/venom/README.md index 6f3b318c9b..ea6eabebaa 100644 --- a/vyper/venom/README.md +++ b/vyper/venom/README.md @@ -209,15 +209,16 @@ Assembly can be inspected with `-f asm`, whereas an opcode view of the final byt - Effectively translates to `JUMP`, and marks the call site as a valid return destination (for callee to jump back to) by `JUMPDEST`. - `alloca` - ``` - out = alloca size, offset + out = alloca size, offset, id ``` - Allocates memory of a given `size` at a given `offset` in memory. + - The `id` argument is there to help debugging translation into venom - The output is the offset value itself. - Because the SSA form does not allow changing values of registers, handling mutable variables can be tricky. The `alloca` instruction is meant to simplify that. - `palloca` - ``` - out = palloca size, offset + out = palloca size, offset, id ``` - Like the `alloca` instruction but only used for parameters of internal functions which are passed by memory. - `iload` diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index bf3115b4dd..593a9556a9 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -14,6 +14,7 @@ AlgebraicOptimizationPass, BranchOptimizationPass, DFTPass, + FloatAllocas, MakeSSA, Mem2Var, RemoveUnusedVariablesPass, @@ -47,6 +48,8 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: ac = IRAnalysesCache(fn) + FloatAllocas(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() Mem2Var(ac, fn).run_pass() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 02a9f4d1f7..782309d841 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -107,18 +107,16 @@ NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) SymbolTable = dict[str, Optional[IROperand]] -_global_symbols: SymbolTable = None # type: ignore +_alloca_table: SymbolTable = None # type: ignore MAIN_ENTRY_LABEL_NAME = "__main_entry" -_external_functions: dict[int, SymbolTable] = None # type: ignore # convert IRnode directly to venom def ir_node_to_venom(ir: IRnode) -> IRContext: _ = ir.unique_symbols # run unique symbols check - global _global_symbols, _external_functions - _global_symbols = {} - _external_functions = {} + global _alloca_table + _alloca_table = {} ctx = IRContext() fn = ctx.create_function(MAIN_ENTRY_LABEL_NAME) @@ -233,7 +231,7 @@ def pop_source(*args, **kwargs): def _convert_ir_bb(fn, ir, symbols): assert isinstance(ir, IRnode), ir # TODO: refactor these to not be globals - global _break_target, _continue_target, _global_symbols, _external_functions + global _break_target, _continue_target, _alloca_table # keep a map from external functions to all possible entry points @@ -269,8 +267,8 @@ def _convert_ir_bb(fn, ir, symbols): if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", current_func)) > 0: # Internal definition var_list = ir.args[0].args[1] + assert var_list.value == "var_list" does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args - _global_symbols = {} symbols = {} new_fn = _handle_internal_func(fn, ir, does_return_data, symbols) for ir_node in ir.args[1:]: @@ -298,8 +296,6 @@ def _convert_ir_bb(fn, ir, symbols): cont_ret = _convert_ir_bb(fn, cond, symbols) cond_block = fn.get_basic_block() - saved_global_symbols = _global_symbols.copy() - then_block = IRBasicBlock(ctx.get_next_label("then"), fn) else_block = IRBasicBlock(ctx.get_next_label("else"), fn) @@ -314,7 +310,6 @@ def _convert_ir_bb(fn, ir, symbols): # convert "else" cond_symbols = symbols.copy() - _global_symbols = saved_global_symbols.copy() fn.append_basic_block(else_block) else_ret_val = None if len(ir.args) == 3: @@ -343,8 +338,6 @@ def _convert_ir_bb(fn, ir, symbols): if not then_block_finish.is_terminated: then_block_finish.append_instruction("jmp", exit_bb.label) - _global_symbols = saved_global_symbols - return if_ret elif ir.value == "with": @@ -385,13 +378,6 @@ def _convert_ir_bb(fn, ir, symbols): data = _convert_ir_bb(fn, c, symbols) ctx.append_data("db", [data]) # type: ignore elif ir.value == "label": - function_id_pattern = r"external (\d+)" - function_name = ir.args[0].value - m = re.match(function_id_pattern, function_name) - if m is not None: - function_id = m.group(1) - _global_symbols = _external_functions.setdefault(function_id, {}) - label = IRLabel(ir.args[0].value, True) bb = fn.get_basic_block() if not bb.is_terminated: @@ -463,13 +449,11 @@ def _convert_ir_bb(fn, ir, symbols): elif ir.value == "repeat": def emit_body_blocks(): - global _break_target, _continue_target, _global_symbols + global _break_target, _continue_target old_targets = _break_target, _continue_target _break_target, _continue_target = exit_block, incr_block - saved_global_symbols = _global_symbols.copy() _convert_ir_bb(fn, body, symbols.copy()) _break_target, _continue_target = old_targets - _global_symbols = saved_global_symbols sym = ir.args[0] start, end, _ = _convert_ir_bb_list(fn, ir.args[1:4], symbols) @@ -540,16 +524,25 @@ def emit_body_blocks(): elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): _convert_ir_opcode(fn, ir, symbols) elif isinstance(ir.value, str): - if ir.value.startswith("$alloca") and ir.value not in _global_symbols: + if ir.value.startswith("$alloca"): alloca = ir.passthrough_metadata["alloca"] - ptr = fn.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size) - _global_symbols[ir.value] = ptr - elif ir.value.startswith("$palloca") and ir.value not in _global_symbols: + if alloca._id not in _alloca_table: + ptr = fn.get_basic_block().append_instruction( + "alloca", alloca.offset, alloca.size, alloca._id + ) + _alloca_table[alloca._id] = ptr + return _alloca_table[alloca._id] + + elif ir.value.startswith("$palloca"): alloca = ir.passthrough_metadata["alloca"] - ptr = fn.get_basic_block().append_instruction("palloca", alloca.offset, alloca.size) - _global_symbols[ir.value] = ptr - - return _global_symbols.get(ir.value) or symbols.get(ir.value) + if alloca._id not in _alloca_table: + ptr = fn.get_basic_block().append_instruction( + "palloca", alloca.offset, alloca.size, alloca._id + ) + _alloca_table[alloca._id] = ptr + return _alloca_table[alloca._id] + + return symbols.get(ir.value) elif ir.is_literal: return IRLiteral(ir.value) else: diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index 83098234c1..fcd2aa1f22 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -1,6 +1,7 @@ from .algebraic_optimization import AlgebraicOptimizationPass from .branch_optimization import BranchOptimizationPass from .dft import DFTPass +from .float_allocas import FloatAllocas from .make_ssa import MakeSSA from .mem2var import Mem2Var from .normalization import NormalizationPass diff --git a/vyper/venom/passes/float_allocas.py b/vyper/venom/passes/float_allocas.py new file mode 100644 index 0000000000..81fa115645 --- /dev/null +++ b/vyper/venom/passes/float_allocas.py @@ -0,0 +1,36 @@ +from vyper.venom.passes.base_pass import IRPass + + +class FloatAllocas(IRPass): + """ + This pass moves allocas to the entry basic block of a function + We could probably move them to the immediate dominator of the basic + block defining the alloca instead of the entry (which dominates all + basic blocks), but this is done for expedience. + Without this step, sccp fails, possibly because dominators are not + guaranteed to be traversed first. + """ + + def run_pass(self): + entry_bb = self.function.entry + assert entry_bb.is_terminated + tmp = entry_bb.instructions.pop() + + for bb in self.function.get_basic_blocks(): + if bb is entry_bb: + continue + + # Extract alloca instructions + non_alloca_instructions = [] + for inst in bb.instructions: + if inst.opcode in ("alloca", "palloca"): + # note: order of allocas impacts bytecode. + # TODO: investigate. + entry_bb.insert_instruction(inst) + else: + non_alloca_instructions.append(inst) + + # Replace original instructions with filtered list + bb.instructions = non_alloca_instructions + + entry_bb.instructions.append(tmp) diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 2be84ce502..369be3e753 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -252,7 +252,7 @@ def finalize(ret): if eval_result is LatticeEnum.BOTTOM: return finalize(LatticeEnum.BOTTOM) - assert isinstance(eval_result, IROperand) + assert isinstance(eval_result, IROperand), (inst.parent.label, op, inst) ops.append(eval_result) # If we haven't found BOTTOM yet, evaluate the operation From e98e004235961613c3d769d4c652884b2a242608 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Nov 2024 12:39:48 +0100 Subject: [PATCH 6/6] fix[venom]: add missing extcodesize+hash effects (#4373) per title -- effects.py was missing extcodesize and extcodehash effects. --- vyper/venom/effects.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vyper/venom/effects.py b/vyper/venom/effects.py index a668ff5439..97cffe2cb2 100644 --- a/vyper/venom/effects.py +++ b/vyper/venom/effects.py @@ -68,6 +68,8 @@ def __iter__(self): "balance": BALANCE, "selfbalance": BALANCE, "extcodecopy": EXTCODE, + "extcodesize": EXTCODE, + "extcodehash": EXTCODE, "selfdestruct": BALANCE, # may modify code, but after the transaction "log": MEMORY, "revert": MEMORY,