From 867dd03dbfa443051ceb71d029f7d838eec7a2e4 Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Mon, 15 Apr 2024 16:57:42 +0200 Subject: [PATCH] refactor bounds check to be generalized --- vyper/builtins/functions.py | 21 +++++---------------- vyper/codegen/core.py | 23 +++++++++++++++++------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index debf5c5c32..4cb62cda2d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -32,6 +32,7 @@ sar, shl, shr, + slice_bounds_check, unwrap_location, ) from vyper.codegen.expr import Expr @@ -230,18 +231,6 @@ def build_IR(self, expr, context): ADHOC_SLICE_NODE_MACROS = ["~calldata", "~selfcode", "~extcode"] -# make sure we don't overrun the source buffer, checking for overflow: -# valid inputs satisfy: -# `assert !(start+length > src_len || start+length < start` -def _make_slice_bounds_check(start, length, src_len): - with start.cache_when_complex("start") as (b1, start): - with add_ofst(start, length).cache_when_complex("end") as (b2, end): - arithmetic_overflow = ["lt", end, start] - buffer_oob = ["gt", end, src_len] - ok = ["iszero", ["or", arithmetic_overflow, buffer_oob]] - return b1.resolve(b2.resolve(["assert", ok])) - - def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: Context) -> IRnode: assert length.is_literal, "typechecker failed" assert isinstance(length.value, int) # mypy hint @@ -254,7 +243,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: if sub.value == "~calldata": node = [ "seq", - _make_slice_bounds_check(start, length, "calldatasize"), + slice_bounds_check(start, length, "calldatasize"), ["mstore", np, length], ["calldatacopy", np + 32, start, length], np, @@ -264,7 +253,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: elif sub.value == "~selfcode": node = [ "seq", - _make_slice_bounds_check(start, length, "codesize"), + slice_bounds_check(start, length, "codesize"), ["mstore", np, length], ["codecopy", np + 32, start, length], np, @@ -279,7 +268,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: sub.args[0], [ "seq", - _make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]), + slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]), ["mstore", np, length], ["extcodecopy", "_extcode_address", np + 32, start, length], np, @@ -452,7 +441,7 @@ def build_IR(self, expr, args, kwargs, context): ret = [ "seq", - _make_slice_bounds_check(start, length, src_len), + slice_bounds_check(start, length, src_len), do_copy, ["mstore", dst, length], # set length dst, # return pointer to dst diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 9ab014ffb3..6544a91d14 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -467,12 +467,11 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True): if parent.location == MEMORY: # TODO: replace with utility function with abi_ofst.cache_when_complex("abi_ofst") as (b1, abi_ofst): bound = parent_abi_t.size_bound() - end = ["add", abi_ofst, member_abi_t.size_bound()] - # head + member_size must be 'le' than the upper bound of parent buffer - end_clamped = clamp("le", end, bound) - # head + member_size must be 'gt' than the head (ie no overflow due to 'add') - end_clamped = ["assert", ["gt", end_clamped, abi_ofst]] - ofst_ir = ["seq", end_clamped, add_ofst(parent, abi_ofst)] + ofst_ir = [ + "seq", + slice_bounds_check(abi_ofst, member_abi_t.size_bound(), bound), + add_ofst(parent, abi_ofst), + ] ofst_ir = b1.resolve(ofst_ir) return IRnode.from_list( @@ -1243,3 +1242,15 @@ def clamp2(lo, arg, hi, signed): LE = "sle" if signed else "le" ret = ["seq", ["assert", ["and", [GE, arg, lo], [LE, arg, hi]]], arg] return IRnode.from_list(b1.resolve(ret), typ=arg.typ) + + +# make sure we don't overrun the source buffer, checking for overflow: +# valid inputs satisfy: +# `assert !(start+length > src_len || start+length < start)` +def slice_bounds_check(start, length, src_len): + with start.cache_when_complex("start") as (b1, start): + with add_ofst(start, length).cache_when_complex("end") as (b2, end): + arithmetic_overflow = ["lt", end, start] + buffer_oob = ["gt", end, src_len] + ok = ["iszero", ["or", arithmetic_overflow, buffer_oob]] + return b1.resolve(b2.resolve(["assert", ok]))