Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: tests for functionality fixed in #2471 #2568

Merged
merged 19 commits into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/control-structures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ The ``@nonreentrant(<key>)`` decorator places a lock on a function, and all func
# this function is protected from re-entrancy
...

You can put the ``@nonreentrant(<key>)`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way.

The `__default__` Function
--------------------------

Expand Down
67 changes: 67 additions & 0 deletions tests/functional/codegen/test_abi_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,70 @@ def abi_encode3(x: uint256, ensure_tuple: bool, include_method_id: bool) -> Byte
human_encoded = abi_encode(f"({human_t})", (human_tuple,))
assert c.abi_encode(*args, True, False).hex() == human_encoded.hex()
assert c.abi_encode(*args, True, True).hex() == (method_id + human_encoded).hex()


def test_abi_encode_length_failing(get_contract, assert_compile_failed):
tserg marked this conversation as resolved.
Show resolved Hide resolved
code = """
struct WrappedBytes:
bs: Bytes[6]

@internal
def foo():
x: WrappedBytes = WrappedBytes({bs: b"hello"})
y: Bytes[96] = _abi_encode(x, ensure_tuple=True) # should be Bytes[128]
"""

assert_compile_failed(
lambda: get_contract(code)
)


def test_abi_encode_length_failing_two(get_contract, assert_compile_failed):
code = """
struct WrappedBytes:
bs: String[6]

@internal
def foo():
x: WrappedBytes = WrappedBytes({bs: "hello"})
y: Bytes[96] = _abi_encode(x, ensure_tuple=True) # should be Bytes[128]
"""

assert_compile_failed(
lambda: get_contract(code)
)


def test_side_effects_evaluation(get_contract, abi_encode):
contract_1 = """
counter: uint256

@external
def __init__():
self.counter = 0

@external
def get_counter() -> (uint256, String[6]):
self.counter += 1
return (self.counter, "hello")
"""

c = get_contract(contract_1)

contract_2 = """
interface Foo:
def get_counter() -> (uint256, String[6]): nonpayable

@external
def foo(addr: address) -> Bytes[164]:
return _abi_encode(Foo(addr).get_counter(), method_id=0xdeadbeef)
"""

c2 = get_contract(contract_2)

method_id = 0xDEADBEEF .to_bytes(4, "big")

# call to get_counter() should be evaluated only once
get_counter_encoded = abi_encode("((uint256,string))", ((1, "hello"),))
tserg marked this conversation as resolved.
Show resolved Hide resolved

assert c2.foo(c.address).hex() == (method_id + get_counter_encoded).hex()
63 changes: 59 additions & 4 deletions tests/functional/codegen/test_struct_return.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


def test_nested_tuple(get_contract):
def test_nested_struct(get_contract):
code = """
struct Animal:
location: address
Expand All @@ -12,7 +12,7 @@ def test_nested_tuple(get_contract):
animal: Animal

@external
def modify_nested_tuple(_human: Human) -> Human:
def modify_nested_struct(_human: Human) -> Human:
human: Human = _human

# do stuff, edit the structs
Expand All @@ -25,16 +25,43 @@ def modify_nested_tuple(_human: Human) -> Human:
addr1 = "0x1234567890123456789012345678901234567890"
addr2 = "0x1234567890123456789012345678900000000000"
# assert c.modify_nested_tuple([addr1, 123], [addr2, 456]) == [[addr1, 124], [addr2, 457]]
assert c.modify_nested_tuple(
assert c.modify_nested_struct(
{"location": addr1, "animal": {"location": addr2, "fur": "wool"}}
) == (
addr1,
(addr2, "wool is great"),
)


def test_nested_single_struct(get_contract):
code = """
struct Animal:
fur: String[32]

struct Human:
animal: Animal

@external
def modify_nested_single_struct(_human: Human) -> Human:
human: Human = _human

# do stuff, edit the structs
# (13 is the length of the result)
human.animal.fur = slice(concat(human.animal.fur, " is great"), 0, 13)

return human
"""
c = get_contract(code)

assert c.modify_nested_single_struct(
{"animal": {"fur": "wool"}}
) == (
("wool is great",),
)


@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"])
def test_string_inside_tuple(get_contract, string):
def test_string_inside_struct(get_contract, string):
code = f"""
struct Person:
name: String[6]
Expand All @@ -61,3 +88,31 @@ def test_values(a: address) -> Person:

c2 = get_contract(code)
assert c2.test_values(c1.address) == (string, 42)


@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"])
def test_string_inside_single_struct(get_contract, string):
code = f"""
struct Person:
name: String[6]

@external
def test_return() -> Person:
return Person({{ name:"{string}"}})
"""
c1 = get_contract(code)

code = """
struct Person:
name: String[6]

interface jsonabi:
def test_return() -> Person: view

@external
def test_values(a: address) -> Person:
return jsonabi(a).test_return()
"""

c2 = get_contract(code)
assert c2.test_values(c1.address) == (string,)
100 changes: 99 additions & 1 deletion tests/parser/features/decorators/test_nonreentrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vyper.exceptions import FunctionDeclarationException


def test_nonrentrant_decorator(get_contract, assert_tx_failed):
def test_nonreentrant_decorator(get_contract, assert_tx_failed):
calling_contract_code = """
interface SpecialContract:
def unprotected_function(val: String[100], do_callback: bool): nonpayable
Expand Down Expand Up @@ -85,6 +85,104 @@ def unprotected_function(val: String[100], do_callback: bool):
assert_tx_failed(lambda: reentrant_contract.protected_function2("zzz value", True, transact={}))


def test_nonreentrant_decorator_for_default(w3, get_contract, assert_tx_failed):
calling_contract_code = """
@external
def send_funds(_amount: uint256):
# raw_call() is used to overcome gas limit of send()
response: Bytes[32] = raw_call(
msg.sender,
concat(
tserg marked this conversation as resolved.
Show resolved Hide resolved
method_id("transfer(address,uint256)"),
convert(msg.sender, bytes32),
convert(_amount, bytes32)
),
max_outsize=32,
value=_amount
)

@external
@payable
def __default__():
pass
"""

reentrant_code = """
interface Callback:
def send_funds(_amount: uint256): nonpayable

special_value: public(String[100])
callback: public(Callback)

@external
def set_callback(c: address):
self.callback = Callback(c)

@external
@payable
@nonreentrant('default')
def protected_function(val: String[100], do_callback: bool) -> uint256:
self.special_value = val
_amount: uint256 = msg.value
send(self.callback.address, msg.value)

if do_callback:
self.callback.send_funds(_amount)
return 1
else:
return 2

@external
@payable
def unprotected_function(val: String[100], do_callback: bool):
self.special_value = val
_amount: uint256 = msg.value
send(self.callback.address, msg.value)

if do_callback:
self.callback.send_funds(_amount)

@external
@payable
@nonreentrant('default')
def __default__():
pass
"""

reentrant_contract = get_contract(reentrant_code)
calling_contract = get_contract(calling_contract_code)

reentrant_contract.set_callback(calling_contract.address, transact={})
assert reentrant_contract.callback() == calling_contract.address

# Test unprotected function without callback.
reentrant_contract.unprotected_function("some value", False, transact={"value": 1000})
assert reentrant_contract.special_value() == "some value"
assert w3.eth.getBalance(reentrant_contract.address) == 0
assert w3.eth.getBalance(calling_contract.address) == 1000

# Test unprotected function with callback to default.
reentrant_contract.unprotected_function("another value", True, transact={"value": 1000})
assert reentrant_contract.special_value() == "another value"
assert w3.eth.getBalance(reentrant_contract.address) == 1000
assert w3.eth.getBalance(calling_contract.address) == 1000

# Test protected function without callback.
reentrant_contract.protected_function("surprise!", False, transact={"value": 1000})
assert reentrant_contract.special_value() == "surprise!"
assert w3.eth.getBalance(reentrant_contract.address) == 1000
assert w3.eth.getBalance(calling_contract.address) == 2000

# Test protected function with callback to default.
assert_tx_failed(
lambda: reentrant_contract.protected_function(
"zzz value",
True,
transact={"value": 1000}
)
)


def test_disallow_on_init_function(get_contract):
# nonreentrant has no effect when used on the __init__ fn
# however, should disallow its usage regardless
Expand Down
16 changes: 16 additions & 0 deletions tests/parser/features/decorators/test_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,22 @@ def foo(a: int128) -> (int128, int128):
),
(
"""
struct A:
one: uint8

@internal
def _foo(_one: uint8) ->A:
return A({one: _one})

@external
def foo() -> A:
return self._foo(1)
""",
(),
(1,),
),
(
"""
struct A:
many: uint256[4]
one: uint256
Expand Down
Loading