Skip to content

Commit

Permalink
refactor[lang]: remove VyperNode __hash__() and __eq__() implemen…
Browse files Browse the repository at this point in the history
…tations (#4433)

it is a performance and correctness footgun for `VyperNode`'s hash and
eq implementations to recurse. for instance, two nodes from different
source files should never compare equal.

several tests rely on the recursive behavior of the eq implementation;
a utility function `deepequals()` is added in this PR so that tests
can perform the recursive check on AST nodes where needed. nowhere in
the compiler itself (`vyper/` directory) is the recursive definition
relied on.

this commit also slightly refactors the import analyzer so that uses
the new hash and eq implementations instead of the previous workaround
to avoid recursion (which was to use `id(module_ast)`.
  • Loading branch information
charles-cooper authored Jan 12, 2025
1 parent f444c8f commit 9b5523e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 30 deletions.
25 changes: 25 additions & 0 deletions tests/ast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from vyper.ast.nodes import VyperNode


def deepequals(node: VyperNode, other: VyperNode):
# checks two nodes are recursively equal, ignoring metadata
# like line info.
if not isinstance(other, type(node)):
return False

if isinstance(node, list):
if len(node) != len(other):
return False
return all(deepequals(a, b) for a, b in zip(node, other))

if not isinstance(node, VyperNode):
return node == other

if getattr(node, "node_id", None) != getattr(other, "node_id", None):
return False
for field_name in (i for i in node.get_fields() if i not in VyperNode.__slots__):
lhs = getattr(node, field_name, None)
rhs = getattr(other, field_name, None)
if not deepequals(lhs, rhs):
return False
return True
3 changes: 2 additions & 1 deletion tests/unit/ast/nodes/test_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.ast_utils import deepequals
from vyper import ast as vy_ast
from vyper.exceptions import SyntaxException

Expand All @@ -18,7 +19,7 @@ def x():
"""
)

assert expected == mutated
assert deepequals(expected, mutated)


def test_binary_length():
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/ast/nodes/test_compare_nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper import ast as vy_ast


Expand All @@ -6,33 +7,33 @@ def test_compare_different_node_clases():
left = vyper_ast.body[0].target
right = vyper_ast.body[0].value

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_class():
vyper_ast = vy_ast.parse_to_ast("[1, 2]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_value():
vyper_ast = vy_ast.parse_to_ast("[1, 1]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_similar_node():
# test equality without node_ids
left = vy_ast.Int(value=1)
right = vy_ast.Int(value=1)

assert left == right
assert deepequals(left, right)


def test_compare_same_node():
vyper_ast = vy_ast.parse_to_ast("42")
node = vyper_ast.body[0].value

assert node == node
assert deepequals(node, node)
3 changes: 2 additions & 1 deletion tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json

from tests.ast_utils import deepequals
from vyper import compiler
from vyper.ast.nodes import NODE_SRC_ATTRIBUTES
from vyper.ast.parse import parse_to_ast
Expand Down Expand Up @@ -138,7 +139,7 @@ def test() -> int128:
new_dict = json.loads(out_json)
new_ast = dict_to_ast(new_dict)

assert new_ast == original_ast
assert deepequals(new_ast, original_ast)


# strip source annotations like lineno, we don't care for inspecting
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/ast/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper.ast.parse import parse_to_ast


Expand All @@ -12,7 +13,7 @@ def test() -> int128:
ast1 = parse_to_ast(code)
ast2 = parse_to_ast("\n \n" + code + "\n\n")

assert ast1 == ast2
assert deepequals(ast1, ast2)


def test_ast_unequal():
Expand All @@ -32,4 +33,4 @@ def test() -> int128:
ast1 = parse_to_ast(code1)
ast2 = parse_to_ast(code2)

assert ast1 != ast2
assert not deepequals(ast1, ast2)
16 changes: 0 additions & 16 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,26 +331,10 @@ def get_fields(cls) -> set:
slot_fields = [x for i in cls.__mro__ for x in getattr(i, "__slots__", [])]
return set(i for i in slot_fields if not i.startswith("_"))

def __hash__(self):
values = [getattr(self, i, None) for i in VyperNode._public_slots]
return hash(tuple(values))

def __deepcopy__(self, memo):
# default implementation of deepcopy is a hotspot
return pickle.loads(pickle.dumps(self))

def __eq__(self, other):
# CMC 2024-03-03 I'm not sure it makes much sense to compare AST
# nodes, especially if they come from other modules
if not isinstance(other, type(self)):
return False
if getattr(other, "node_id", None) != getattr(self, "node_id", None):
return False
for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__):
if getattr(self, field_name, None) != getattr(other, field_name, None):
return False
return True

def __repr__(self):
cls = type(self)
class_repr = f"{cls.__module__}.{cls.__qualname__}"
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/analysis/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def push_path(self, module_ast: vy_ast.Module) -> None:

def pop_path(self, expected: vy_ast.Module) -> None:
popped = self._path.pop()
if expected != popped:
raise CompilerPanic("unreachable")
assert expected is popped, "unreachable"
self._imports.pop()

@contextlib.contextmanager
Expand All @@ -78,7 +77,7 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph):
self.graph = graph
self._ast_of: dict[int, vy_ast.Module] = {}

self.seen: set[int] = set()
self.seen: set[vy_ast.Module] = set()

self._integrity_sum = None

Expand All @@ -103,7 +102,7 @@ def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module):
return sha256sum("".join(acc))

def _resolve_imports_r(self, module_ast: vy_ast.Module):
if id(module_ast) in self.seen:
if module_ast in self.seen:
return
with self.graph.enter_path(module_ast):
for node in module_ast.body:
Expand All @@ -112,7 +111,8 @@ def _resolve_imports_r(self, module_ast: vy_ast.Module):
self._handle_Import(node)
elif isinstance(node, vy_ast.ImportFrom):
self._handle_ImportFrom(node)
self.seen.add(id(module_ast))

self.seen.add(module_ast)

def _handle_Import(self, node: vy_ast.Import):
# import x.y[name] as y[alias]
Expand Down

0 comments on commit 9b5523e

Please sign in to comment.