Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: tests for functionality fixed in #2471 #2568

Merged
merged 19 commits into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/control-structures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ The ``@nonreentrant(<key>)`` decorator places a lock on a function, and all func
# this function is protected from re-entrancy
...

You can put the ``@nonreentrant(<key>)`` decorator on a ``__default__`` function but we recommend against it.
tserg marked this conversation as resolved.
Show resolved Hide resolved

The `__default__` Function
--------------------------

Expand Down
100 changes: 99 additions & 1 deletion tests/parser/features/decorators/test_nonreentrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vyper.exceptions import FunctionDeclarationException


def test_nonrentrant_decorator(get_contract, assert_tx_failed):
def test_nonreentrant_decorator(get_contract, assert_tx_failed):
calling_contract_code = """
interface SpecialContract:
def unprotected_function(val: String[100], do_callback: bool): nonpayable
Expand Down Expand Up @@ -85,6 +85,104 @@ def unprotected_function(val: String[100], do_callback: bool):
assert_tx_failed(lambda: reentrant_contract.protected_function2("zzz value", True, transact={}))


def test_nonreentrant_decorator_for_default(w3, get_contract, assert_tx_failed):
calling_contract_code = """
@external
def send_funds(_amount: uint256):
# raw_call() is used to overcome gas limit of send()
response: Bytes[32] = raw_call(
msg.sender,
concat(
tserg marked this conversation as resolved.
Show resolved Hide resolved
method_id("transfer(address,uint256)"),
convert(msg.sender, bytes32),
convert(_amount, bytes32)
),
max_outsize=32,
value=_amount
)

@external
@payable
def __default__():
pass
"""

reentrant_code = """
interface Callback:
def send_funds(_amount: uint256): nonpayable

special_value: public(String[100])
callback: public(Callback)

@external
def set_callback(c: address):
self.callback = Callback(c)

@external
@payable
@nonreentrant('default')
def protected_function(val: String[100], do_callback: bool) -> uint256:
self.special_value = val
_amount: uint256 = msg.value
send(self.callback.address, msg.value)

if do_callback:
self.callback.send_funds(_amount)
return 1
else:
return 2

@external
@payable
def unprotected_function(val: String[100], do_callback: bool):
self.special_value = val
_amount: uint256 = msg.value
send(self.callback.address, msg.value)

if do_callback:
self.callback.send_funds(_amount)

@external
@payable
@nonreentrant('default')
def __default__():
pass
"""

reentrant_contract = get_contract(reentrant_code)
calling_contract = get_contract(calling_contract_code)

reentrant_contract.set_callback(calling_contract.address, transact={})
assert reentrant_contract.callback() == calling_contract.address

# Test unprotected function without callback.
reentrant_contract.unprotected_function("some value", False, transact={"value": 1000})
assert reentrant_contract.special_value() == "some value"
assert w3.eth.getBalance(reentrant_contract.address) == 0
assert w3.eth.getBalance(calling_contract.address) == 1000

# Test unprotected function with callback to default.
reentrant_contract.unprotected_function("another value", True, transact={"value": 1000})
assert reentrant_contract.special_value() == "another value"
assert w3.eth.getBalance(reentrant_contract.address) == 1000
assert w3.eth.getBalance(calling_contract.address) == 1000

# Test protected function without callback.
reentrant_contract.protected_function("surprise!", False, transact={"value": 1000})
assert reentrant_contract.special_value() == "surprise!"
assert w3.eth.getBalance(reentrant_contract.address) == 1000
assert w3.eth.getBalance(calling_contract.address) == 2000

# Test protected function with callback to default.
assert_tx_failed(
lambda: reentrant_contract.protected_function(
"zzz value",
True,
transact={"value": 1000}
)
)


def test_disallow_on_init_function(get_contract):
# nonreentrant has no effect when used on the __init__ fn
# however, should disallow its usage regardless
Expand Down
Loading