Skip to content

Commit

Permalink
Merge branch 'master' into fix/cyclic-check
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper authored Aug 7, 2024
2 parents e2c20c7 + b91730b commit 91c5b41
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 20 deletions.
31 changes: 31 additions & 0 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,34 @@ def test_call(a: address, b: {type_str}) -> {type_str}:
make_file("jsonabi.json", json.dumps(convert_v1_abi(abi)))
c3 = get_contract(code, input_bundle=input_bundle)
assert c3.test_call(c1.address, value) == value


def test_interface_function_without_visibility(make_input_bundle, get_contract):
interface_code = """
def foo() -> uint256:
...
@external
def bar() -> uint256:
...
"""

code = """
import a as FooInterface
implements: FooInterface
@external
def foo() -> uint256:
return 1
@external
def bar() -> uint256:
return 1
"""

input_bundle = make_input_bundle({"a.vyi": interface_code})

c = get_contract(code, input_bundle=input_bundle)

assert c.foo() == c.bar() == 1
60 changes: 56 additions & 4 deletions tests/functional/codegen/types/numbers/test_unsigned_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,65 @@ def foo():
compile_code(code)


def test_invalid_div():
code = """
div_code_with_hint = [
(
"""
@external
def foo():
a: uint256 = 5 / 9
"""
""",
"did you mean `5 // 9`?",
),
(
"""
@external
def foo():
a: uint256 = 10
a /= (3 + 10) // (2 + 3)
""",
"did you mean `a //= (3 + 10) // (2 + 3)`?",
),
(
"""
@external
def foo(a: uint256, b:uint256, c: uint256) -> uint256:
return (a + b) / c
""",
"did you mean `(a + b) // c`?",
),
(
"""
@external
def foo(a: uint256, b:uint256, c: uint256) -> uint256:
return (a + b) / (a + c)
""",
"did you mean `(a + b) // (a + c)`?",
),
(
"""
@external
def foo(a: uint256, b:uint256, c: uint256) -> uint256:
return (a + (c + b)) / (a + c)
""",
"did you mean `(a + (c + b)) // (a + c)`?",
),
(
"""
interface Foo:
def foo() -> uint256: view
@external
def foo(a: uint256, b:uint256, c: uint256) -> uint256:
return (a + b) / staticcall Foo(self).foo()
""",
"did you mean `(a + b) // staticcall Foo(self).foo()`?",
),
]


@pytest.mark.parametrize("code, expected_hint", div_code_with_hint)
def test_invalid_div(code, expected_hint):
with pytest.raises(InvalidOperation) as e:
compile_code(code)

assert e.value._hint == "did you mean `5 // 9`?"
assert e.value._hint == expected_hint
78 changes: 78 additions & 0 deletions tests/functional/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,81 @@ def baz():
"""

assert compiler.compile_code(code, input_bundle=input_bundle) is not None


invalid_visibility_code = [
"""
import foo as Foo
implements: Foo
@external
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
@internal
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
def foobar():
pass
""",
]


@pytest.mark.parametrize("code", invalid_visibility_code)
def test_internal_visibility_in_interface(make_input_bundle, code):
interface_code = """
@internal
def foobar():
...
"""

input_bundle = make_input_bundle({"foo.vyi": interface_code})

with pytest.raises(FunctionDeclarationException) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value._message == "Interface functions can only be marked as `@external`"


external_visibility_interface = [
"""
@external
def foobar():
...
def bar():
...
""",
"""
def foobar():
...
@external
def bar():
...
""",
]


@pytest.mark.parametrize("iface", external_visibility_interface)
def test_internal_implemenatation_of_external_interface(make_input_bundle, iface):
input_bundle = make_input_bundle({"foo.vyi": iface})

code = """
import foo as Foo
implements: Foo
@internal
def foobar():
pass
def bar():
pass
"""

with pytest.raises(InterfaceViolation) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()"
4 changes: 3 additions & 1 deletion vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _pack_arguments(fn_type, args, context):
pack_args.append(["mstore", buf, util.method_id_int(abi_signature)])

if len(args) != 0:
pack_args.append(abi_encode(add_ofst(buf, 32), args_as_tuple, context, bufsz=buflen))
encode_buf = add_ofst(buf, 32)
encode_buflen = buflen - 32
pack_args.append(abi_encode(encode_buf, args_as_tuple, context, bufsz=encode_buflen))

return buf, pack_args, args_ofst, args_len

Expand Down
8 changes: 4 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {}

# keep track of exported functions to prevent duplicate exports
self._exposed_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}
self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}

self._events: list[EventT] = []

Expand Down Expand Up @@ -414,7 +414,7 @@ def visit_ImplementsDecl(self, node):
raise StructureException(msg, node.annotation, hint=hint)

# grab exposed functions
funcs = self._exposed_functions
funcs = {fn_t: node for fn_t, node in self._all_functions.items() if fn_t.is_external}
type_.validate_implements(node, funcs)

node._metadata["interface_type"] = type_
Expand Down Expand Up @@ -608,10 +608,10 @@ def _self_t(self):
def _add_exposed_function(self, func_t, node, relax=True):
# call this before self._self_t.typ.add_member() for exception raising
# priority
if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None:
if not relax and (prev_decl := self._all_functions.get(func_t)) is not None:
raise StructureException("already exported!", node, prev_decl=prev_decl)

self._exposed_functions[func_t] = node
self._all_functions[func_t] = node

def visit_VariableDecl(self, node):
# postcondition of VariableDecl.validate
Expand Down
38 changes: 29 additions & 9 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,23 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

if nonreentrant:
raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef)
# TODO: refactor so parse_decorators returns the AST location
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
raise FunctionDeclarationException(
"`@nonreentrant` not allowed in interfaces", decorator
)

# it's redundant to specify visibility in vyi - always should be external
if function_visibility is None:
function_visibility = FunctionVisibility.EXTERNAL

if function_visibility != FunctionVisibility.EXTERNAL:
nonexternal = next(
d for d in funcdef.decorator_list if d.id in FunctionVisibility.values()
)
raise FunctionDeclarationException(
"Interface functions can only be marked as `@external`", nonexternal
)

if funcdef.name == "__init__":
raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef)
Expand Down Expand Up @@ -381,6 +397,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
"""
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

