diff --git a/tests/parser/features/iteration/test_repeater.py b/tests/parser/features/iteration/test_repeater.py index aa0b9a7aee..9cc4bf7697 100644 --- a/tests/parser/features/iteration/test_repeater.py +++ b/tests/parser/features/iteration/test_repeater.py @@ -9,7 +9,6 @@ def repeat(z: int128) -> int128: """ c = get_contract_with_gas_estimation(basic_repeater) assert c.repeat(9) == 54 - print("Passed basic repeater test") def test_digit_reverser(get_contract_with_gas_estimation): @@ -30,7 +29,6 @@ def reverse_digits(x: int128) -> int128: c = get_contract_with_gas_estimation(digit_reverser) assert c.reverse_digits(123456) == 654321 - print("Passed digit reverser test") def test_more_complex_repeater(get_contract_with_gas_estimation): @@ -48,8 +46,6 @@ def repeat() -> int128: c = get_contract_with_gas_estimation(more_complex_repeater) assert c.repeat() == 666666 - print("Passed complex repeater test") - def test_offset_repeater(get_contract_with_gas_estimation): offset_repeater = """ @@ -64,8 +60,6 @@ def sum() -> int128: c = get_contract_with_gas_estimation(offset_repeater) assert c.sum() == 4100 - print("Passed repeater with offset test") - def test_offset_repeater_2(get_contract_with_gas_estimation): offset_repeater_2 = """ @@ -83,8 +77,6 @@ def sum(frm: int128, to: int128) -> int128: assert c.sum(100, 99999) == 15150 assert c.sum(70, 131) == 6100 - print("Passed more complex repeater with offset test") - def test_loop_call_priv(get_contract_with_gas_estimation): code = """ @@ -101,3 +93,86 @@ def foo() -> bool: c = get_contract_with_gas_estimation(code) assert c.foo() is True + + +def test_return_inside_repeater(get_contract): + code = """ +@internal +def _final(a: int128) -> int128: + for i in range(10): + if i > a: + return i + return -42 + +@internal +def _middle(a: int128) -> int128: + b: int128 = self._final(a) + return b + +@external +def foo(a: int128) -> int128: + b: int128 = self._middle(a) + return b + """ + + c = get_contract(code) + assert c.foo(6) == 7 + assert c.foo(100) == -42 + + +def test_return_inside_nested_repeater(get_contract): + code = """ +@internal +def _final(a: int128) -> int128: + for i in range(10): + for x in range(10): + if i + x > a: + return i + x + return -42 + +@internal +def _middle(a: int128) -> int128: + b: int128 = self._final(a) + return b + +@external +def foo(a: int128) -> int128: + b: int128 = self._middle(a) + return b + """ + + c = get_contract(code) + assert c.foo(14) == 15 + assert c.foo(100) == -42 + + +def test_breaks_and_returns_inside_nested_repeater(get_contract): + code = """ +@internal +def _final(a: int128) -> int128: + for i in range(10): + for x in range(10): + if a < 2: + break + return 6 + if a == 1: + break + return 31337 + + return -42 + +@internal +def _middle(a: int128) -> int128: + b: int128 = self._final(a) + return b + +@external +def foo(a: int128) -> int128: + b: int128 = self._middle(a) + return b + """ + + c = get_contract(code) + assert c.foo(100) == 6 + assert c.foo(1) == -42 + assert c.foo(0) == 31337 diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 93d8941b64..dc1bc9f98b 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -26,12 +26,6 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None) if isinstance(begin_pos, int) and isinstance(_size, int): # static values, unroll the mloads instead. mloads = [["mload", pos] for pos in range(begin_pos, _size, 32)] - return ( - ["seq_unchecked"] - + mloads - + nonreentrant_post - + [["jump", ["mload", context.callback_ptr]]] - ) else: mloads = [ "seq_unchecked", @@ -54,12 +48,17 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None) ["goto", start_label], ["label", exit_label], ] - return ( - ["seq_unchecked"] - + [mloads] - + nonreentrant_post - + [["jump", ["mload", context.callback_ptr]]] - ) + + # if we are in a for loop, we have to exit prior to returning + exit_repeater = ["exit_repeater"] if context.forvars else [] + + return ( + ["seq_unchecked"] + + exit_repeater + + mloads + + nonreentrant_post + + [["jump", ["mload", context.callback_ptr]]] + ) else: return ["seq_unchecked"] + nonreentrant_post + [["return", begin_pos, _size]] diff --git a/vyper/compile_lll.py b/vyper/compile_lll.py index 224446e493..93b9de10e4 100644 --- a/vyper/compile_lll.py +++ b/vyper/compile_lll.py @@ -214,15 +214,21 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No # Continue to the next iteration of the for loop elif code.value == "continue": if not break_dest: - raise Exception("Invalid break") + raise CompilerPanic("Invalid break") dest, continue_dest, break_height = break_dest return [continue_dest, "JUMP"] # Break from inside a for loop elif code.value == "break": if not break_dest: - raise Exception("Invalid break") + raise CompilerPanic("Invalid break") dest, continue_dest, break_height = break_dest return ["POP"] * (height - break_height) + [dest, "JUMP"] + # Break from inside one or more for loops prior to a return statement inside the loop + elif code.value == "exit_repeater": + if not break_dest: + raise CompilerPanic("Invalid break") + _, _, break_height = break_dest + return ["POP"] * break_height # With statements elif code.value == "with": o = []