Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check length of return data from external calls #2076

Merged
merged 5 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 196 additions & 59 deletions tests/parser/features/external_contracts/test_external_contract_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ def bar(arg1: address) -> int128:
print("Successfully executed a complicated external contract call")


def test_external_contract_calls_with_bytes(get_contract, get_contract_with_gas_estimation):
contract_1 = """
@pytest.mark.parametrize("length", [3, 32, 33, 64])
def test_external_contract_calls_with_bytes(get_contract, length):
contract_1 = f"""
@external
def array() -> bytes[3]:
def array() -> bytes[{length}]:
return b'dog'
"""

c = get_contract_with_gas_estimation(contract_1)
c = get_contract(contract_1)

contract_2 = """
interface Foo:
Expand All @@ -89,6 +90,113 @@ def get_array(arg1: address) -> bytes[3]:
assert c2.get_array(c.address) == b"dog"


def test_bytes_too_long(get_contract, assert_tx_failed):
contract_1 = """
@external
def array() -> bytes[4]:
return b'doge'
"""

c = get_contract(contract_1)

contract_2 = """
interface Foo:
def array() -> bytes[3]: view

@external
def get_array(arg1: address) -> bytes[3]:
return Foo(arg1).array()
"""

c2 = get_contract(contract_2)
assert_tx_failed(lambda: c2.get_array(c.address))


@pytest.mark.parametrize("a,b", [(3, 3), (4, 3), (3, 4), (32, 32), (33, 33), (64, 64)])
@pytest.mark.parametrize("actual", [3, 32, 64])
def test_tuple_with_bytes(get_contract, assert_tx_failed, a, b, actual):
contract_1 = f"""
@external
def array() -> (bytes[{actual}], int128, bytes[{actual}]):
return b'dog', 255, b'cat'
"""

c = get_contract(contract_1)

contract_2 = f"""
interface Foo:
def array() -> (bytes[{a}], int128, bytes[{b}]): view

@external
def get_array(arg1: address) -> (bytes[{a}], int128, bytes[{b}]):
a: bytes[{a}] = b""
b: int128 = 0
c: bytes[{b}] = b""
a, b, c = Foo(arg1).array()
return a, b, c
"""

c2 = get_contract(contract_2)
assert c.array() == [b"dog", 255, b"cat"]
assert c2.get_array(c.address) == [b"dog", 255, b"cat"]


@pytest.mark.parametrize("a,b", [(18, 7), (18, 18), (19, 6), (64, 6), (7, 19)])
@pytest.mark.parametrize("c,d", [(19, 7), (64, 64)])
def test_tuple_with_bytes_too_long(get_contract, assert_tx_failed, a, c, b, d):
contract_1 = f"""
@external
def array() -> (bytes[{c}], int128, bytes[{d}]):
return b'nineteen characters', 255, b'seven!!'
"""

c = get_contract(contract_1)

contract_2 = f"""
interface Foo:
def array() -> (bytes[{a}], int128, bytes[{b}]): view

@external
def get_array(arg1: address) -> (bytes[{a}], int128, bytes[{b}]):
a: bytes[{a}] = b""
b: int128 = 0
c: bytes[{b}] = b""
a, b, c = Foo(arg1).array()
return a, b, c
"""

c2 = get_contract(contract_2)
assert c.array() == [b"nineteen characters", 255, b"seven!!"]
assert_tx_failed(lambda: c2.get_array(c.address))


def test_tuple_with_bytes_too_long_two(get_contract, assert_tx_failed):
contract_1 = """
@external
def array() -> (bytes[30], int128, bytes[30]):
return b'nineteen characters', 255, b'seven!!'
"""

c = get_contract(contract_1)

contract_2 = """
interface Foo:
def array() -> (bytes[30], int128, bytes[3]): view

@external
def get_array(arg1: address) -> (bytes[30], int128, bytes[3]):
a: bytes[30] = b""
b: int128 = 0
c: bytes[3] = b""
a, b, c = Foo(arg1).array()
return a, b, c
"""

c2 = get_contract(contract_2)
assert c.array() == [b"nineteen characters", 255, b"seven!!"]
assert_tx_failed(lambda: c2.get_array(c.address))


def test_external_contract_call_state_change(get_contract):
contract_1 = """
lucky: public(int128)
Expand Down Expand Up @@ -664,61 +772,6 @@ def test_bad_code_struct_exc(assert_compile_failed, get_contract_with_gas_estima
assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), ArgumentException)


def test_external_value_arg_without_return(w3, get_contract_with_gas_estimation):
contract_1 = """
@payable
@external
def get_lucky():
pass

@external
def get_balance() -> uint256:
return self.balance
"""

contract_2 = """
interface Bar:
def get_lucky() -> int128: payable

bar_contract: Bar

@external
def set_contract(contract_address: address):
self.bar_contract = Bar(contract_address)

@payable
@external
def get_lucky(amount_to_send: uint256):
if amount_to_send != 0:
self.bar_contract.get_lucky(value=amount_to_send)
else: # send it all
self.bar_contract.get_lucky(value=msg.value)
"""

c1 = get_contract_with_gas_estimation(contract_1)
c2 = get_contract_with_gas_estimation(contract_2)

assert c1.get_balance() == 0

c2.set_contract(c1.address, transact={})

# Send some eth
c2.get_lucky(0, transact={"value": 500})

# Contract 1 received money.
assert c1.get_balance() == 500
assert w3.eth.getBalance(c1.address) == 500
assert w3.eth.getBalance(c2.address) == 0

# Send subset of amount
c2.get_lucky(250, transact={"value": 500})

# Contract 1 received more money.
assert c1.get_balance() == 750
assert w3.eth.getBalance(c1.address) == 750
assert w3.eth.getBalance(c2.address) == 250


def test_tuple_return_external_contract_call(get_contract):
contract_1 = """
@external
Expand Down Expand Up @@ -795,3 +848,87 @@ def get_array(arg1: address) -> int128[3]:

