Skip to content

Commit

Permalink
simplify logic for extract32
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Mar 27, 2024
1 parent 29f11e2 commit 48db56c
Showing 1 changed file with 38 additions and 47 deletions.
85 changes: 38 additions & 47 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import (
LOAD,
STORE,
IRnode,
add_ofst,
Expand Down Expand Up @@ -884,57 +885,47 @@ def infer_kwarg_types(self, node):

@process_inputs
def build_IR(self, expr, args, kwargs, context):
sub, index = args
bytez, index = args
ret_type = kwargs["output_type"]

scale = sub.location.word_scale
load_op = sub.location.load_op

# TODO rewrite all this with bitshifts
# Special case: index known to be a multiple of 32
if isinstance(index.value, int) and not index.value % 32:
with sub.cache_when_complex("_sub") as (b1, sub):
length = get_bytearray_length(sub)
idx = ["div", clamp2(0, index, ["sub", length, 32], signed=True), 32]
ret = [load_op, ["add", sub, ["add", scale, ["mul", scale, idx]]]]
o = IRnode.from_list(
b1.resolve(ret), typ=ret_type, annotation="extracting 32 bytes"
)
return IRnode.from_list(clamp_basetype(o), typ=ret_type)
def finalize(ret):
annotation = "extract32"
ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation)
return clamp_basetype(ret)

# General case
with sub.cache_when_complex("_sub") as (b1, sub):
length = get_bytearray_length(sub)
with bytez.cache_when_complex("_sub") as (b1, bytez):
# merge
length = get_bytearray_length(bytez)
index = clamp("lt", index, ["sub", length, 32])
with index.cache_when_complex("_index") as (b2, index):
idx = clamp2(0, index, ["sub", length, 32], signed=True)
mi32 = IRnode.from_list(["mod", idx, 32])
di32 = IRnode.from_list(["div", idx, 32])

with mi32.cache_when_complex("_mi32") as (b3, mi32), di32.cache_when_complex(
"_di32"
) as (b4, di32):
left_payload = [load_op, add_ofst(sub, ["add", scale, ["mul", scale, di32]])]
left_bytes = shl(["mul", 8, mi32], left_payload)

right_payload = [
load_op,
add_ofst(sub, ["add", scale, ["mul", scale, ["add", di32, 1]]]),
]
right_bytes = shr(["mul", 8, ["sub", 32, mi32]], right_payload)

ret = [
"if",
mi32,
["add", left_bytes, right_bytes],
[load_op, add_ofst(sub, ["add", scale, ["mul", scale, di32]])],
]
o = IRnode.from_list(
b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret)))),
typ=ret_type,
annotation="extracting 32 bytes",
)

return IRnode.from_list(clamp_basetype(o), typ=ret_type)
assert not index.typ.is_signed

# "easy" case, byte- addressed locations:
if bytez.location.word_scale == 32:
word = LOAD(add_ofst(bytes_data_ptr(bytez), index))
return finalize(b1.resolve(b2.resolve(word)))

# storage and transient storage, word-addressed
assert bytez.location.word_scale == 1

slot = IRnode.from_list(["div", index, 32])
# byte offset within the slot
byte_ofst = IRnode.from_list(["mod", index, 32])

with byte_ofst.cache_when_complex("byte_ofst") as (
b3,
byte_ofst,
), slot.cache_when_complex("slot") as (b4, slot):
# perform two loads and merge
w1 = LOAD(add_ofst(bytes_data_ptr(bytez), slot))
w2 = LOAD(add_ofst(bytes_data_ptr(bytez), ["add", slot, 1]))

left_bytes = shl(["mul", 8, byte_ofst], w1)
right_bytes = shr(["mul", 8, ["sub", 32, byte_ofst]], w2)
merged = ["or", left_bytes, right_bytes]

ret = ["if", byte_ofst, merged, left_bytes]
return finalize(b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret)))))


class AsWeiValue(BuiltinFunctionT):
Expand Down

0 comments on commit 48db56c

Please sign in to comment.