From 07bfffa508f306bdbc56250de1a6c7ddc72dfaad Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Fri, 29 Mar 2024 11:02:50 +0100 Subject: [PATCH] add complex transient storage tests --- .../codegen/features/test_transient.py | 194 ++++++++++++++++-- 1 file changed, 175 insertions(+), 19 deletions(-) diff --git a/tests/functional/codegen/features/test_transient.py b/tests/functional/codegen/features/test_transient.py index 2168ef78c6..05088a8965 100644 --- a/tests/functional/codegen/features/test_transient.py +++ b/tests/functional/codegen/features/test_transient.py @@ -159,6 +159,35 @@ def foo(_a: uint256, _b: uint256, _c: address, _d: int256) -> MyStruct: assert c.foo(*values) == values +def test_complex_struct_transient(get_contract): + code = """ +struct MyStruct: + a: address + b: MyStruct2 + c: DynArray[DynArray[uint256, 3], 3] + +struct MyStruct2: + a: DynArray[uint256, 3] + +my_struct: public(transient(MyStruct)) + +@external +def foo(_a: address, _b: MyStruct2, _c: DynArray[DynArray[uint256, 3], 3]) -> MyStruct: + self.my_struct = MyStruct( + a=_a, + b=_b, + c=_c, + ) + return self.my_struct + """ + values = ("0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", ([1],), [[3,4], [5,6]]) + + c = get_contract(code) + assert c.foo(*values) == values + assert c.my_struct() == (None, ([],), []) + assert c.foo(*values) == values + + def test_complex_transient_modifiable(get_contract): code = """ struct MyStruct: @@ -200,6 +229,53 @@ def foo(_a: uint256, _b: uint256, _c: uint256) -> uint256[3]: assert c.foo(*values) == list(values) +def test_hashmap_transient(get_contract): + code = """ +my_map: public(transient(HashMap[uint256, uint256])) + +@external +def foo(k: uint256, v: uint256) -> uint256: + self.my_map[k] = v + return self.my_map[k] + """ + c = get_contract(code) + for v in range(5): + for k in range(5): + assert c.foo(k, v) == v + assert c.my_map(k) == 0 + + +def test_complex_hashmap_transient(get_contract): + code = """ +struct MyStruct: + a: uint256 + b: DynArray[uint256, 3] + +my_map: public(transient(HashMap[uint256, MyStruct])) +my_res: public(HashMap[uint256, MyStruct]) + +@external +def do_side_effects(): + a: DynArray[uint256, 3] = [1, 2, 3] + s: MyStruct = MyStruct(a=100, b=a) + for i: uint256 in range(2): + for j: uint256 in range(3): + s.b[j] = i + j + s.a = i + self.my_map[i] = s + self.my_res[i] = self.my_map[i] + """ + c = get_contract(code) + c.do_side_effects(transact={}) + for i in range(2): + assert c.my_res(i)[0] == i + assert c.my_map(i)[0] == 0 + for j in range(3): + print(c.my_res(i)[1]) + assert c.my_res(i)[1][j] == i + j + assert c.my_map(i)[1] == [] + + def test_dynarray_transient(get_contract): code = """ my_list: public(transient(DynArray[uint256, 3])) @@ -248,11 +324,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256: def test_nested_dynarray_transient(get_contract): - code = """ -my_list: public(transient(DynArray[DynArray[DynArray[int128, 3], 3], 3])) - -@external -def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[int128, 3], 3], 3]: + set_list = """ self.my_list = [ [[x, y, z], [y, z, x], [z, y, x]], [ @@ -265,25 +337,29 @@ def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[i [z * (-2), y * (-3), x * (-4)], [z * (-y), y * (-x), x * (-z)], ], - ] + ] + """ + code = f""" +interface Iface: + def my_list(x: uint256, y: uint256, z: uint256) -> int128: view + +my_list: public(transient(DynArray[DynArray[DynArray[int128, 3], 3], 3])) + +@external +def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[int128, 3], 3], 3]: + {set_list} return self.my_list @external def get_idx_two(x: int128, y: int128, z: int128) -> int128: - self.my_list = [ - [[x, y, z], [y, z, x], [z, y, x]], - [ - [x * 1000 + y, y * 1000 + z, z * 1000 + x], - [- (x * 1000 + y), - (y * 1000 + z), - (z * 1000 + x)], - [- (x * 1000) + y, - (y * 1000) + z, - (z * 1000) + x], - ], - [ - [z * 2, y * 3, x * 4], - [z * (-2), y * (-3), x * (-4)], - [z * (-y), y * (-x), x * (-z)], - ], - ] + {set_list} return self.my_list[2][2][2] + +@external +def get_idx_two_using_getter(x: int128, y: int128, z: int128) -> int128: + {set_list} + #return self.my_list[2][2][2] + return staticcall Iface(self).my_list(2, 2, 2) """ values = (37, 41, 73) expected_values = [ @@ -299,6 +375,9 @@ def get_idx_two(x: int128, y: int128, z: int128) -> int128: assert c.get_idx_two(*values) == expected_values[2][2][2] with pytest.raises(TransactionFailed): c.my_list(0, 0, 0) + assert c.get_idx_two_using_getter(*values) == expected_values[2][2][2] + with pytest.raises(TransactionFailed): + c.my_list(0, 0, 0) @pytest.mark.parametrize("n", range(5)) @@ -383,3 +462,80 @@ def bar(i: uint256, a: address) -> uint256: value = 333 assert c2.bar(value, c1.address) == value assert c1.get_x() == 0 + + +def test_modules_transient(get_contract, make_input_bundle): + lib1 = """ +counter: transient(uint256) + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: transient(uint256) +counter2: public(uint256) + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@external +def foo() -> (uint256, uint256): + lib1.counter = 2 + lib2.foo() + lib2.counter = 10 + return lib1.counter, lib2.counter + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == [3, 10] + + +def test_complex_modules_transient(get_contract, make_input_bundle): + lib1 = """ +l: transient(uint256[3]) + """ + lib2 = """ +import lib1 + +uses: lib1 + +struct MyStruct: + a: uint256 + d: uint256 + +s: transient(MyStruct) + +@internal +def foo(): + self.s = MyStruct(a=lib1.l[0], d=lib1.l[1]) + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +my_map: HashMap[uint256, uint256] + +@external +def foo() -> (uint256[3], lib2.MyStruct, uint256): + lib1.l = [1, 2, 3] + lib2.foo() + self.my_map[0] = 42 + return lib1.l, lib2.s, self.my_map[0] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == [[1, 2, 3], (1, 2), 42]