Skip to content

Commit

Permalink
refactor bounds check to be generalized
Browse files Browse the repository at this point in the history
  • Loading branch information
cyberthirst committed Apr 15, 2024
1 parent ab874ac commit 867dd03
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
21 changes: 5 additions & 16 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
sar,
shl,
shr,
slice_bounds_check,
unwrap_location,
)
from vyper.codegen.expr import Expr
Expand Down Expand Up @@ -230,18 +231,6 @@ def build_IR(self, expr, context):
ADHOC_SLICE_NODE_MACROS = ["~calldata", "~selfcode", "~extcode"]


# make sure we don't overrun the source buffer, checking for overflow:
# valid inputs satisfy:
# `assert !(start+length > src_len || start+length < start`
def _make_slice_bounds_check(start, length, src_len):
with start.cache_when_complex("start") as (b1, start):
with add_ofst(start, length).cache_when_complex("end") as (b2, end):
arithmetic_overflow = ["lt", end, start]
buffer_oob = ["gt", end, src_len]
ok = ["iszero", ["or", arithmetic_overflow, buffer_oob]]
return b1.resolve(b2.resolve(["assert", ok]))


def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: Context) -> IRnode:
assert length.is_literal, "typechecker failed"
assert isinstance(length.value, int) # mypy hint
Expand All @@ -254,7 +243,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
slice_bounds_check(start, length, "calldatasize"),
["mstore", np, length],
["calldatacopy", np + 32, start, length],
np,
Expand All @@ -264,7 +253,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
elif sub.value == "~selfcode":
node = [
"seq",
_make_slice_bounds_check(start, length, "codesize"),
slice_bounds_check(start, length, "codesize"),
["mstore", np, length],
["codecopy", np + 32, start, length],
np,
Expand All @@ -279,7 +268,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
sub.args[0],
[
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
["mstore", np, length],
["extcodecopy", "_extcode_address", np + 32, start, length],
np,
Expand Down Expand Up @@ -452,7 +441,7 @@ def build_IR(self, expr, args, kwargs, context):

ret = [
"seq",
_make_slice_bounds_check(start, length, src_len),
slice_bounds_check(start, length, src_len),
do_copy,
["mstore", dst, length], # set length
dst, # return pointer to dst
Expand Down
23 changes: 17 additions & 6 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,11 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp_=True):
if parent.location == MEMORY: # TODO: replace with utility function
with abi_ofst.cache_when_complex("abi_ofst") as (b1, abi_ofst):
bound = parent_abi_t.size_bound()
end = ["add", abi_ofst, member_abi_t.size_bound()]
# head + member_size must be 'le' than the upper bound of parent buffer
end_clamped = clamp("le", end, bound)
# head + member_size must be 'gt' than the head (ie no overflow due to 'add')
end_clamped = ["assert", ["gt", end_clamped, abi_ofst]]
ofst_ir = ["seq", end_clamped, add_ofst(parent, abi_ofst)]
ofst_ir = [
"seq",
slice_bounds_check(abi_ofst, member_abi_t.size_bound(), bound),
add_ofst(parent, abi_ofst),
]
ofst_ir = b1.resolve(ofst_ir)

return IRnode.from_list(
Expand Down Expand Up @@ -1243,3 +1242,15 @@ def clamp2(lo, arg, hi, signed):
LE = "sle" if signed else "le"
ret = ["seq", ["assert", ["and", [GE, arg, lo], [LE, arg, hi]]], arg]
return IRnode.from_list(b1.resolve(ret), typ=arg.typ)


# make sure we don't overrun the source buffer, checking for overflow:
# valid inputs satisfy:
# `assert !(start+length > src_len || start+length < start)`
def slice_bounds_check(start, length, src_len):
with start.cache_when_complex("start") as (b1, start):
with add_ofst(start, length).cache_when_complex("end") as (b2, end):
arithmetic_overflow = ["lt", end, start]
buffer_oob = ["gt", end, src_len]
ok = ["iszero", ["or", arithmetic_overflow, buffer_oob]]
return b1.resolve(b2.resolve(["assert", ok]))

0 comments on commit 867dd03

Please sign in to comment.