From 75fb0594ab3011491fd124abe534573dbd9ba052 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 11 Apr 2024 06:55:59 -0400 Subject: [PATCH] fix[lang]: fix array index checks when the subscript is folded (#3924) this commit fixes a regression introduced in 56c4c9dbc. prior to 56c4c9dbc, the folded subscript would be checked for OOB access, but after 56c4c9dbc, expressions like `foo[0 - 1]` can slip past the typechecker (getting demoted to a runtime check). also, a common pattern is refactored. common pattern: ```python if node.has_folded_value: node = node.get_folded_value() ``` => ``` node = node.reduced() ``` --- tests/unit/ast/nodes/test_fold_subscript.py | 22 +++++++++++++++++++++ vyper/ast/nodes.py | 5 +++++ vyper/ast/nodes.pyi | 1 + vyper/codegen/expr.py | 7 ++++--- vyper/semantics/analysis/local.py | 10 +++------- vyper/semantics/types/subscriptable.py | 9 ++++++--- vyper/semantics/types/utils.py | 3 +-- 7 files changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/unit/ast/nodes/test_fold_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py index 3ed26d07b7..232f18b41d 100644 --- a/tests/unit/ast/nodes/test_fold_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -3,6 +3,8 @@ from hypothesis import strategies as st from tests.utils import parse_and_fold +from vyper.compiler import compile_code +from vyper.exceptions import ArrayIndexException @pytest.mark.fuzzing @@ -24,3 +26,23 @@ def foo(array: int128[10], idx: uint256) -> int128: new_node = old_node.get_folded_value() assert contract.foo(array, idx) == new_node.value + + +def test_negative_index(): + source = """ +@external +def foo(array: int128[10]) -> int128: + return array[0 - 1] + """ + with pytest.raises(ArrayIndexException): + compile_code(source) + + +def test_oob_index(): + source = """ +@external +def foo(array: int128[10]) -> int128: + return array[9 + 1] + """ + with pytest.raises(ArrayIndexException): + compile_code(source) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index c78ecb6d89..c7b20d4f12 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -421,6 +421,11 @@ def get_folded_value(self) -> "ExprNode": except KeyError: raise UnfoldableNode("not foldable", self) + def reduced(self) -> "ExprNode": + if self.has_folded_value: + return self.get_folded_value() + return self + def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index fe01bf9260..d67c496188 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -38,6 +38,7 @@ class VyperNode: def get_fields(cls: Any) -> set: ... def set_parent(self, parent: VyperNode) -> VyperNode: ... def get_folded_value(self) -> ExprNode: ... + def reduced(self) -> ExprNode: ... def _set_folded_value(self, node: ExprNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 691a42876e..edd932d58e 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -71,8 +71,7 @@ class Expr: def __init__(self, node, context, is_stmt=False): assert isinstance(node, vy_ast.VyperNode) - if node.has_folded_value: - node = node.get_folded_value() + node = node.reduced() self.expr = node self.context = context @@ -347,7 +346,9 @@ def parse_Subscript(self): index = Expr.parse_value_expr(self.expr.slice, self.context) elif is_tuple_like(sub.typ): - index = self.expr.slice.n + # should we annotate expr.slice in the frontend with the + # folded value instead of calling reduced() here? + index = self.expr.slice.reduced().n # note: this check should also happen in get_element_ptr if not 0 <= index < len(sub.typ.member_types): raise TypeCheckFailure("unreachable") diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index e23a267a15..37ba371dd8 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -527,9 +527,7 @@ def _analyse_range_iter(self, iter_node, target_type): def _analyse_list_iter(self, iter_node, target_type): # iteration over a variable or literal list - iter_val = iter_node - if iter_val.has_folded_value: - iter_val = iter_val.get_folded_value() + iter_val = iter_node.reduced() if isinstance(iter_val, vy_ast.List): len_ = len(iter_val.elements) @@ -946,12 +944,10 @@ def _validate_range_call(node: vy_ast.Call): validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args - start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] + start, end = [i.reduced() for i in (start, end)] if "bound" in kwargs: - bound = kwargs["bound"] - if bound.has_folded_value: - bound = bound.get_folded_value() + bound = kwargs["bound"].reduced() if not isinstance(bound, vy_ast.Int): raise StructureException("Bound must be a literal integer", bound) if bound.value <= 0: diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index e6e8971087..5144952be8 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -128,6 +128,8 @@ def validate_index_type(self, node): # TODO break this cycle from vyper.semantics.analysis.utils import validate_expected_type + node = node.reduced() + if isinstance(node, vy_ast.Int): if node.value < 0: raise ArrayIndexException("Vyper does not support negative indexing", node) @@ -290,9 +292,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": if not isinstance(node.slice, vy_ast.Tuple) or len(node.slice.elements) != 2: raise StructureException(err_msg, node.slice) - length_node = node.slice.elements[1] - if length_node.has_folded_value: - length_node = length_node.get_folded_value() + length_node = node.slice.elements[1].reduced() if not isinstance(length_node, vy_ast.Int): raise StructureException(err_msg, length_node) @@ -367,6 +367,8 @@ def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types) def validate_index_type(self, node): + node = node.reduced() + if not isinstance(node, vy_ast.Int): raise InvalidType("Tuple indexes must be literals", node) if node.value < 0: @@ -375,6 +377,7 @@ def validate_index_type(self, node): raise ArrayIndexException("Index out of range", node) def get_subscripted_type(self, node): + node = node.reduced() return self.member_types[node.value] def compare_type(self, other): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 7b0b43990f..be80a200ed 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -195,8 +195,7 @@ def get_index_value(node: vy_ast.VyperNode) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - if node.has_folded_value: - node = node.get_folded_value() + node = node.reduced() if not isinstance(node, vy_ast.Int): # even though the subscript is an invalid type, first check if it's a valid _something_