Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[lang]: remove VyperNode __hash__() and __eq__() implementations #4433

Merged
merged 9 commits into from
Jan 12, 2025
25 changes: 25 additions & 0 deletions tests/ast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from vyper.ast.nodes import VyperNode


def deepequals(node: VyperNode, other: VyperNode):
# checks two nodes are recursively equal, ignoring metadata
# like line info.
if not isinstance(other, type(node)):
return False

if isinstance(node, list):
if len(node) != len(other):
return False
return all(deepequals(a, b) for a, b in zip(node, other))

if not isinstance(node, VyperNode):
return node == other

if getattr(node, "node_id", None) != getattr(other, "node_id", None):
return False
for field_name in (i for i in node.get_fields() if i not in VyperNode.__slots__):
lhs = getattr(node, field_name, None)
rhs = getattr(other, field_name, None)
if not deepequals(lhs, rhs):
return False
return True
3 changes: 2 additions & 1 deletion tests/unit/ast/nodes/test_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.ast_utils import deepequals
from vyper import ast as vy_ast
from vyper.exceptions import SyntaxException

Expand All @@ -18,7 +19,7 @@ def x():
"""
)

assert expected == mutated
assert deepequals(expected, mutated)


def test_binary_length():
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/ast/nodes/test_compare_nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper import ast as vy_ast


Expand All @@ -6,33 +7,33 @@ def test_compare_different_node_clases():
left = vyper_ast.body[0].target
right = vyper_ast.body[0].value

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_class():
vyper_ast = vy_ast.parse_to_ast("[1, 2]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_value():
vyper_ast = vy_ast.parse_to_ast("[1, 1]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_similar_node():
# test equality without node_ids
left = vy_ast.Int(value=1)
right = vy_ast.Int(value=1)

assert left == right
assert deepequals(left, right)


def test_compare_same_node():
vyper_ast = vy_ast.parse_to_ast("42")
node = vyper_ast.body[0].value

assert node == node
assert deepequals(node, node)
3 changes: 2 additions & 1 deletion tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json

from tests.ast_utils import deepequals
from vyper import compiler
from vyper.ast.nodes import NODE_SRC_ATTRIBUTES
from vyper.ast.parse import parse_to_ast
Expand Down Expand Up @@ -138,7 +139,7 @@ def test() -> int128:
new_dict = json.loads(out_json)
new_ast = dict_to_ast(new_dict)

assert new_ast == original_ast
assert deepequals(new_ast, original_ast)


# strip source annotations like lineno, we don't care for inspecting
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/ast/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper.ast.parse import parse_to_ast


Expand All @@ -12,7 +13,7 @@ def test() -> int128:
ast1 = parse_to_ast(code)
ast2 = parse_to_ast("\n \n" + code + "\n\n")

assert ast1 == ast2
assert deepequals(ast1, ast2)


def test_ast_unequal():
Expand All @@ -32,4 +33,4 @@ def test() -> int128:
ast1 = parse_to_ast(code1)
ast2 = parse_to_ast(code2)

assert ast1 != ast2
assert not deepequals(ast1, ast2)
16 changes: 0 additions & 16 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,26 +331,10 @@ def get_fields(cls) -> set:
slot_fields = [x for i in cls.__mro__ for x in getattr(i, "__slots__", [])]
return set(i for i in slot_fields if not i.startswith("_"))

def __hash__(self):
values = [getattr(self, i, None) for i in VyperNode._public_slots]
return hash(tuple(values))

def __deepcopy__(self, memo):
# default implementation of deepcopy is a hotspot
return pickle.loads(pickle.dumps(self))

def __eq__(self, other):
# CMC 2024-03-03 I'm not sure it makes much sense to compare AST
# nodes, especially if they come from other modules
if not isinstance(other, type(self)):
return False
if getattr(other, "node_id", None) != getattr(self, "node_id", None):
return False
for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__):
if getattr(self, field_name, None) != getattr(other, field_name, None):
return False
return True

def __repr__(self):
cls = type(self)
class_repr = f"{cls.__module__}.{cls.__qualname__}"
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/analysis/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def push_path(self, module_ast: vy_ast.Module) -> None:

def pop_path(self, expected: vy_ast.Module) -> None:
popped = self._path.pop()
if expected != popped:
raise CompilerPanic("unreachable")
assert expected is popped, "unreachable"
self._imports.pop()

@contextlib.contextmanager
Expand All @@ -78,7 +77,7 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph):
self.graph = graph
self._ast_of: dict[int, vy_ast.Module] = {}

self.seen: set[int] = set()
self.seen: set[vy_ast.Module] = set()

self._integrity_sum = None

Expand All @@ -103,7 +102,7 @@ def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module):
return sha256sum("".join(acc))

def _resolve_imports_r(self, module_ast: vy_ast.Module):
if id(module_ast) in self.seen:
if module_ast in self.seen:
return
with self.graph.enter_path(module_ast):
for node in module_ast.body:
Expand All @@ -112,7 +111,8 @@ def _resolve_imports_r(self, module_ast: vy_ast.Module):
self._handle_Import(node)
elif isinstance(node, vy_ast.ImportFrom):
self._handle_ImportFrom(node)
self.seen.add(id(module_ast))

self.seen.add(module_ast)

def _handle_Import(self, node: vy_ast.Import):
# import x.y[name] as y[alias]
Expand Down
Loading