Skip to content

Commit

Permalink
feat[lang]: add native hex string literals (vyperlang#4271)
Browse files Browse the repository at this point in the history
This commit adds support for custom hex literals like `x"a8b9"`. this
hex literal is equivalent to the bytestring `b"\xa8\xb9"`. this syntax
is a break from Python, which does not have this type of literal, but
it is helpful in the smart contract space because often literals are
copy-pasted around as hex.

the approach taken in this PR is to add another custom sub-parser in
`pre_parse.py` which keeps track of where the hex literals are in
the original source code. it additionally refactors the signature of
`pre_parse()` to return a `PreParseResult`, which bundles all the data
structures returned by `pre_parse()` together.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Oct 11, 2024
1 parent 4845fd4 commit 212ff59
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 62 deletions.
3 changes: 2 additions & 1 deletion docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,12 @@ A byte array with a max size.
The syntax being ``Bytes[maxLen]``, where ``maxLen`` is an integer which denotes the maximum number of bytes.
On the ABI level the Fixed-size bytes array is annotated as ``bytes``.

Bytes literals may be given as bytes strings.
Bytes literals may be given as bytes strings or as hex strings.

.. code-block:: vyper
bytes_string: Bytes[100] = b"\x01"
bytes_string: Bytes[100] = x"01"
.. index:: !string

Expand Down
22 changes: 22 additions & 0 deletions tests/functional/codegen/types/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,28 @@ def test2(l: bytes{m} = {vyper_literal}) -> bool:
assert c.test2(vyper_literal) is True


@pytest.mark.parametrize("m,val", [(2, "ab"), (3, "ab"), (4, "abcd")])
def test_native_hex_literals(get_contract, m, val):
vyper_literal = bytes.fromhex(val)
code = f"""
@external
def test() -> bool:
l: Bytes[{m}] = x"{val}"
return l == {vyper_literal}
@external
def test2(l: Bytes[{m}] = x"{val}") -> bool:
return l == {vyper_literal}
"""
print(code)

c = get_contract(code)

assert c.test() is True
assert c.test2() is True
assert c.test2(vyper_literal) is True


def test_zero_padding_with_private(get_contract):
code = """
counter: uint256
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/grammar/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@ def has_no_docstrings(c):
max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]
)
def test_grammar_bruteforce(code):
_, _, _, reformatted_code = pre_parse(code + "\n")
tree = parse_to_ast(reformatted_code)
pre_parse_result = pre_parse(code + "\n")
tree = parse_to_ast(pre_parse_result.reformatted_code)
assert isinstance(tree, Module)
9 changes: 9 additions & 0 deletions tests/functional/syntax/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def test() -> Bytes[1]:
(
"""
@external
def test() -> Bytes[2]:
a: Bytes[2] = x"abc"
return a
""",
SyntaxException,
),
(
"""
@external
def foo():
a: Bytes = b"abc"
""",
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/ast/test_annotate_and_optimize_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def foo() -> int128:


def get_contract_info(source_code):
_, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code)
py_ast = python_ast.parse(reformatted_code)
pre_parse_result = pre_parse(source_code)
py_ast = python_ast.parse(pre_parse_result.reformatted_code)

annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types)
annotate_python_ast(py_ast, pre_parse_result.reformatted_code, pre_parse_result)

return py_ast, reformatted_code
return py_ast, pre_parse_result.reformatted_code


def test_it_annotates_ast_with_source_code():
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ast/test_pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version):
@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples)
def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version):
mock_version("0.3.10")
settings, _, _, _ = pre_parse(code)
pre_parse_result = pre_parse(code)

assert settings == pre_parse_settings
assert pre_parse_result.settings == pre_parse_settings

compiler_data = CompilerData(code)

Expand Down
2 changes: 1 addition & 1 deletion vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ COMMENT: /#[^\n\r]*/
_NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+


STRING: /b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
STRING: /x?b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
DOCSTRING: /(""".*?(?<!\\)(\\\\)*?"""|'''.*?(?<!\\)(\\\\)*?''')/is

DEC_NUMBER: /0|[1-9]\d*/i
Expand Down
19 changes: 19 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,25 @@ def s(self):
return self.value


class HexBytes(Constant):
__slots__ = ()
_translated_fields = {"s": "value"}

def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
super().__init__(parent, **kwargs)
if isinstance(self.value, str):
self.value = bytes.fromhex(self.value)

def to_dict(self):
ast_dict = super().to_dict()
ast_dict["value"] = f"0x{self.value.hex()}"
return ast_dict

@property
def s(self):
return self.value


class List(ExprNode):
__slots__ = ("elements",)
_translated_fields = {"elts": "elements"}
Expand Down
4 changes: 4 additions & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ class Bytes(Constant):
@property
def s(self): ...

class HexBytes(Constant):
@property
def s(self): ...

class NameConstant(Constant): ...
class Ellipsis(Constant): ...

Expand Down
59 changes: 31 additions & 28 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import asttokens

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import pre_parse
from vyper.ast.pre_parser import PreParseResult, pre_parse
from vyper.compiler.settings import Settings
from vyper.exceptions import CompilerPanic, ParserException, SyntaxException
from vyper.typing import ModificationOffsets
from vyper.utils import sha256sum, vyper_warn


Expand Down Expand Up @@ -55,9 +54,9 @@ def parse_to_ast_with_settings(
"""
if "\x00" in vyper_source:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, class_types, for_loop_annotations, python_source = pre_parse(vyper_source)
pre_parse_result = pre_parse(vyper_source)
try:
py_ast = python_ast.parse(python_source)
py_ast = python_ast.parse(pre_parse_result.reformatted_code)
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), vyper_source, e.lineno, e.offset) from None
Expand All @@ -73,21 +72,20 @@ def parse_to_ast_with_settings(
annotate_python_ast(
py_ast,
vyper_source,
class_types,
for_loop_annotations,
pre_parse_result,
source_id=source_id,
module_path=module_path,
resolved_path=resolved_path,
)

# postcondition: consumed all the for loop annotations
assert len(for_loop_annotations) == 0
assert len(pre_parse_result.for_loop_annotations) == 0

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint

return settings, module
return pre_parse_result.settings, module


def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
Expand Down Expand Up @@ -118,8 +116,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]:
def annotate_python_ast(
parsed_ast: python_ast.AST,
vyper_source: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
pre_parse_result: PreParseResult,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand All @@ -133,11 +130,8 @@ def annotate_python_ast(
The AST to be annotated and optimized.
vyper_source: str
The original vyper source code
loop_var_annotations: dict
A mapping of line numbers of `For` nodes to the tokens of the type
annotation of the iterator extracted during pre-parsing.
modification_offsets : dict
A mapping of class names to their original class types.
pre_parse_result: PreParseResult
Outputs from pre-parsing.
Returns
-------
Expand All @@ -148,8 +142,7 @@ def annotate_python_ast(
tokens.mark_tokens(parsed_ast)
visitor = AnnotatingVisitor(
vyper_source,
modification_offsets,
for_loop_annotations,
pre_parse_result,
tokens,
source_id,
module_path=module_path,
Expand All @@ -162,14 +155,12 @@ def annotate_python_ast(

class AnnotatingVisitor(python_ast.NodeTransformer):
_source_code: str
_modification_offsets: ModificationOffsets
_loop_var_annotations: dict[int, dict[str, Any]]
_pre_parse_result: PreParseResult

def __init__(
self,
source_code: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
pre_parse_result: PreParseResult,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
Expand All @@ -180,8 +171,7 @@ def __init__(
self._module_path = module_path
self._resolved_path = resolved_path
self._source_code = source_code
self._modification_offsets = modification_offsets
self._for_loop_annotations = for_loop_annotations
self._pre_parse_result = pre_parse_result

self.counter: int = 0

Expand Down Expand Up @@ -275,15 +265,16 @@ def visit_ClassDef(self, node):
"""
self.generic_visit(node)

node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)]
node.ast_type = self._pre_parse_result.modification_offsets[(node.lineno, node.col_offset)]
return node

def visit_For(self, node):
"""
Visit a For node, splicing in the loop variable annotation provided by
the pre-parser
"""
annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset))
key = (node.lineno, node.col_offset)
annotation_tokens = self._pre_parse_result.for_loop_annotations.pop(key)

if not annotation_tokens:
# a common case for people migrating to 0.4.0, provide a more
Expand Down Expand Up @@ -350,14 +341,15 @@ def visit_Expr(self, node):
if isinstance(node.value, python_ast.Yield):
# CMC 2024-03-03 consider unremoving this from the enclosing Expr
node = node.value
node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)]
key = (node.lineno, node.col_offset)
node.ast_type = self._pre_parse_result.modification_offsets[key]

return node

def visit_Await(self, node):
start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them
self.generic_visit(node)
node.ast_type = self._modification_offsets[start_pos]
node.ast_type = self._pre_parse_result.modification_offsets[start_pos]
return node

def visit_Call(self, node):
Expand Down Expand Up @@ -401,7 +393,18 @@ def visit_Constant(self, node):
if node.value is None or isinstance(node.value, bool):
node.ast_type = "NameConstant"
elif isinstance(node.value, str):
node.ast_type = "Str"
key = (node.lineno, node.col_offset)
if key in self._pre_parse_result.native_hex_literal_locations:
if len(node.value) % 2 != 0:
raise SyntaxException(
"Native hex string must have an even number of characters",
self._source_code,
node.lineno,
node.col_offset,
)
node.ast_type = "HexBytes"
else:
node.ast_type = "Str"
elif isinstance(node.value, bytes):
node.ast_type = "Bytes"
elif isinstance(node.value, Ellipsis.__class__):
Expand Down
Loading

0 comments on commit 212ff59

Please sign in to comment.