From 2d82a74937edeed5e9d4c0c8cecd78a0d70530fa Mon Sep 17 00:00:00 2001
From: Charles Cooper <cooper.charles.m@gmail.com>
Date: Mon, 17 Jun 2024 04:10:01 -0700
Subject: [PATCH] feat[test]: add more coverage to `abi_decode` fuzzer tests
 (#4153)

fuzz with `unwrap_tuple=False`
add fuzzing for structs

follow up to 69e5c0541a9b23
---
 .../builtins/codegen/test_abi_decode_fuzz.py  | 124 +++++++++++++++---
 vyper/semantics/types/user.py                 |  11 +-
 2 files changed, 115 insertions(+), 20 deletions(-)

diff --git a/tests/functional/builtins/codegen/test_abi_decode_fuzz.py b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py
index d12b2cde7e..e215002446 100644
--- a/tests/functional/builtins/codegen/test_abi_decode_fuzz.py
+++ b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py
@@ -18,12 +18,12 @@
     IntegerT,
     SArrayT,
     StringT,
+    StructT,
     TupleT,
     VyperType,
     _get_primitive_types,
     _get_sequence_types,
 )
-from vyper.semantics.types.shortcuts import UINT256_T
 
 from .abi_decode import DecodeError, spec_decode
 
@@ -39,7 +39,7 @@
         continue
     type_ctors.append(t)
 
-complex_static_ctors = [SArrayT, TupleT]
+complex_static_ctors = [SArrayT, TupleT, StructT]
 complex_dynamic_ctors = [DArrayT]
 leaf_ctors = [t for t in type_ctors if t not in _get_sequence_types().values()]
 static_leaf_ctors = [t for t in leaf_ctors if t._is_prim_word]
@@ -50,10 +50,12 @@
 
 @st.composite
 # max type nesting
-def vyper_type(draw, nesting=3, skip=None):
+def vyper_type(draw, nesting=3, skip=None, source_fragments=None):
     assert nesting >= 0
 
     skip = skip or []
+    if source_fragments is None:
+        source_fragments = []
 
     st_leaves = st.one_of(st.sampled_from(dynamic_leaf_ctors), st.sampled_from(static_leaf_ctors))
     st_complex = st.one_of(
@@ -71,39 +73,52 @@ def vyper_type(draw, nesting=3, skip=None):
     # note: maybe st.deferred is good here, we could define it with
     # mutual recursion
     def _go(skip=skip):
-        return draw(vyper_type(nesting=nesting - 1, skip=skip))
+        _, typ = draw(vyper_type(nesting=nesting - 1, skip=skip, source_fragments=source_fragments))
+        return typ
+
+    def finalize(typ):
+        return source_fragments, typ
 
     if t in (BytesT, StringT):
         # arbitrary max_value
         bound = draw(st.integers(min_value=1, max_value=1024))
-        return t(bound)
+        return finalize(t(bound))
 
     if t == SArrayT:
         subtype = _go(skip=[TupleT, BytesT, StringT])
         bound = draw(st.integers(min_value=1, max_value=6))
-        return t(subtype, bound)
+        return finalize(t(subtype, bound))
     if t == DArrayT:
         subtype = _go(skip=[TupleT])
         bound = draw(st.integers(min_value=1, max_value=16))
-        return t(subtype, bound)
+        return finalize(t(subtype, bound))
 
     if t == TupleT:
         # zero-length tuples are not allowed in vyper
         n = draw(st.integers(min_value=1, max_value=6))
         subtypes = [_go() for _ in range(n)]
-        return TupleT(subtypes)
+        return finalize(TupleT(subtypes))
+
+    if t == StructT:
+        n = draw(st.integers(min_value=1, max_value=6))
+        subtypes = {f"x{i}": _go() for i in range(n)}
+        _id = len(source_fragments)  # poor man's unique id
+        name = f"MyStruct{_id}"
+        typ = StructT(name, subtypes)
+        source_fragments.append(typ.def_source_str())
+        return finalize(StructT(name, subtypes))
 
     if t in (BoolT, AddressT):
-        return t()
+        return finalize(t())
 
     if t == IntegerT:
         signed = draw(st.booleans())
         bits = 8 * draw(st.integers(min_value=1, max_value=32))
-        return t(signed, bits)
+        return finalize(t(signed, bits))
 
     if t == BytesM_T:
         m = draw(st.integers(min_value=1, max_value=32))
-        return t(m)
+        return finalize(t(m))
 
     raise RuntimeError("unreachable")
 
@@ -116,6 +131,9 @@ def _go(t):
     if isinstance(typ, TupleT):
         return tuple(_go(item_t) for item_t in typ.member_types)
 
+    if isinstance(typ, StructT):
+        return tuple(_go(item_t) for item_t in typ.tuple_members())
+
     if isinstance(typ, SArrayT):
         return [_go(typ.value_type) for _ in range(typ.length)]
 
@@ -294,6 +312,13 @@ def _finalize():  # little trick to save re-typing the arguments
         num_dynamic_types = sum(s.num_dynamic_types for s in substats)
         return _finalize()
 
+    if isinstance(typ, StructT):
+        substats = [_type_stats(t) for t in typ.tuple_members()]
+        nesting = 1 + max(s.nesting for s in substats)
+        breadth = max(len(typ.member_types), *[s.breadth for s in substats])
+        num_dynamic_types = sum(s.num_dynamic_types for s in substats)
+        return _finalize()
+
     if isinstance(typ, DArrayT):
         substat = _type_stats(typ.value_type)
         nesting = 1 + substat.nesting
@@ -332,8 +357,8 @@ def payload_copier(get_contract_from_ir):
 @pytest.mark.parametrize("_n", list(range(PARALLELISM)))
 @hp.given(typ=vyper_type())
 @hp.settings(max_examples=100, **_settings)
-@hp.example(typ=DArrayT(DArrayT(UINT256_T, 2), 2))
-def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
+def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier, env):
+    source_fragments, typ = typ
     # import time
     # t0 = time.time()
     # print("ENTER", typ)
