Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix _abi_decode overflow #1

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/functional/builtins/codegen/test_abi_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tests.evm_backends.base_env import EvmError, ExecutionReverted
from tests.utils import decimal_to_int
from vyper.exceptions import ArgumentException, StackTooDeep, StructureException
from vyper.exceptions import ArgumentException, StructureException
from vyper.utils import method_id

TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex()
Expand Down Expand Up @@ -474,7 +474,7 @@ def test_abi_decode_length_mismatch(get_contract, assert_compile_failed, bad_cod
assert_compile_failed(lambda: get_contract(bad_code), exception)


def test_abi_decode_arithmetic_overflow(w3, tx_failed, get_contract):
def test_abi_decode_arithmetic_overflow(env, tx_failed, get_contract):
# test based on GHSA-9p8r-4xp4-gw5w:
# https://github.com/vyperlang/vyper/security/advisories/GHSA-9p8r-4xp4-gw5w#advisory-comment-91841
# note: doesn't even reach the assert but reverts internally on the clamp in getelemptr
Expand All @@ -498,10 +498,10 @@ def f(x: Bytes[32 * 3]):
# and it will be added to base ptr leading to an arithmetic overflow
data += (2**256 - 0x60).to_bytes(32, "big")
with tx_failed():
w3.eth.send_transaction({"to": c.address, "data": data})
env.message_call(c.address, data=data)


def test_abi_decode_oob_due_to_invalid_head(w3, tx_failed, get_contract):
def test_abi_decode_oob_due_to_invalid_head(env, tx_failed, get_contract):
code = """
@external
def f(x: Bytes[32 * 5]):
Expand All @@ -526,10 +526,10 @@ def f(x: Bytes[32 * 5]):
data += (0x00).to_bytes(31, "big")
data += (0x03).to_bytes(32, "big") * 2
# with tx_failed():
w3.eth.send_transaction({"to": c.address, "data": data})
env.message_call(c.address, data=data)


def test_abi_decode_oob_due_to_invalid_head2(w3, tx_failed, get_contract):
def test_abi_decode_oob_due_to_invalid_head2(tx_failed, get_contract):
code = """
@external
def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]):
Expand Down Expand Up @@ -569,7 +569,7 @@ def run(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]):
c.run(data)


def test_abi_decode_oob_due_to_invalid_size(w3, tx_failed, get_contract):
def test_abi_decode_oob_due_to_invalid_size(tx_failed, get_contract, env):
code = """
@external
def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]):
Expand Down Expand Up @@ -599,7 +599,7 @@ def f(x: Bytes[2 * 32 + 3 * 32 + 3 * 32 * 4]):
data += (0x01).to_bytes(32, "big") * 3 # DynArray[Bytes[96], 3][2] data

with tx_failed():
w3.eth.send_transaction({"to": c.address, "data": data})
env.message_call(c.address, data=data)


def test_abi_decode_oob_due_to_invalid_head3(tx_failed, get_contract):
Expand Down
45 changes: 17 additions & 28 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _mul(x, y):


# Resolve pointer locations for ABI-encoded data
def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True):
def _getelemptr_abi_helper(parent, member_t, ofst):
member_abi_t = member_t.abi_type

# ABI encoding has length word and then pretends length is not there
Expand All @@ -461,17 +461,14 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True):

# double dereference, according to ABI spec
# `ofst_ir` is the "real" (absolute) pointer to the item
if parent.location != MEMORY:
ofst_ir = add_ofst(parent, abi_ofst)

else:
with abi_ofst.cache_when_complex("abi_ofst") as (b1, abi_ofst):
# TODO: cache add_ofst
arithmetic_overflow = ["lt", add_ofst(parent, abi_ofst), parent]
ofst_ir = add_ofst(parent, abi_ofst)
with ofst_ir.cache_when_complex("ofst_ir") as (b1, ofst_ir):
if parent.location == MEMORY:
arithmetic_overflow = ["lt", ofst_ir, parent]
bounds_check = ["assert", ["iszero", arithmetic_overflow]]
ofst_ir = ["seq", bounds_check, ofst_ir]

ofst_ir = ["seq", bounds_check, add_ofst(parent, abi_ofst)]
ofst_ir = b1.resolve(ofst_ir)
ofst_ir = b1.resolve(ofst_ir)

return IRnode.from_list(
ofst_ir,
Expand All @@ -494,7 +491,7 @@ def _get_element_ptr_tuplelike(parent, key, hi=None):
index = attrs.index(key)
annotation = key
else:
# TupleT
assert isinstance(typ, TupleT)
assert isinstance(key, int)
subtype = typ.member_types[key]
attrs = list(typ.tuple_keys())
Expand Down Expand Up @@ -1092,18 +1089,15 @@ def clamp_bytestring(ir_node, hi=None):
if not isinstance(t, _BytestringT): # pragma: nocover
raise CompilerPanic(f"{t} passed to clamp_bytestring")

# TODO: cache get_bytearray_length
# check if byte array length is within type max
bslen_check = ["assert", ["le", get_bytearray_length(ir_node), t.maxlen]]

if hi:
payload_sz = ["add", get_bytearray_length(ir_node), 32]
absolute_end = add_ofst(ir_node, payload_sz)
ret = ["seq", ["assert", ["le", absolute_end, hi]], bslen_check]
else:
ret = bslen_check
with get_bytearray_length(ir_node).cache_when_complex("length") as (b1, length):
len_check = ["assert", ["le", length, t.maxlen]]
if hi:
payload_len = ["add", length, 32]
absolute_end = add_ofst(ir_node, payload_len)
len_check = ["seq", ["assert", ["le", absolute_end, hi]], len_check]

return IRnode.from_list(ret, error_msg=f"{ir_node.typ} bounds check")
return IRnode.from_list(b1.resolve(len_check), error_msg=f"{ir_node.typ} bounds check")


def clamp_dyn_array(ir_node, hi=None):
Expand All @@ -1113,16 +1107,11 @@ def clamp_dyn_array(ir_node, hi=None):
dynarr_len_check = ["assert", ["le", get_dyn_array_count(ir_node), t.count]]

if hi and not t.abi_type.subtyp.is_dynamic():
pass
payload_sz = ["add", ["mul", get_dyn_array_count(ir_node), 32], 32]
absolute_end = add_ofst(ir_node, payload_sz)
ret = ["seq"]
ret.append(["assert", ["le", absolute_end, hi]])
ret.append(dynarr_len_check)
else:
ret = dynarr_len_check
dynarr_len_check = ["seq", ["assert", ["le", absolute_end, hi]], dynarr_len_check]

return IRnode.from_list(ret, error_msg=f"{ir_node.typ} bounds check")
return IRnode.from_list(dynarr_len_check, error_msg=f"{ir_node.typ} bounds check")


# clampers for basetype
Expand Down
Loading