Skip to content

Commit

Permalink
feat: allow constant interfaces (#3718)
Browse files Browse the repository at this point in the history
This PR adds support for interfaces as constants.
- also introduced a mild refactor of `check_modifiability()`
  • Loading branch information
tserg authored Jan 15, 2024
1 parent af5c49f commit 88d9c22
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] =
assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default]


def test_default_param_interface(get_contract):
code = """
interface Foo:
def bar(): payable
FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
@external
def bar(a: uint256, b: Foo = Foo(0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7)) -> Foo:
return b
@external
def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo:
return b
@external
def faz(a: uint256, b: Foo = FOO) -> Foo:
return b
"""
c = get_contract(code)

addr1 = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"
addr2 = "0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7"

assert c.bar(1) == addr2
assert c.bar(1, addr1) == addr1
assert c.baz(1) is None
assert c.baz(1, "0x0000000000000000000000000000000000000000") is None
assert c.faz(1) == addr1
assert c.faz(1, addr1) == addr1


def test_default_param_internal_function(get_contract):
code = """
@internal
Expand Down
5 changes: 5 additions & 0 deletions tests/functional/codegen/storage_variables/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def foo() -> int128:

def test_getter_code(get_contract_with_gas_estimation_for_constants):
getter_code = """
interface V:
def foo(): nonpayable
struct W:
a: uint256
b: int128[7]
Expand All @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants):
d: public(immutable(uint256))
e: public(immutable(uint256[2]))
f: public(constant(uint256[2])) = [3, 7]
g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
@external
def __init__():
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__():
assert c.d() == 1729
assert c.e(0) == 2
assert [c.f(i) for i in range(2)] == [3, 7]
assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF"


def test_getter_mutability(get_contract):
Expand Down
13 changes: 13 additions & 0 deletions tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]):
CONST_BAR: constant(Bar) = Bar({c: C, d: D})
""",
"""
interface Foo:
def foo(): nonpayable
FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
""",
"""
interface Foo:
def foo(): nonpayable
FOO: constant(Foo) = Foo(BAR)
BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
""",
]


Expand Down
3 changes: 3 additions & 0 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
# ensures the type can be inferred exactly.
get_exact_type_from_node(arg)

def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return self._modifiability >= modifiability

def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
self._validate_arg_types(node)

Expand Down
6 changes: 3 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def parse_Name(self):
varinfo = self.context.globals[self.expr.id]

if varinfo.is_constant:
# non-struct constants should have already gotten propagated
# during constant folding
assert isinstance(varinfo.typ, StructT)
# constants other than structs and interfaces should have already gotten
# propagated during constant folding
assert isinstance(varinfo.typ, (InterfaceT, StructT))
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

assert varinfo.is_immutable, "not an immutable!"
Expand Down
10 changes: 3 additions & 7 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,15 +645,11 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) ->
return all(check_modifiability(item, modifiability) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)

# builtins
call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE)
return call_type_modifiability >= modifiability
# structs and interfaces
if hasattr(call_type, "check_modifiability_for_call"):
return call_type.check_modifiability_for_call(node, modifiability)

value_type = get_expr_info(node)
return value_type.modifiability >= modifiability
5 changes: 5 additions & 0 deletions vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ def __init__(self, typedef):
def __repr__(self):
return f"type({self.typedef})"

def check_modifiability_for_call(self, node, modifiability):
if hasattr(self.typedef, "_ctor_modifiability_for_call"):
return self.typedef._ctor_modifiability_for_call(node, modifiability)
raise StructureException("Value is not callable", node)

# dispatch into ctor if it's called
def fetch_call_return(self, node):
if hasattr(self.typedef, "_ctor_call_return"):
Expand Down
11 changes: 9 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
StructureException,
UnfoldableNode,
)
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import (
check_modifiability,
validate_expected_type,
validate_unique_method_ids,
)
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.function import ContractFunctionT
Expand Down Expand Up @@ -81,6 +85,9 @@ def _ctor_arg_types(self, node):
def _ctor_kwarg_types(self, node):
return {}

def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return check_modifiability(node.args[0], modifiability)

# TODO x.validate_implements(other)
def validate_implements(self, node: vy_ast.ImplementsDecl) -> None:
namespace = get_namespace()
Expand Down
6 changes: 5 additions & 1 deletion vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
UnknownAttribute,
VariableDeclarationException,
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.subscriptable import HashMapT
Expand Down Expand Up @@ -419,3 +420,6 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT":
)

return self

def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return all(check_modifiability(v, modifiability) for v in node.args[0].values)

0 comments on commit 88d9c22

Please sign in to comment.