Skip to content

Commit

Permalink
fix[lang]: recursion in uses analysis for nonreentrant functions (v…
Browse files Browse the repository at this point in the history
…yperlang#3971)

this commit fixes `uses` analysis for nonreentrant functions, which are
called recursively.

a partial fix for this was applied in cb94068, but it missed the
case where a nonreentrant function is deep in the call tree.
  • Loading branch information
charles-cooper authored May 9, 2024
1 parent fb55f4c commit 4c66c8c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
80 changes: 57 additions & 23 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,52 +1300,86 @@ def foo():
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
@pytest.fixture
def nonreentrant_library_bundle(make_input_bundle):
# test simple case
lib1 = """
# lib1.vy
@external
@internal
@nonreentrant
def bar():
pass
# lib1.vy
@external
@nonreentrant
def ext_bar():
pass
"""
main = """
# test case with recursion
lib2 = """
@internal
def bar():
self.baz()
@external
def ext_bar():
self.baz()
@nonreentrant
@internal
def baz():
return
"""
# test case with nested recursion
lib3 = """
import lib1
uses: lib1
exports: lib1.bar # line 4
@internal
def bar():
lib1.bar()
@external
def ext_bar():
lib1.bar()
"""

return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})


@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
def test_nonreentrant_exports(nonreentrant_library_bundle, lib):
main = f"""
import {lib}
exports: {lib}.ext_bar # line 4
@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
compile_code(main, input_bundle=nonreentrant_library_bundle)
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1
@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
def test_internal_nonreentrant_import(nonreentrant_library_bundle, lib):
main = f"""
import {lib}
@external
def foo():
lib1.bar() # line 6
{lib}.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
compile_code(main, input_bundle=nonreentrant_library_bundle)
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE

hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
6 changes: 5 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def get_variable_accesses(self):
return self._variable_reads | self._variable_writes

def uses_state(self):
return self.nonreentrant or uses_state(self.get_variable_accesses())
return (
self.nonreentrant
or uses_state(self.get_variable_accesses())
or any(f.nonreentrant for f in self.reachable_internal_functions)
)

def get_used_modules(self):
# _used_modules is populated during analysis
Expand Down

0 comments on commit 4c66c8c

Please sign in to comment.