Skip to content

Commit

Permalink
fix: allow using interface defs from imported modules (#3725)
Browse files Browse the repository at this point in the history
- add interface defs to ModuleT's exposed `get_type_members()`
- slight refactor of ModuleT to have a special helper instead of
dispatching to `self.interface`
  • Loading branch information
charles-cooper authored Jan 12, 2024
1 parent 07ab92f commit 5c2177b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 10 deletions.
31 changes: 31 additions & 0 deletions tests/functional/codegen/modules/test_interface_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
def test_import_interface_types(make_input_bundle, get_contract):
ifaces = """
interface IFoo:
def foo() -> uint256: nonpayable
"""

foo_impl = """
import ifaces
implements: ifaces.IFoo
@external
def foo() -> uint256:
return block.number
"""

contract = """
import ifaces
@external
def test_foo(s: ifaces.IFoo) -> bool:
assert s.foo() == block.number
return True
"""

input_bundle = make_input_bundle({"ifaces.vy": ifaces})

foo = get_contract(foo_impl, input_bundle=input_bundle)
c = get_contract(contract, input_bundle=input_bundle)

assert c.test_foo(foo.address) is True
2 changes: 1 addition & 1 deletion tests/functional/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def foo(): nonpayable
"""
implements: self.x
""",
StructureException,
InvalidType,
),
(
"""
Expand Down
4 changes: 2 additions & 2 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,8 +1372,8 @@ class ImplementsDecl(Stmt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if not isinstance(self.annotation, Name):
raise StructureException("not an identifier", self.annotation)
if not isinstance(self.annotation, (Name, Attribute)):
raise StructureException("invalid implements", self.annotation)


class If(Stmt):
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,9 @@ def visit_ImportFrom(self, node):
self._add_import(node, node.level, qualified_module_name, alias)

def visit_InterfaceDef(self, node):
obj = InterfaceT.from_InterfaceDef(node)
self.namespace[node.name] = obj
interface_t = InterfaceT.from_InterfaceDef(node)
node._metadata["interface_type"] = interface_t
self.namespace[node.name] = interface_t

def visit_StructDef(self, node):
struct_t = StructT.from_StructDef(node)
Expand Down
18 changes: 17 additions & 1 deletion vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> No

self._helper = VyperType(events | structs)
self._id = _id
self._helper._id = _id
self.functions = functions
self.events = events
self.structs = structs
Expand Down Expand Up @@ -267,6 +268,8 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT":

# Datatype to store all module information.
class ModuleT(VyperType):
_attribute_in_annotation = True

def __init__(self, module: vy_ast.Module, name: Optional[str] = None):
super().__init__()

Expand All @@ -276,7 +279,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None):

# compute the interface, note this has the side effect of checking
# for function collisions
self._helper = self.interface
_ = self.interface

self._helper = VyperType()
self._helper._id = self._id

for f in self.function_defs:
# note: this checks for collisions
Expand All @@ -289,6 +295,12 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None):
for s in self.struct_defs:
# add the type of the struct so it can be used in call position
self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore
self._helper.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore

for i in self.interface_defs:
# add the type of the interface so it can be used in call position
self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore
self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore

for v in self.variable_decls:
self.add_member(v.target.id, v.target._metadata["varinfo"])
Expand Down Expand Up @@ -322,6 +334,10 @@ def event_defs(self):
def struct_defs(self):
return self._module.get_children(vy_ast.StructDef)

@property
def interface_defs(self):
return self._module.get_children(vy_ast.InterfaceDef)

@property
def import_stmts(self):
return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom))
Expand Down
10 changes: 6 additions & 4 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
except UndeclaredDefinition:
raise InvalidType(err_msg, node) from None

interface = module_or_interface
if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo
interface = module_or_interface.module_t.interface
module_or_interface = module_or_interface.module_t

if not interface._attribute_in_annotation:
if not isinstance(module_or_interface, VyperType):
raise InvalidType(err_msg, node)

type_t = interface.get_type_member(node.attr, node)
if not module_or_interface._attribute_in_annotation:
raise InvalidType(err_msg, node)

type_t = module_or_interface.get_type_member(node.attr, node) # type: ignore
assert isinstance(type_t, TYPE_T) # sanity check
return type_t.typedef

Expand Down

0 comments on commit 5c2177b

Please sign in to comment.