diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index fad6ce889c..36b87137b9 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -495,8 +495,8 @@ def f(x: Bytes[32 * 3]): decoded_y1: Bytes[32] = _abi_decode(y, Bytes[32]) a = b"bar" decoded_y2: Bytes[32] = _abi_decode(y, Bytes[32]) - - assert decoded_y1 != decoded_y2 + # original POC: + # assert decoded_y1 != decoded_y2 """ c = get_contract(code) @@ -1043,7 +1043,7 @@ def run(): c.run() -def test_abi_decode_extcall_zero_len_array(get_contract): +def test_abi_decode_extcall_empty_array(get_contract): code = """ @external def bar() -> (uint256, uint256): @@ -1061,6 +1061,59 @@ def run(): c.run() +def test_abi_decode_extcall_complex_empty_dynarray(get_contract): + # 5th word of the payload points to the last word of the payload + # which is considered the length of the Point.y array + # because the length is 0, the decoding should succeed + code = """ +struct Point: + x: uint256 + y: DynArray[uint256, 2] + z: uint256 + +@external +def bar() -> (uint256, uint256, uint256, uint256, uint256, uint256): + return 32, 1, 32, 1, 64, 0 + +interface A: + def bar() -> DynArray[Point, 2]: nonpayable + +@external +def run(): + x: DynArray[Point, 2] = extcall A(self).bar() + assert len(x) == 1 and len(x[0].y) == 0 + """ + c = get_contract(code) + + c.run() + + +def test_abi_decode_extcall_complex_empty_dynarray2(tx_failed, get_contract): + # top-level head points 1B over the runtime buffer end + # thus the decoding should fail although the length is 0 + code = """ +struct Point: + x: uint256 + y: DynArray[uint256, 2] + z: uint256 + +@external +def bar() -> (uint256, uint256): + return 33, 0 + +interface A: + def bar() -> DynArray[Point, 2]: nonpayable + +@external +def run(): + x: DynArray[Point, 2] = extcall A(self).bar() + """ + c = get_contract(code) + + with tx_failed(): + c.run() + + def test_abi_decode_extcall_zero_len_array2(get_contract): code = """ @external @@ -1080,3 +1133,193 @@ def run() -> uint256: length = c.run() assert length == 0 + + +def test_abi_decode_top_level_head_oob(tx_failed, get_contract): + code = """ +@external +def run(x: Bytes[256], y: uint256): + player_lost: bool = empty(bool) + + if y == 1: + player_lost = True + + decoded: DynArray[Bytes[1], 2] = empty(DynArray[Bytes[1], 2]) + decoded = _abi_decode(x, DynArray[Bytes[1], 2]) + """ + c = get_contract(code) + + # head points over the buffer end + payload = (0x0100, *_replicate(0x00, 7)) + + data = _abi_payload_from_tuple(payload) + + with tx_failed(): + c.run(data, 1) + + with tx_failed(): + c.run(data, 0) + + +def test_abi_decode_dynarray_complex_insufficient_data(env, tx_failed, get_contract): + code = """ +struct Point: + x: uint256 + y: uint256 + +@external +def run(x: Bytes[32 * 8]): + y: Bytes[32 * 8] = x + decoded_y1: DynArray[Point, 3] = _abi_decode(y, DynArray[Point, 3]) + """ + c = get_contract(code) + + # runtime buffer has insufficient size - we decode 3 points, but provide only + # 3 * 32B of payload + payload = (0x20, 0x03, *_replicate(0x03, 3)) + + data = _abi_payload_from_tuple(payload) + + with tx_failed(): + c.run(data) + + +def test_abi_decode_dynarray_complex2(env, tx_failed, get_contract): + # point head to the 1st 0x01 word (ie the length) + # but size of the point is 3 * 32B, thus we'd decode 2B over the buffer end + code = """ +struct Point: + x: uint256 + y: uint256 + z: uint256 + + +@external +def run(x: Bytes[32 * 8]): + y: Bytes[32 * 11] = x + decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) + """ + c = get_contract(code) + + payload = ( + 0xC0, # points to the 1st 0x01 word (ie the length) + *_replicate(0x03, 5), + *_replicate(0x01, 2), + ) + + data = _abi_payload_from_tuple(payload) + + with tx_failed(): + c.run(data) + + +def test_abi_decode_complex_empty_dynarray(env, tx_failed, get_contract): + # point head to the last word of the payload + # this will be the length, but because it's set to 0, the decoding should succeed + code = """ +struct Point: + x: uint256 + y: DynArray[uint256, 2] + z: uint256 + + +@external +def run(x: Bytes[32 * 16]): + y: Bytes[32 * 16] = x + decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) + assert len(decoded_y1) == 1 and len(decoded_y1[0].y) == 0 + """ + c = get_contract(code) + + payload = ( + 0x20, + 0x01, + 0x20, + 0x01, + 0xA0, # points to the last word of the payload + 0x04, + 0x02, + 0x02, + 0x00, # length is 0, so decoding should succeed + ) + + data = _abi_payload_from_tuple(payload) + + c.run(data) + + +def test_abi_decode_complex_arithmetic_overflow(tx_failed, get_contract): + # inner head roundtrips due to arithmetic overflow + code = """ +struct Point: + x: uint256 + y: DynArray[uint256, 2] + z: uint256 + + +@external +def run(x: Bytes[32 * 16]): + y: Bytes[32 * 16] = x + decoded_y1: DynArray[Point, 2] = _abi_decode(y, DynArray[Point, 2]) + """ + c = get_contract(code) + + payload = ( + 0x20, + 0x01, + 0x20, + 0x01, # both Point.x and Point.y length + 2**256 - 0x20, # points to the "previous" word of the payload + 0x04, + 0x02, + 0x02, + 0x00, + ) + + data = _abi_payload_from_tuple(payload) + + with tx_failed(): + c.run(data) + + +def test_abi_decode_empty_toplevel_dynarray(get_contract): + code = """ +@external +def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): + y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x + assert len(y) == 2 * 32 + decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( + y, + DynArray[DynArray[uint256, 3], 3] + ) + assert len(decoded_y1) == 0 + """ + c = get_contract(code) + + payload = (0x20, 0x00) # DynArray head, DynArray length + + data = _abi_payload_from_tuple(payload) + + c.run(data) + + +def test_abi_decode_invalid_toplevel_dynarray_head(tx_failed, get_contract): + # head points 1B over the bounds of the runtime buffer + code = """ +@external +def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): + y: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4] = x + decoded_y1: DynArray[DynArray[uint256, 3], 3] = _abi_decode( + y, + DynArray[DynArray[uint256, 3], 3] + ) + """ + c = get_contract(code) + + # head points 1B over the bounds of the runtime buffer + payload = (0x21, 0x00) # DynArray head, DynArray length + + data = _abi_payload_from_tuple(payload) + + with tx_failed(): + c.run(data) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 3c81778660..5d4621518f 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -889,11 +889,17 @@ def _dirty_read_risk(ir_node): def _abi_payload_size(ir_node): SCALE = ir_node.location.word_scale assert SCALE == 32 # we must be in some byte-addressable region, like memory - OFFSET = DYNAMIC_ARRAY_OVERHEAD * SCALE if isinstance(ir_node.typ, DArrayT): - return ["add", OFFSET, ["mul", get_dyn_array_count(ir_node), SCALE]] + # the amount of size each value occupies in static section + # (the amount of size it occupies in the dynamic section is handled in + # make_setter recursion) + item_size = ir_node.typ.value_type.abi_type.static_size() + if item_size == 0: + # manual optimization; the mload cannot currently be optimized out + return ["add", OFFSET, 0] + return ["add", OFFSET, ["mul", get_dyn_array_count(ir_node), item_size]] if isinstance(ir_node.typ, _BytestringT): return ["add", OFFSET, get_bytearray_length(ir_node)] @@ -1175,14 +1181,17 @@ def clamp_dyn_array(ir_node, hi=None): assert (hi is not None) == _dirty_read_risk(ir_node) - # if the subtype is dynamic, the check will be performed in the recursion - if hi is not None and not t.abi_type.subtyp.is_dynamic(): + if hi is not None: assert t.count < 2**64 # sanity check # note: this add does not risk arithmetic overflow because # length is bounded by count * elemsize. item_end = add_ofst(ir_node, _abi_payload_size(ir_node)) + # if the subtype is dynamic, the length check is performed in + # the recursion, UNLESS the count is zero. here we perform the + # check all the time, but it could maybe be optimized out in the + # make_setter loop (in the common case that runtime count > 0). len_check = ["seq", ["assert", ["le", item_end, hi]], len_check] return IRnode.from_list(len_check, error_msg=f"{ir_node.typ} bounds check")