diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 021cb7a33c..d975ab7d80 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -3,7 +3,7 @@ from tests.evm_backends.base_env import EvmError, ExecutionReverted from tests.utils import decimal_to_int -from vyper.exceptions import ArgumentException, StackTooDeep, StructureException +from vyper.exceptions import ArgumentException, StructureException from vyper.utils import method_id TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() @@ -474,7 +474,7 @@ def test_abi_decode_length_mismatch(get_contract, assert_compile_failed, bad_cod assert_compile_failed(lambda: get_contract(bad_code), exception) -def test_abi_decode_arithmetic_overflow(w3, tx_failed, get_contract): +def test_abi_decode_arithmetic_overflow(env, tx_failed, get_contract): # test based on GHSA-9p8r-4xp4-gw5w: # https://github.com/vyperlang/vyper/security/advisories/GHSA-9p8r-4xp4-gw5w#advisory-comment-91841 # note: doesn't even reach the assert but reverts internally on the clamp in getelemptr @@ -498,10 +498,10 @@ def f(x: Bytes[32 * 3]): # and it will be added to base ptr leading to an arithmetic overflow data += (2**256 - 0x60).to_bytes(32, "big") with tx_failed(): - w3.eth.send_transaction({"to": c.address, "data": data}) + env.message_call(c.address, data=data) -def test_abi_decode_oob_due_to_invalid_head(w3, tx_failed, get_contract): +def test_abi_decode_oob_due_to_invalid_head(env, tx_failed, get_contract): code = """ @external def f(x: Bytes[32 * 5]): @@ -526,10 +526,10 @@ def f(x: Bytes[32 * 5]): data += (0x00).to_bytes(31, "big") data += (0x03).to_bytes(32, "big") * 2 # with tx_failed(): - w3.eth.send_transaction({"to": c.address, "data": data}) + env.message_call(c.address, data=data) -def test_abi_decode_oob_due_to_invalid_head2(w3, tx_failed, get_contract): +def test_abi_decode_oob_due_to_invalid_head2(tx_failed, get_contract): code = """ @external def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): @@ -569,7 +569,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): c.run(data) -def test_abi_decode_oob_due_to_invalid_size(w3, tx_failed, get_contract): +def test_abi_decode_oob_due_to_invalid_size(tx_failed, get_contract, env): code = """ @external def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): @@ -599,7 +599,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]): data += (0x01).to_bytes(32, "big") * 3 # DynArray[Bytes[96], 3][2] data with tx_failed(): - w3.eth.send_transaction({"to": c.address, "data": data}) + env.message_call(c.address, data=data) def test_abi_decode_oob_due_to_invalid_head3(tx_failed, get_contract): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index f14dfabb48..6925264d0b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -443,7 +443,7 @@ def _mul(x, y): # Resolve pointer locations for ABI-encoded data -def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True): +def _getelemptr_abi_helper(parent, member_t, ofst): member_abi_t = member_t.abi_type # ABI encoding has length word and then pretends length is not there @@ -461,17 +461,14 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True): # double dereference, according to ABI spec # `ofst_ir` is the "real" (absolute) pointer to the item - if parent.location != MEMORY: - ofst_ir = add_ofst(parent, abi_ofst) - - else: - with abi_ofst.cache_when_complex("abi_ofst") as (b1, abi_ofst): - # TODO: cache add_ofst - arithmetic_overflow = ["lt", add_ofst(parent, abi_ofst), parent] + ofst_ir = add_ofst(parent, abi_ofst) + with ofst_ir.cache_when_complex("ofst_ir") as (b1, ofst_ir): + if parent.location == MEMORY: + arithmetic_overflow = ["lt", ofst_ir, parent] bounds_check = ["assert", ["iszero", arithmetic_overflow]] + ofst_ir = ["seq", bounds_check, ofst_ir] - ofst_ir = ["seq", bounds_check, add_ofst(parent, abi_ofst)] - ofst_ir = b1.resolve(ofst_ir) + ofst_ir = b1.resolve(ofst_ir) return IRnode.from_list( ofst_ir, @@ -494,7 +491,7 @@ def _get_element_ptr_tuplelike(parent, key, hi=None): index = attrs.index(key) annotation = key else: - # TupleT + assert isinstance(typ, TupleT) assert isinstance(key, int) subtype = typ.member_types[key] attrs = list(typ.tuple_keys()) @@ -1092,18 +1089,15 @@ def clamp_bytestring(ir_node, hi=None): if not isinstance(t, _BytestringT): # pragma: nocover raise CompilerPanic(f"{t} passed to clamp_bytestring") - # TODO: cache get_bytearray_length # check if byte array length is within type max - bslen_check = ["assert", ["le", get_bytearray_length(ir_node), t.maxlen]] - - if hi: - payload_sz = ["add", get_bytearray_length(ir_node), 32] - absolute_end = add_ofst(ir_node, payload_sz) - ret = ["seq", ["assert", ["le", absolute_end, hi]], bslen_check] - else: - ret = bslen_check + with get_bytearray_length(ir_node).cache_when_complex("length") as (b1, length): + len_check = ["assert", ["le", length, t.maxlen]] + if hi: + payload_len = ["add", length, 32] + absolute_end = add_ofst(ir_node, payload_len) + len_check = ["seq", ["assert", ["le", absolute_end, hi]], len_check] - return IRnode.from_list(ret, error_msg=f"{ir_node.typ} bounds check") + return IRnode.from_list(b1.resolve(len_check), error_msg=f"{ir_node.typ} bounds check") def clamp_dyn_array(ir_node, hi=None): @@ -1113,16 +1107,11 @@ def clamp_dyn_array(ir_node, hi=None): dynarr_len_check = ["assert", ["le", get_dyn_array_count(ir_node), t.count]] if hi and not t.abi_type.subtyp.is_dynamic(): - pass payload_sz = ["add", ["mul", get_dyn_array_count(ir_node), 32], 32] absolute_end = add_ofst(ir_node, payload_sz) - ret = ["seq"] - ret.append(["assert", ["le", absolute_end, hi]]) - ret.append(dynarr_len_check) - else: - ret = dynarr_len_check + dynarr_len_check = ["seq", ["assert", ["le", absolute_end, hi]], dynarr_len_check] - return IRnode.from_list(ret, error_msg=f"{ir_node.typ} bounds check") + return IRnode.from_list(dynarr_len_check, error_msg=f"{ir_node.typ} bounds check") # clampers for basetype