From 212ff59b05696ac6c058bc1fb45b86c5f54f7101 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 12 Oct 2024 05:48:58 +0800 Subject: [PATCH] feat[lang]: add native hex string literals (#4271) This commit adds support for custom hex literals like `x"a8b9"`. this hex literal is equivalent to the bytestring `b"\xa8\xb9"`. this syntax is a break from Python, which does not have this type of literal, but it is helpful in the smart contract space because often literals are copy-pasted around as hex. the approach taken in this PR is to add another custom sub-parser in `pre_parse.py` which keeps track of where the hex literals are in the original source code. it additionally refactors the signature of `pre_parse()` to return a `PreParseResult`, which bundles all the data structures returned by `pre_parse()` together. --------- Co-authored-by: Charles Cooper --- docs/types.rst | 3 +- tests/functional/codegen/types/test_bytes.py | 22 ++++ tests/functional/grammar/test_grammar.py | 4 +- tests/functional/syntax/test_bytes.py | 9 ++ .../ast/test_annotate_and_optimize_ast.py | 8 +- tests/unit/ast/test_pre_parser.py | 4 +- vyper/ast/grammar.lark | 2 +- vyper/ast/nodes.py | 19 +++ vyper/ast/nodes.pyi | 4 + vyper/ast/parse.py | 59 +++++----- vyper/ast/pre_parser.py | 108 ++++++++++++++---- vyper/codegen/expr.py | 6 + vyper/semantics/types/bytestrings.py | 2 +- vyper/typing.py | 1 - 14 files changed, 189 insertions(+), 62 deletions(-) diff --git a/docs/types.rst b/docs/types.rst index 752e06b14f..807c83848f 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -359,11 +359,12 @@ A byte array with a max size. The syntax being ``Bytes[maxLen]``, where ``maxLen`` is an integer which denotes the maximum number of bytes. On the ABI level the Fixed-size bytes array is annotated as ``bytes``. -Bytes literals may be given as bytes strings. +Bytes literals may be given as bytes strings or as hex strings. .. code-block:: vyper bytes_string: Bytes[100] = b"\x01" + bytes_string: Bytes[100] = x"01" .. index:: !string diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index a5b119f143..6473be4348 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -259,6 +259,28 @@ def test2(l: bytes{m} = {vyper_literal}) -> bool: assert c.test2(vyper_literal) is True +@pytest.mark.parametrize("m,val", [(2, "ab"), (3, "ab"), (4, "abcd")]) +def test_native_hex_literals(get_contract, m, val): + vyper_literal = bytes.fromhex(val) + code = f""" +@external +def test() -> bool: + l: Bytes[{m}] = x"{val}" + return l == {vyper_literal} + +@external +def test2(l: Bytes[{m}] = x"{val}") -> bool: + return l == {vyper_literal} + """ + print(code) + + c = get_contract(code) + + assert c.test() is True + assert c.test2() is True + assert c.test2(vyper_literal) is True + + def test_zero_padding_with_private(get_contract): code = """ counter: uint256 diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 2af5385b3d..0ff8c23477 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -102,6 +102,6 @@ def has_no_docstrings(c): max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much] ) def test_grammar_bruteforce(code): - _, _, _, reformatted_code = pre_parse(code + "\n") - tree = parse_to_ast(reformatted_code) + pre_parse_result = pre_parse(code + "\n") + tree = parse_to_ast(pre_parse_result.reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/test_bytes.py b/tests/functional/syntax/test_bytes.py index 0ca3b27fee..9df2962f2e 100644 --- a/tests/functional/syntax/test_bytes.py +++ b/tests/functional/syntax/test_bytes.py @@ -80,6 +80,15 @@ def test() -> Bytes[1]: ( """ @external +def test() -> Bytes[2]: + a: Bytes[2] = x"abc" + return a + """, + SyntaxException, + ), + ( + """ +@external def foo(): a: Bytes = b"abc" """, diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 7e1641e49e..39ea899bd9 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,12 +28,12 @@ def foo() -> int128: def get_contract_info(source_code): - _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) - py_ast = python_ast.parse(reformatted_code) + pre_parse_result = pre_parse(source_code) + py_ast = python_ast.parse(pre_parse_result.reformatted_code) - annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) + annotate_python_ast(py_ast, pre_parse_result.reformatted_code, pre_parse_result) - return py_ast, reformatted_code + return py_ast, pre_parse_result.reformatted_code def test_it_annotates_ast_with_source_code(): diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index da7d72b8ec..4190725f7e 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -174,9 +174,9 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _, _ = pre_parse(code) + pre_parse_result = pre_parse(code) - assert settings == pre_parse_settings + assert pre_parse_result.settings == pre_parse_settings compiler_data = CompilerData(code) diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 97f9f70e24..bc2f9ba77c 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -318,7 +318,7 @@ COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ -STRING: /b?("(?!"").*?(? Union[Dict, List]: @@ -118,8 +116,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: def annotate_python_ast( parsed_ast: python_ast.AST, vyper_source: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, + pre_parse_result: PreParseResult, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -133,11 +130,8 @@ def annotate_python_ast( The AST to be annotated and optimized. vyper_source: str The original vyper source code - loop_var_annotations: dict - A mapping of line numbers of `For` nodes to the tokens of the type - annotation of the iterator extracted during pre-parsing. - modification_offsets : dict - A mapping of class names to their original class types. + pre_parse_result: PreParseResult + Outputs from pre-parsing. Returns ------- @@ -148,8 +142,7 @@ def annotate_python_ast( tokens.mark_tokens(parsed_ast) visitor = AnnotatingVisitor( vyper_source, - modification_offsets, - for_loop_annotations, + pre_parse_result, tokens, source_id, module_path=module_path, @@ -162,14 +155,12 @@ def annotate_python_ast( class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str - _modification_offsets: ModificationOffsets - _loop_var_annotations: dict[int, dict[str, Any]] + _pre_parse_result: PreParseResult def __init__( self, source_code: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, + pre_parse_result: PreParseResult, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -180,8 +171,7 @@ def __init__( self._module_path = module_path self._resolved_path = resolved_path self._source_code = source_code - self._modification_offsets = modification_offsets - self._for_loop_annotations = for_loop_annotations + self._pre_parse_result = pre_parse_result self.counter: int = 0 @@ -275,7 +265,7 @@ def visit_ClassDef(self, node): """ self.generic_visit(node) - node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] + node.ast_type = self._pre_parse_result.modification_offsets[(node.lineno, node.col_offset)] return node def visit_For(self, node): @@ -283,7 +273,8 @@ def visit_For(self, node): Visit a For node, splicing in the loop variable annotation provided by the pre-parser """ - annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + key = (node.lineno, node.col_offset) + annotation_tokens = self._pre_parse_result.for_loop_annotations.pop(key) if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more @@ -350,14 +341,15 @@ def visit_Expr(self, node): if isinstance(node.value, python_ast.Yield): # CMC 2024-03-03 consider unremoving this from the enclosing Expr node = node.value - node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] + key = (node.lineno, node.col_offset) + node.ast_type = self._pre_parse_result.modification_offsets[key] return node def visit_Await(self, node): start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them self.generic_visit(node) - node.ast_type = self._modification_offsets[start_pos] + node.ast_type = self._pre_parse_result.modification_offsets[start_pos] return node def visit_Call(self, node): @@ -401,7 +393,18 @@ def visit_Constant(self, node): if node.value is None or isinstance(node.value, bool): node.ast_type = "NameConstant" elif isinstance(node.value, str): - node.ast_type = "Str" + key = (node.lineno, node.col_offset) + if key in self._pre_parse_result.native_hex_literal_locations: + if len(node.value) % 2 != 0: + raise SyntaxException( + "Native hex string must have an even number of characters", + self._source_code, + node.lineno, + node.col_offset, + ) + node.ast_type = "HexBytes" + else: + node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" elif isinstance(node.value, Ellipsis.__class__): diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b12aecd0bf..07ba1d2d0d 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -2,7 +2,7 @@ import io import re from collections import defaultdict -from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize +from tokenize import COMMENT, NAME, OP, STRING, TokenError, TokenInfo, tokenize, untokenize from packaging.specifiers import InvalidSpecifier, SpecifierSet @@ -12,7 +12,7 @@ # evm-version pragma from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import StructureException, SyntaxException, VersionException -from vyper.typing import ModificationOffsets, ParserPosition +from vyper.typing import ParserPosition def validate_version_pragma(version_str: str, full_source_code: str, start: ParserPosition) -> None: @@ -48,7 +48,7 @@ def validate_version_pragma(version_str: str, full_source_code: str, start: Pars ) -class ForParserState(enum.Enum): +class ParserState(enum.Enum): NOT_RUNNING = enum.auto() START_SOON = enum.auto() RUNNING = enum.auto() @@ -63,7 +63,7 @@ def __init__(self, code): self.annotations = {} self._current_annotation = None - self._state = ForParserState.NOT_RUNNING + self._state = ParserState.NOT_RUNNING self._current_for_loop = None def consume(self, token): @@ -71,15 +71,15 @@ def consume(self, token): if token.type == NAME and token.string == "for": # note: self._state should be NOT_RUNNING here, but we don't sanity # check here as that should be an error the parser will handle. - self._state = ForParserState.START_SOON + self._state = ParserState.START_SOON self._current_for_loop = token.start - if self._state == ForParserState.NOT_RUNNING: + if self._state == ParserState.NOT_RUNNING: return False # state machine: start slurping tokens if token.type == OP and token.string == ":": - self._state = ForParserState.RUNNING + self._state = ParserState.RUNNING # sanity check -- this should never really happen, but if it does, # try to raise an exception which pinpoints the source. @@ -93,12 +93,12 @@ def consume(self, token): # state machine: end slurping tokens if token.type == NAME and token.string == "in": - self._state = ForParserState.NOT_RUNNING + self._state = ParserState.NOT_RUNNING self.annotations[self._current_for_loop] = self._current_annotation or [] self._current_annotation = None return False - if self._state != ForParserState.RUNNING: + if self._state != ParserState.RUNNING: return False # slurp the token @@ -106,6 +106,42 @@ def consume(self, token): return True +class HexStringParser: + def __init__(self): + self.locations = [] + self._current_x = None + self._state = ParserState.NOT_RUNNING + + def consume(self, token, result): + # prepare to check if the next token is a STRING + if token.type == NAME and token.string == "x": + self._state = ParserState.RUNNING + self._current_x = token + return True + + if self._state == ParserState.NOT_RUNNING: + return False + + if self._state == ParserState.RUNNING: + current_x = self._current_x + self._current_x = None + self._state = ParserState.NOT_RUNNING + + toks = [current_x] + + # drop the leading x token if the next token is a STRING to avoid a python + # parser error + if token.type == STRING: + self.locations.append(current_x.start) + toks = [TokenInfo(STRING, token.string, current_x.start, token.end, token.line)] + result.extend(toks) + return True + + result.extend(toks) + + return False + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = { @@ -122,7 +158,34 @@ def consume(self, token): CUSTOM_EXPRESSION_TYPES = {"extcall": "ExtCall", "staticcall": "StaticCall"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: +class PreParseResult: + # Compilation settings based on the directives in the source code + settings: Settings + # A mapping of class names to their original class types. + modification_offsets: dict[tuple[int, int], str] + # A mapping of line/column offsets of `For` nodes to the annotation of the for loop target + for_loop_annotations: dict[tuple[int, int], list[TokenInfo]] + # A list of line/column offsets of native hex literals + native_hex_literal_locations: list[tuple[int, int]] + # Reformatted python source string. + reformatted_code: str + + def __init__( + self, + settings, + modification_offsets, + for_loop_annotations, + native_hex_literal_locations, + reformatted_code, + ): + self.settings = settings + self.modification_offsets = modification_offsets + self.for_loop_annotations = for_loop_annotations + self.native_hex_literal_locations = native_hex_literal_locations + self.reformatted_code = reformatted_code + + +def pre_parse(code: str) -> PreParseResult: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -144,19 +207,14 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: Returns ------- - Settings - Compilation settings based on the directives in the source code - ModificationOffsets - A mapping of class names to their original class types. - dict[tuple[int, int], list[TokenInfo]] - A mapping of line/column offsets of `For` nodes to the annotation of the for loop target - str - Reformatted python source string. + PreParseResult + Outputs for transforming the python AST to vyper AST """ - result = [] - modification_offsets: ModificationOffsets = {} + result: list[TokenInfo] = [] + modification_offsets: dict[tuple[int, int], str] = {} settings = Settings() for_parser = ForParser(code) + native_hex_parser = HexStringParser() _col_adjustments: dict[int, int] = defaultdict(lambda: 0) @@ -264,7 +322,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - if not for_parser.consume(token): + if not for_parser.consume(token) and not native_hex_parser.consume(token, result): result.extend(toks) except TokenError as e: @@ -274,4 +332,10 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: for k, v in for_parser.annotations.items(): for_loop_annotations[k] = v.copy() - return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") + return PreParseResult( + settings, + modification_offsets, + for_loop_annotations, + native_hex_parser.locations, + untokenize(result).decode("utf-8"), + ) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 0b3b29b9d0..cd51966710 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -140,6 +140,12 @@ def parse_Str(self): # Byte literals def parse_Bytes(self): + return self._parse_bytes() + + def parse_HexBytes(self): + return self._parse_bytes() + + def _parse_bytes(self): bytez = self.expr.value bytez_length = len(self.expr.value) typ = BytesT(bytez_length) diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index cd330681cf..02e3bb213f 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -159,7 +159,7 @@ class BytesT(_BytestringT): typeclass = "bytes" _id = "Bytes" - _valid_literal = (vy_ast.Bytes,) + _valid_literal = (vy_ast.Bytes, vy_ast.HexBytes) @property def abi_type(self) -> ABIType: diff --git a/vyper/typing.py b/vyper/typing.py index ad3964dff9..108c0605bb 100644 --- a/vyper/typing.py +++ b/vyper/typing.py @@ -1,7 +1,6 @@ from typing import Dict, Optional, Sequence, Tuple, Union # Parser -ModificationOffsets = Dict[Tuple[int, int], str] ParserPosition = Tuple[int, int] # Compiler