# it's redundant to specify internal visibility - it's implied by not being external
if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

positional_args, keyword_args = _parse_args(funcdef)

return_type = _parse_return_type(funcdef)
Expand Down Expand Up @@ -419,6 +439,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
raise FunctionDeclarationException(
"Constructor may not use default arguments", funcdef.args.defaults[0]
)
if nonreentrant:
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

return cls(
funcdef.name,
Expand Down Expand Up @@ -495,6 +519,8 @@ def implements(self, other: "ContractFunctionT") -> bool:
if not self.is_external: # pragma: nocover
raise CompilerPanic("unreachable!")

assert self.visibility == other.visibility

arguments, return_type = self._iface_sig
other_arguments, other_return_type = other._iface_sig

Expand Down Expand Up @@ -700,7 +726,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]:

def _parse_decorators(
funcdef: vy_ast.FunctionDef,
) -> tuple[FunctionVisibility, StateMutability, bool]:
) -> tuple[Optional[FunctionVisibility], StateMutability, bool]:
function_visibility = None
state_mutability = None
nonreentrant_node = None
Expand All @@ -719,10 +745,6 @@ def _parse_decorators(
if nonreentrant_node is not None:
raise StructureException("nonreentrant decorator is already set", nonreentrant_node)

if funcdef.name == "__init__":
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

nonreentrant_node = decorator

elif isinstance(decorator, vy_ast.Name):
Expand All @@ -733,6 +755,7 @@ def _parse_decorators(
decorator,
hint="only one visibility decorator is allowed per function",
)

function_visibility = FunctionVisibility(decorator.id)

elif StateMutability.is_valid_value(decorator.id):
Expand All @@ -755,9 +778,6 @@ def _parse_decorators(
else:
raise StructureException("Bad decorator syntax", decorator)

if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

if state_mutability is None:
# default to nonpayable
state_mutability = StateMutability.NONPAYABLE
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia
def validate_implements(
self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode]
) -> None:
# only external functions can implement interfaces
fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()}

unimplemented = []
Expand All @@ -116,7 +117,9 @@ def _is_function_implemented(fn_name, fn_type):
return False

to_compare = fns_by_name[fn_name]
assert to_compare.is_external
assert isinstance(to_compare, ContractFunctionT)
assert isinstance(fn_type, ContractFunctionT)

return to_compare.implements(fn_type)

Expand Down
11 changes: 9 additions & 2 deletions vyper/semantics/types/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,16 @@ def _add_div_hint(node, e):
else:
return e

def _get_source(node):
source = node.node_source_code
if isinstance(node, vy_ast.BinOp):
# parenthesize, to preserve precedence
return f"({source})"
return source

if isinstance(node, vy_ast.BinOp):
e._hint = f"did you mean `{node.left.node_source_code} "
e._hint += f"{suggested} {node.right.node_source_code}`?"
e._hint = f"did you mean `{_get_source(node.left)} "
e._hint += f"{suggested} {_get_source(node.right)}`?"
elif isinstance(node, vy_ast.AugAssign):
e._hint = f"did you mean `{node.target.node_source_code} "
e._hint += f"{suggested}= {node.value.node_source_code}`?"
Expand Down

0 comments on commit 91c5b41

Please sign in to comment.