@@ -350,12 +375,13 @@ def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
     # by bytes length check at function entry
     type_bound = wrapped_type.abi_type.size_bound()
     buffer_bound = type_bound + MAX_MUTATIONS
-    type_str = repr(typ)  # annotation in vyper code
-    # TODO: intrinsic decode from staticcall/extcall
-    # TODO: _abi_decode from other sources (staticcall/extcall?)
-    # TODO: dirty the buffer
-    # TODO: check unwrap_tuple=False
+
+    preamble = "\n\n".join(source_fragments)
+    type_str = str(typ)  # annotation in vyper code
+
     code = f"""
+{preamble}
+
 @external
 def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
     ret: {type_str} = abi_decode(xs, {type_str})
@@ -375,6 +401,13 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}:
     assert len(xs) <= {type_bound}
     return (extcall copier.bar(xs))
     """
+    try:
+        c = get_contract(code)
+    except EvmError as e:
+        if env.contract_size_limit_error in str(e):
+            hp.assume(False)
+    # print(code)
+    hp.note(code)
     c = get_contract(code)
 
     @hp.given(data=payload_from(wrapped_type))
@@ -382,7 +415,6 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}:
     def _fuzz(data):
         hp.note(f"type: {typ}")
         hp.note(f"abi_t: {wrapped_type.abi_type.selector_name()}")
-        hp.note(code)
         hp.note(data.hex())
 
         try:
@@ -414,3 +446,57 @@ def _fuzz(data):
 
     # t1 = time.time()
     # print(f"elapsed {t1 - t0}s")
+
+
+@pytest.mark.parametrize("_n", list(range(PARALLELISM)))
+@hp.given(typ=vyper_type())
+@hp.settings(max_examples=100, **_settings)
+def test_abi_decode_no_wrap_fuzz(_n, typ, get_contract, tx_failed, env):
+    source_fragments, typ = typ
+    # import time
+    # t0 = time.time()
+    # print("ENTER", typ)
+
+    stats = _type_stats(typ)
+    hp.target(stats.num_dynamic_types)
+
+    # add max_mutations bytes worth of padding so we don't just get caught
+    # by bytes length check at function entry
+    type_bound = typ.abi_type.size_bound()
+    buffer_bound = type_bound + MAX_MUTATIONS
+
+    type_str = str(typ)  # annotation in vyper code
+    preamble = "\n\n".join(source_fragments)
+
+    code = f"""
+{preamble}
+
+@external
+def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
+    ret: {type_str} = abi_decode(xs, {type_str}, unwrap_tuple=False)
+    return ret
+    """
+    try:
+        c = get_contract(code)
+    except EvmError as e:
+        if env.contract_size_limit_error in str(e):
+            hp.assume(False)
+
+    @hp.given(data=payload_from(typ))
+    @hp.settings(max_examples=100, **_settings)
+    def _fuzz(data):
+        hp.note(code)
+        hp.note(data.hex())
+        try:
+            expected = spec_decode(typ, data)
+            hp.note(f"expected {expected}")
+            assert expected == c.run(data)
+        except DecodeError:
+            hp.note("expect failure")
+            with tx_failed(EvmError):
+                c.run(data)
+
+    _fuzz()
+
+    # t1 = time.time()
+    # print(f"elapsed {t1 - t0}s")
diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py
index a6ee646e62..ca8e99bc92 100644
--- a/vyper/semantics/types/user.py
+++ b/vyper/semantics/types/user.py
@@ -371,8 +371,11 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":
 
         return cls(struct_name, members, ast_def=base_node)
 
+    def __str__(self):
+        return f"{self._id}"
+
     def __repr__(self):
-        return f"{self._id} declaration object"
+        return f"{self._id} {self.members}"
 
     def _try_fold(self, node):
         if len(node.args) != 1:
@@ -384,6 +387,12 @@ def _try_fold(self, node):
         # it can't be reduced, but this lets upstream code know it's constant
         return node
 
+    def def_source_str(self):
+        ret = f"struct {self._id}:\n"
+        for k, v in self.member_types.items():
+            ret += f"    {k}: {v}\n"
+        return ret
+
     @property
     def size_in_bytes(self):
         return sum(i.size_in_bytes for i in self.member_types.values())