Skip to content

Commit

Permalink
Merge pull request #2110 from iamdefinitelyahuman/fix-return-in-loop
Browse files Browse the repository at this point in the history
Pop for loop values from stack prior to returning
  • Loading branch information
fubuloubu authored Jul 16, 2020
2 parents dc8ebc0 + a7364cd commit a1d92e5
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 22 deletions.
91 changes: 83 additions & 8 deletions tests/parser/features/iteration/test_repeater.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def repeat(z: int128) -> int128:
"""
c = get_contract_with_gas_estimation(basic_repeater)
assert c.repeat(9) == 54
print("Passed basic repeater test")


def test_digit_reverser(get_contract_with_gas_estimation):
Expand All @@ -30,7 +29,6 @@ def reverse_digits(x: int128) -> int128:

c = get_contract_with_gas_estimation(digit_reverser)
assert c.reverse_digits(123456) == 654321
print("Passed digit reverser test")


def test_more_complex_repeater(get_contract_with_gas_estimation):
Expand All @@ -48,8 +46,6 @@ def repeat() -> int128:
c = get_contract_with_gas_estimation(more_complex_repeater)
assert c.repeat() == 666666

print("Passed complex repeater test")


def test_offset_repeater(get_contract_with_gas_estimation):
offset_repeater = """
Expand All @@ -64,8 +60,6 @@ def sum() -> int128:
c = get_contract_with_gas_estimation(offset_repeater)
assert c.sum() == 4100

print("Passed repeater with offset test")


def test_offset_repeater_2(get_contract_with_gas_estimation):
offset_repeater_2 = """
Expand All @@ -83,8 +77,6 @@ def sum(frm: int128, to: int128) -> int128:
assert c.sum(100, 99999) == 15150
assert c.sum(70, 131) == 6100

print("Passed more complex repeater with offset test")


def test_loop_call_priv(get_contract_with_gas_estimation):
code = """
Expand All @@ -101,3 +93,86 @@ def foo() -> bool:

c = get_contract_with_gas_estimation(code)
assert c.foo() is True


def test_return_inside_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
if i > a:
return i
return -42
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(6) == 7
assert c.foo(100) == -42


def test_return_inside_nested_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
for x in range(10):
if i + x > a:
return i + x
return -42
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(14) == 15
assert c.foo(100) == -42


def test_breaks_and_returns_inside_nested_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
for x in range(10):
if a < 2:
break
return 6
if a == 1:
break
return 31337
return -42
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(100) == 6
assert c.foo(1) == -42
assert c.foo(0) == 31337
23 changes: 11 additions & 12 deletions vyper/codegen/return_.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None)
if isinstance(begin_pos, int) and isinstance(_size, int):
# static values, unroll the mloads instead.
mloads = [["mload", pos] for pos in range(begin_pos, _size, 32)]
return (
["seq_unchecked"]
+ mloads
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)
else:
mloads = [
"seq_unchecked",
Expand All @@ -54,12 +48,17 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None)
["goto", start_label],
["label", exit_label],
]
return (
["seq_unchecked"]
+ [mloads]
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)

# if we are in a for loop, we have to exit prior to returning
exit_repeater = ["exit_repeater"] if context.forvars else []

return (
["seq_unchecked"]
+ exit_repeater
+ mloads
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)
else:
return ["seq_unchecked"] + nonreentrant_post + [["return", begin_pos, _size]]

Expand Down
10 changes: 8 additions & 2 deletions vyper/compile_lll.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,21 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No
# Continue to the next iteration of the for loop
elif code.value == "continue":
if not break_dest:
raise Exception("Invalid break")
raise CompilerPanic("Invalid break")
dest, continue_dest, break_height = break_dest
return [continue_dest, "JUMP"]
# Break from inside a for loop
elif code.value == "break":
if not break_dest:
raise Exception("Invalid break")
raise CompilerPanic("Invalid break")
dest, continue_dest, break_height = break_dest
return ["POP"] * (height - break_height) + [dest, "JUMP"]
# Break from inside one or more for loops prior to a return statement inside the loop
elif code.value == "exit_repeater":
if not break_dest:
raise CompilerPanic("Invalid break")
_, _, break_height = break_dest
return ["POP"] * break_height
# With statements
elif code.value == "with":
o = []
Expand Down

0 comments on commit a1d92e5

Please sign in to comment.