Skip to content

Commit

Permalink
fix[codegen]: fix make_setter overlap with internal calls (vyperlan…
Browse files Browse the repository at this point in the history
…g#4037)

fix overlap analysis for `make_setter` when the RHS contains an internal
call. this represents a bug which was partially fixed in 1c8349e;
this commit fixes two variants:

- where the RHS contains an internal call, the analysis needs to take
  into account all storage variables touched by the called function.

- where the RHS contains an external call, the analysis needs to assume
  any storage variable could be touched (this is a conservative
  analysis; a more advanced analysis could limit the assumption to
  storage variables touched by reentrant functions)

---------

Co-authored-by: cyberthirst <[email protected]>
Co-authored-by: trocher <[email protected]>
  • Loading branch information
3 people authored May 26, 2024
1 parent 5445960 commit ad9c10b
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2527,3 +2527,58 @@ def foo(a: DynArray[{typ}, 3], b: String[5]):
# Dynamic size is at least minimum (132 bytes * 2 + 2 (for 0x) = 266)
valid = data[:266]
env.message_call(c1.address, data=valid)


def test_make_setter_external_call(get_contract):
# variant of GH #3503
code = """
interface A:
def boo() -> uint256:nonpayable
a: DynArray[uint256, 10]
@external
def foo() -> DynArray[uint256, 10]:
self.a = [1, 2, extcall A(self).boo(), 4]
return self.a # returns [11, 12, 3, 4]
@external
def boo() -> uint256:
self.a = [11, 12, 13, 14, 15, 16]
self.a = []
# it should now be impossible to read any of [11, 12, 13, 14, 15, 16]
return 3
"""
c = get_contract(code)

assert c.foo() == [1, 2, 3, 4]


def test_make_setter_external_call2(get_contract):
# variant of GH #3503
code = """
interface A:
def boo(): nonpayable
a: DynArray[uint256, 10]
@external
def foo() -> DynArray[uint256, 10]:
self.a = [1, 2, self.baz(), 4]
return self.a # returns [11, 12, 3, 4]
@internal
def baz() -> uint256:
extcall A(self).boo()
return 3
@external
def boo():
self.a = [11, 12, 13, 14, 15, 16]
self.a = []
# it should now be impossible to read any of [11, 12, 13, 14, 15, 16]
"""
c = get_contract(code)

assert c.foo() == [1, 2, 3, 4]
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,49 @@ def bar() -> Bytes[6]:
assert c.bar() == b"hello"


def test_make_setter_internal_call(get_contract):
# cf. GH #3503
code = """
a:DynArray[uint256,2]
@external
def foo() -> DynArray[uint256,2]:
# Initial value
self.a = [1, 2]
self.a = [self.bar(1), self.bar(0)]
return self.a
@internal
def bar(i: uint256) -> uint256:
return self.a[i]
"""
c = get_contract(code)

assert c.foo() == [2, 1]


def test_make_setter_internal_call2(get_contract):
# cf. GH #3503
code = """
a: DynArray[uint256, 10]
@external
def foo() -> DynArray[uint256, 10]:
self.a = [1, 2, self.boo(), 4]
return self.a # returns [11, 12, 3, 4]
@internal
def boo() -> uint256:
self.a = [11, 12, 13, 14, 15, 16]
self.a = []
# it should now be impossible to read any of [11, 12, 13, 14, 15, 16]
return 3
"""
c = get_contract(code)

assert c.foo() == [1, 2, 3, 4]


def test_dynamically_sized_struct_member_as_arg_2(get_contract):
contract = """
struct X:
Expand Down
1 change: 1 addition & 0 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def parse_Call(self):

assert isinstance(func_t, ContractFunctionT)
assert func_t.is_internal or func_t.is_constructor

return self_call.ir_for_self_call(self.expr, self.context)

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _FuncIRInfo:
func_t: ContractFunctionT
gas_estimate: Optional[int] = None
frame_info: Optional[FrameInfo] = None
func_ir: Optional["InternalFuncIR"] = None

@property
def visibility(self):
Expand Down Expand Up @@ -56,6 +57,10 @@ def set_frame_info(self, frame_info: FrameInfo) -> None:
else:
self.frame_info = frame_info

def set_func_ir(self, func_ir: "InternalFuncIR") -> None:
assert self.func_t.is_internal or self.func_t.is_deploy
self.func_ir = func_ir

@property
# common entry point for external function with kwargs
def external_function_base_entry_label(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion vyper/codegen/function_definitions/internal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,7 @@ def generate_ir_for_internal_function(
func_t._ir_info.gas_estimate = ir_node.gas
tag_frame_info(func_t, context)

return InternalFuncIR(ir_node)
ret = InternalFuncIR(ir_node)
func_t._ir_info.func_ir = ret

return ret
18 changes: 16 additions & 2 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,25 @@ def cache_when_complex(self, name):

@cached_property
def referenced_variables(self):
ret = set()
ret = getattr(self, "_referenced_variables", set())

for arg in self.args:
ret |= arg.referenced_variables

ret |= getattr(self, "_referenced_variables", set())
if getattr(self, "is_self_call", False):
ret |= self.invoked_function_ir.func_ir.referenced_variables

return ret

@cached_property
def contains_risky_call(self):
ret = self.value in ("call", "delegatecall", "create", "create2")

for arg in self.args:
ret |= arg.contains_risky_call

if getattr(self, "is_self_call", False):
ret |= self.invoked_function_ir.func_ir.contains_risky_call

return ret

Expand Down
1 change: 1 addition & 0 deletions vyper/codegen/self_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,5 @@ def ir_for_self_call(stmt_expr, context):
add_gas_estimate=func_t._ir_info.gas_estimate,
)
o.is_self_call = True
o.invoked_function_ir = func_t._ir_info.func_ir
return o
2 changes: 2 additions & 0 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def parse_Assign(self):

ret = ["seq"]
overlap = len(dst.referenced_variables & src.referenced_variables) > 0
overlap |= len(dst.referenced_variables) > 0 and src.contains_risky_call
overlap |= dst.contains_risky_call and len(src.referenced_variables) > 0
if overlap and not dst.typ._is_prim_word:
# there is overlap between the lhs and rhs, and the type is
# complex - i.e., it spans multiple words. for safety, we
Expand Down

0 comments on commit ad9c10b

Please sign in to comment.