c2 = get_contract(contract_2)
assert c2.get_array(c.address) == [0, 0, 0]


def test_returndatasize_too_short(get_contract, assert_tx_failed):
contract_1 = """
@external
def bar(a: int128) -> int128:
return a
"""
contract_2 = """
interface Bar:
def bar(a: int128) -> (int128, int128): view

@external
def foo(_addr: address):
Bar(_addr).bar(456)
"""
c1 = get_contract(contract_1)
c2 = get_contract(contract_2)
assert_tx_failed(lambda: c2.foo(c1.address))


def test_returndatasize_empty(get_contract, assert_tx_failed):
contract_1 = """
@external
def bar(a: int128):
pass
"""
contract_2 = """
interface Bar:
def bar(a: int128) -> int128: view

@external
def foo(_addr: address) -> int128:
return Bar(_addr).bar(456)
"""
c1 = get_contract(contract_1)
c2 = get_contract(contract_2)
assert_tx_failed(lambda: c2.foo(c1.address))


def test_returndatasize_too_long(get_contract, assert_tx_failed):
contract_1 = """
@external
def bar(a: int128) -> (int128, int128):
return a, 789
"""
contract_2 = """
interface Bar:
def bar(a: int128) -> int128: view

@external
def foo(_addr: address) -> int128:
return Bar(_addr).bar(456)
fubuloubu marked this conversation as resolved.
Show resolved Hide resolved
"""
c1 = get_contract(contract_1)
c2 = get_contract(contract_2)

# excess return data does not raise
assert c2.foo(c1.address) == 456


def test_no_returndata(get_contract, assert_tx_failed):
contract_1 = """
@external
def bar(a: int128) -> int128:
return a
"""
contract_2 = """
interface Bar:
def bar(a: int128) -> int128: view

@external
def foo(_addr: address, _addr2: address) -> int128:
x: int128 = Bar(_addr).bar(456)
# make two calls to confirm EVM behavior: RETURNDATA is always based on the last call
y: int128 = Bar(_addr2).bar(123)
return y

"""
c1 = get_contract(contract_1)
c2 = get_contract(contract_2)

assert c2.foo(c1.address, c1.address) == 123
assert_tx_failed(lambda: c2.foo(c1.address, "0x1234567890123456789012345678901234567890"))
3 changes: 2 additions & 1 deletion tests/parser/functions/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ def test_external_call_to_builtin_interface(w3, get_contract):
balanceOf: public(HashMap[address, uint256])

@external
def transfer(to: address, _value: uint256):
def transfer(to: address, _value: uint256) -> bool:
self.balanceOf[to] += _value
return True
"""

code = """
Expand Down
58 changes: 52 additions & 6 deletions vyper/parser/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
ListType,
TupleLike,
get_size_of_type,
get_static_size_of_type,
has_dynamic_data,
)


Expand All @@ -33,10 +35,11 @@ def external_call(node, context, interface_name, contract_address, pos, value=No
is_external_call=True,
)
output_placeholder, output_size, returner = get_external_call_output(sig, context)
sub = [
"seq",
["assert", ["extcodesize", contract_address]],
]
sub = ["seq"]
if not output_size:
# if we do not expect return data, check that a contract exists at the target address
# we can omit this when we _do_ expect return data because we later check `returndatasize`
sub.append(["assert", ["extcodesize", contract_address]])
iamdefinitelyahuman marked this conversation as resolved.
Show resolved Hide resolved
if context.is_constant() and sig.mutability not in ("view", "pure"):
# TODO this can probably go
raise StateAccessViolation(
Expand Down Expand Up @@ -76,9 +79,52 @@ def external_call(node, context, interface_name, contract_address, pos, value=No
],
]
)
if output_size:
# when return data is expected, revert when the length of `returndatasize` is insufficient
output_type = sig.output_type
if not has_dynamic_data(output_type):
static_output_size = get_static_size_of_type(output_type) * 32
sub.append(["assert", ["gt", "returndatasize", static_output_size - 1]])
else:
if isinstance(output_type, ByteArrayLike):
types_list = (output_type,)
elif isinstance(output_type, TupleLike):
types_list = output_type.members
else:
raise

dynamic_checks = []
static_offset = output_placeholder
static_output_size = 0
for typ in types_list:
# ensure length of bytes does not exceed max allowable length for type
if isinstance(typ, ByteArrayLike):
static_output_size += 32
# do not perform this check on calls to a JSON interface - we don't know
# for certain how long the expected data is
if not sig.is_from_json:
dynamic_checks.append(
[
"assert",
[
"lt",
[
"mload",
["add", ["mload", static_offset], output_placeholder],
],
typ.maxlen + 1,
],
]
)
static_offset += get_static_size_of_type(typ) * 32
static_output_size += get_static_size_of_type(typ) * 32

sub.append(["assert", ["gt", "returndatasize", static_output_size - 1]])
sub.extend(dynamic_checks)

sub.extend(returner)
o = LLLnode.from_list(sub, typ=sig.output_type, location="memory", pos=getpos(node))
return o

return LLLnode.from_list(sub, typ=sig.output_type, location="memory", pos=getpos(node))


def get_external_call_output(sig, context):
Expand Down
Loading