From 48db56c8011af4411004834781172d9c6009fdcf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Mar 2024 20:32:01 -0400 Subject: [PATCH] simplify logic for extract32 --- vyper/builtins/functions.py | 85 +++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 47 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 814a7da322..efb8a30c04 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -8,6 +8,7 @@ from vyper.codegen.abi_encoder import abi_encode from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import ( + LOAD, STORE, IRnode, add_ofst, @@ -884,57 +885,47 @@ def infer_kwarg_types(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - sub, index = args + bytez, index = args ret_type = kwargs["output_type"] - scale = sub.location.word_scale - load_op = sub.location.load_op - - # TODO rewrite all this with bitshifts - # Special case: index known to be a multiple of 32 - if isinstance(index.value, int) and not index.value % 32: - with sub.cache_when_complex("_sub") as (b1, sub): - length = get_bytearray_length(sub) - idx = ["div", clamp2(0, index, ["sub", length, 32], signed=True), 32] - ret = [load_op, ["add", sub, ["add", scale, ["mul", scale, idx]]]] - o = IRnode.from_list( - b1.resolve(ret), typ=ret_type, annotation="extracting 32 bytes" - ) - return IRnode.from_list(clamp_basetype(o), typ=ret_type) + def finalize(ret): + annotation = "extract32" + ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation) + return clamp_basetype(ret) - # General case - with sub.cache_when_complex("_sub") as (b1, sub): - length = get_bytearray_length(sub) + with bytez.cache_when_complex("_sub") as (b1, bytez): + # merge + length = get_bytearray_length(bytez) + index = clamp("lt", index, ["sub", length, 32]) with index.cache_when_complex("_index") as (b2, index): - idx = clamp2(0, index, ["sub", length, 32], signed=True) - mi32 = IRnode.from_list(["mod", idx, 32]) - di32 = IRnode.from_list(["div", idx, 32]) - - with mi32.cache_when_complex("_mi32") as (b3, mi32), di32.cache_when_complex( - "_di32" - ) as (b4, di32): - left_payload = [load_op, add_ofst(sub, ["add", scale, ["mul", scale, di32]])] - left_bytes = shl(["mul", 8, mi32], left_payload) - - right_payload = [ - load_op, - add_ofst(sub, ["add", scale, ["mul", scale, ["add", di32, 1]]]), - ] - right_bytes = shr(["mul", 8, ["sub", 32, mi32]], right_payload) - - ret = [ - "if", - mi32, - ["add", left_bytes, right_bytes], - [load_op, add_ofst(sub, ["add", scale, ["mul", scale, di32]])], - ] - o = IRnode.from_list( - b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret)))), - typ=ret_type, - annotation="extracting 32 bytes", - ) - - return IRnode.from_list(clamp_basetype(o), typ=ret_type) + assert not index.typ.is_signed + + # "easy" case, byte- addressed locations: + if bytez.location.word_scale == 32: + word = LOAD(add_ofst(bytes_data_ptr(bytez), index)) + return finalize(b1.resolve(b2.resolve(word))) + + # storage and transient storage, word-addressed + assert bytez.location.word_scale == 1 + + slot = IRnode.from_list(["div", index, 32]) + # byte offset within the slot + byte_ofst = IRnode.from_list(["mod", index, 32]) + + with byte_ofst.cache_when_complex("byte_ofst") as ( + b3, + byte_ofst, + ), slot.cache_when_complex("slot") as (b4, slot): + # perform two loads and merge + w1 = LOAD(add_ofst(bytes_data_ptr(bytez), slot)) + w2 = LOAD(add_ofst(bytes_data_ptr(bytez), ["add", slot, 1])) + + left_bytes = shl(["mul", 8, byte_ofst], w1) + right_bytes = shr(["mul", 8, ["sub", 32, byte_ofst]], w2) + merged = ["or", left_bytes, right_bytes] + + ret = ["if", byte_ofst, merged, left_bytes] + return finalize(b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret))))) class AsWeiValue(BuiltinFunctionT):