diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index 97763b7ae2..dce56014a8 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -67,14 +67,15 @@ def bar(arg1: address) -> int128: print("Successfully executed a complicated external contract call") -def test_external_contract_calls_with_bytes(get_contract, get_contract_with_gas_estimation): - contract_1 = """ +@pytest.mark.parametrize("length", [3, 32, 33, 64]) +def test_external_contract_calls_with_bytes(get_contract, length): + contract_1 = f""" @external -def array() -> bytes[3]: +def array() -> bytes[{length}]: return b'dog' """ - c = get_contract_with_gas_estimation(contract_1) + c = get_contract(contract_1) contract_2 = """ interface Foo: @@ -89,6 +90,113 @@ def get_array(arg1: address) -> bytes[3]: assert c2.get_array(c.address) == b"dog" +def test_bytes_too_long(get_contract, assert_tx_failed): + contract_1 = """ +@external +def array() -> bytes[4]: + return b'doge' + """ + + c = get_contract(contract_1) + + contract_2 = """ +interface Foo: + def array() -> bytes[3]: view + +@external +def get_array(arg1: address) -> bytes[3]: + return Foo(arg1).array() +""" + + c2 = get_contract(contract_2) + assert_tx_failed(lambda: c2.get_array(c.address)) + + +@pytest.mark.parametrize("a,b", [(3, 3), (4, 3), (3, 4), (32, 32), (33, 33), (64, 64)]) +@pytest.mark.parametrize("actual", [3, 32, 64]) +def test_tuple_with_bytes(get_contract, assert_tx_failed, a, b, actual): + contract_1 = f""" +@external +def array() -> (bytes[{actual}], int128, bytes[{actual}]): + return b'dog', 255, b'cat' + """ + + c = get_contract(contract_1) + + contract_2 = f""" +interface Foo: + def array() -> (bytes[{a}], int128, bytes[{b}]): view + +@external +def get_array(arg1: address) -> (bytes[{a}], int128, bytes[{b}]): + a: bytes[{a}] = b"" + b: int128 = 0 + c: bytes[{b}] = b"" + a, b, c = Foo(arg1).array() + return a, b, c +""" + + c2 = get_contract(contract_2) + assert c.array() == [b"dog", 255, b"cat"] + assert c2.get_array(c.address) == [b"dog", 255, b"cat"] + + +@pytest.mark.parametrize("a,b", [(18, 7), (18, 18), (19, 6), (64, 6), (7, 19)]) +@pytest.mark.parametrize("c,d", [(19, 7), (64, 64)]) +def test_tuple_with_bytes_too_long(get_contract, assert_tx_failed, a, c, b, d): + contract_1 = f""" +@external +def array() -> (bytes[{c}], int128, bytes[{d}]): + return b'nineteen characters', 255, b'seven!!' + """ + + c = get_contract(contract_1) + + contract_2 = f""" +interface Foo: + def array() -> (bytes[{a}], int128, bytes[{b}]): view + +@external +def get_array(arg1: address) -> (bytes[{a}], int128, bytes[{b}]): + a: bytes[{a}] = b"" + b: int128 = 0 + c: bytes[{b}] = b"" + a, b, c = Foo(arg1).array() + return a, b, c +""" + + c2 = get_contract(contract_2) + assert c.array() == [b"nineteen characters", 255, b"seven!!"] + assert_tx_failed(lambda: c2.get_array(c.address)) + + +def test_tuple_with_bytes_too_long_two(get_contract, assert_tx_failed): + contract_1 = """ +@external +def array() -> (bytes[30], int128, bytes[30]): + return b'nineteen characters', 255, b'seven!!' + """ + + c = get_contract(contract_1) + + contract_2 = """ +interface Foo: + def array() -> (bytes[30], int128, bytes[3]): view + +@external +def get_array(arg1: address) -> (bytes[30], int128, bytes[3]): + a: bytes[30] = b"" + b: int128 = 0 + c: bytes[3] = b"" + a, b, c = Foo(arg1).array() + return a, b, c +""" + + c2 = get_contract(contract_2) + assert c.array() == [b"nineteen characters", 255, b"seven!!"] + assert_tx_failed(lambda: c2.get_array(c.address)) + + def test_external_contract_call_state_change(get_contract): contract_1 = """ lucky: public(int128) @@ -664,61 +772,6 @@ def test_bad_code_struct_exc(assert_compile_failed, get_contract_with_gas_estima assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), ArgumentException) -def test_external_value_arg_without_return(w3, get_contract_with_gas_estimation): - contract_1 = """ -@payable -@external -def get_lucky(): - pass - -@external -def get_balance() -> uint256: - return self.balance -""" - - contract_2 = """ -interface Bar: - def get_lucky() -> int128: payable - -bar_contract: Bar - -@external -def set_contract(contract_address: address): - self.bar_contract = Bar(contract_address) - -@payable -@external -def get_lucky(amount_to_send: uint256): - if amount_to_send != 0: - self.bar_contract.get_lucky(value=amount_to_send) - else: # send it all - self.bar_contract.get_lucky(value=msg.value) -""" - - c1 = get_contract_with_gas_estimation(contract_1) - c2 = get_contract_with_gas_estimation(contract_2) - - assert c1.get_balance() == 0 - - c2.set_contract(c1.address, transact={}) - - # Send some eth - c2.get_lucky(0, transact={"value": 500}) - - # Contract 1 received money. - assert c1.get_balance() == 500 - assert w3.eth.getBalance(c1.address) == 500 - assert w3.eth.getBalance(c2.address) == 0 - - # Send subset of amount - c2.get_lucky(250, transact={"value": 500}) - - # Contract 1 received more money. - assert c1.get_balance() == 750 - assert w3.eth.getBalance(c1.address) == 750 - assert w3.eth.getBalance(c2.address) == 250 - - def test_tuple_return_external_contract_call(get_contract): contract_1 = """ @external @@ -795,3 +848,87 @@ def get_array(arg1: address) -> int128[3]: c2 = get_contract(contract_2) assert c2.get_array(c.address) == [0, 0, 0] + + +def test_returndatasize_too_short(get_contract, assert_tx_failed): + contract_1 = """ +@external +def bar(a: int128) -> int128: + return a +""" + contract_2 = """ +interface Bar: + def bar(a: int128) -> (int128, int128): view + +@external +def foo(_addr: address): + Bar(_addr).bar(456) +""" + c1 = get_contract(contract_1) + c2 = get_contract(contract_2) + assert_tx_failed(lambda: c2.foo(c1.address)) + + +def test_returndatasize_empty(get_contract, assert_tx_failed): + contract_1 = """ +@external +def bar(a: int128): + pass +""" + contract_2 = """ +interface Bar: + def bar(a: int128) -> int128: view + +@external +def foo(_addr: address) -> int128: + return Bar(_addr).bar(456) +""" + c1 = get_contract(contract_1) + c2 = get_contract(contract_2) + assert_tx_failed(lambda: c2.foo(c1.address)) + + +def test_returndatasize_too_long(get_contract, assert_tx_failed): + contract_1 = """ +@external +def bar(a: int128) -> (int128, int128): + return a, 789 +""" + contract_2 = """ +interface Bar: + def bar(a: int128) -> int128: view + +@external +def foo(_addr: address) -> int128: + return Bar(_addr).bar(456) +""" + c1 = get_contract(contract_1) + c2 = get_contract(contract_2) + + # excess return data does not raise + assert c2.foo(c1.address) == 456 + + +def test_no_returndata(get_contract, assert_tx_failed): + contract_1 = """ +@external +def bar(a: int128) -> int128: + return a +""" + contract_2 = """ +interface Bar: + def bar(a: int128) -> int128: view + +@external +def foo(_addr: address, _addr2: address) -> int128: + x: int128 = Bar(_addr).bar(456) + # make two calls to confirm EVM behavior: RETURNDATA is always based on the last call + y: int128 = Bar(_addr2).bar(123) + return y + +""" + c1 = get_contract(contract_1) + c2 = get_contract(contract_2) + + assert c2.foo(c1.address, c1.address) == 123 + assert_tx_failed(lambda: c2.foo(c1.address, "0x1234567890123456789012345678901234567890")) diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py index 3c0e0a34de..b2d0f226ea 100644 --- a/tests/parser/functions/test_interfaces.py +++ b/tests/parser/functions/test_interfaces.py @@ -227,8 +227,9 @@ def test_external_call_to_builtin_interface(w3, get_contract): balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256): +def transfer(to: address, _value: uint256) -> bool: self.balanceOf[to] += _value + return True """ code = """ diff --git a/vyper/parser/external_call.py b/vyper/parser/external_call.py index 3dd67cb892..0df7091fca 100644 --- a/vyper/parser/external_call.py +++ b/vyper/parser/external_call.py @@ -12,6 +12,8 @@ ListType, TupleLike, get_size_of_type, + get_static_size_of_type, + has_dynamic_data, ) @@ -33,10 +35,11 @@ def external_call(node, context, interface_name, contract_address, pos, value=No is_external_call=True, ) output_placeholder, output_size, returner = get_external_call_output(sig, context) - sub = [ - "seq", - ["assert", ["extcodesize", contract_address]], - ] + sub = ["seq"] + if not output_size: + # if we do not expect return data, check that a contract exists at the target address + # we can omit this when we _do_ expect return data because we later check `returndatasize` + sub.append(["assert", ["extcodesize", contract_address]]) if context.is_constant() and sig.mutability not in ("view", "pure"): # TODO this can probably go raise StateAccessViolation( @@ -76,9 +79,52 @@ def external_call(node, context, interface_name, contract_address, pos, value=No ], ] ) + if output_size: + # when return data is expected, revert when the length of `returndatasize` is insufficient + output_type = sig.output_type + if not has_dynamic_data(output_type): + static_output_size = get_static_size_of_type(output_type) * 32 + sub.append(["assert", ["gt", "returndatasize", static_output_size - 1]]) + else: + if isinstance(output_type, ByteArrayLike): + types_list = (output_type,) + elif isinstance(output_type, TupleLike): + types_list = output_type.members + else: + raise + + dynamic_checks = [] + static_offset = output_placeholder + static_output_size = 0 + for typ in types_list: + # ensure length of bytes does not exceed max allowable length for type + if isinstance(typ, ByteArrayLike): + static_output_size += 32 + # do not perform this check on calls to a JSON interface - we don't know + # for certain how long the expected data is + if not sig.is_from_json: + dynamic_checks.append( + [ + "assert", + [ + "lt", + [ + "mload", + ["add", ["mload", static_offset], output_placeholder], + ], + typ.maxlen + 1, + ], + ] + ) + static_offset += get_static_size_of_type(typ) * 32 + static_output_size += get_static_size_of_type(typ) * 32 + + sub.append(["assert", ["gt", "returndatasize", static_output_size - 1]]) + sub.extend(dynamic_checks) + sub.extend(returner) - o = LLLnode.from_list(sub, typ=sig.output_type, location="memory", pos=getpos(node)) - return o + + return LLLnode.from_list(sub, typ=sig.output_type, location="memory", pos=getpos(node)) def get_external_call_output(sig, context): diff --git a/vyper/signatures/function_signature.py b/vyper/signatures/function_signature.py index c5bbdf5f77..1bf13b6885 100644 --- a/vyper/signatures/function_signature.py +++ b/vyper/signatures/function_signature.py @@ -68,6 +68,7 @@ def __init__( sig, method_id, func_ast_code, + is_from_json, ): self.name = name self.args = args @@ -79,6 +80,7 @@ def __init__( self.gas = None self.nonreentrant_key = nonreentrant_key self.func_ast_code = func_ast_code + self.is_from_json = is_from_json self.calculate_arg_totals() def __str__(self): @@ -154,6 +156,7 @@ def from_definition( interface_def=False, constants=None, constant_override=False, + is_from_json=False, ): if not custom_structs: custom_structs = {} @@ -285,7 +288,16 @@ def from_definition( # Take the first 4 bytes of the hash of the sig to get the method ID method_id = fourbytes_to_int(keccak256(bytes(sig, "utf-8"))[:4]) return cls( - name, args, output_type, mutability, internal, nonreentrant_key, sig, method_id, code + name, + args, + output_type, + mutability, + internal, + nonreentrant_key, + sig, + method_id, + code, + is_from_json, ) @iterable_cast(dict) diff --git a/vyper/signatures/interface.py b/vyper/signatures/interface.py index 44fb6ec755..47714b5aa4 100644 --- a/vyper/signatures/interface.py +++ b/vyper/signatures/interface.py @@ -97,6 +97,7 @@ def mk_full_signature_from_json(abi): ), custom_structs=dict(), constants=Constants(), + is_from_json=True, ) sigs.append(sig) return sigs