Skip to content

Commit

Permalink
feat[venom]: avoid last swap for commutative ops (#4048)
Browse files Browse the repository at this point in the history
This commit implements a simple "last swap" avoidance for commutative ops.
Additionally, it renames then `get_inputs()` method to `get_input_variables()`
in `IRInstruction` for clarity and consistency
  • Loading branch information
harkal authored May 28, 2024
1 parent fe7d86b commit d6b300d
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 15 deletions.
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def analyze(self):
# dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...]
for bb in self.function.get_basic_blocks():
for inst in bb.instructions:
operands = inst.get_inputs()
operands = inst.get_input_variables()
res = inst.get_outputs()

for op in operands:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dup_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def analyze(self):
last_liveness = bb.out_vars
for inst in reversed(bb.instructions):
inst.dup_requirements = OrderedSet()
ops = inst.get_inputs()
ops = inst.get_input_variables()
for op in ops:
if op in last_liveness:
inst.dup_requirements.add(op)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _calculate_liveness(self, bb: IRBasicBlock) -> bool:
orig_liveness = bb.instructions[0].liveness.copy()
liveness = bb.out_vars.copy()
for instruction in reversed(bb.instructions):
ins = instruction.get_inputs()
ins = instruction.get_input_variables()
outs = instruction.get_outputs()

if ins or outs:
Expand Down
4 changes: 2 additions & 2 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_non_label_operands(self) -> Iterator[IROperand]:
"""
return (op for op in self.operands if not isinstance(op, IRLabel))

def get_inputs(self) -> Iterator[IRVariable]:
def get_input_variables(self) -> Iterator[IRVariable]:
"""
Get all input operands for instruction.
"""
Expand Down Expand Up @@ -477,7 +477,7 @@ def get_assignments(self):
def get_uses(self) -> dict[IRVariable, OrderedSet[IRInstruction]]:
uses: dict[IRVariable, OrderedSet[IRInstruction]] = {}
for inst in self.instructions:
for op in inst.get_inputs():
for op in inst.get_input_variables():
if op not in uses:
uses[op] = OrderedSet()
uses[op].add(inst)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset:
self.inst_order[inst] = 0
return

for op in inst.get_inputs():
for op in inst.get_input_variables():
target = self.dfg.get_producing_instruction(op)
assert target is not None, f"no producing instruction for {op}"
if target.parent != inst.parent or target.fence_id != inst.fence_id:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/remove_unused_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _process_instruction(self, inst):
if len(uses) > 0:
return

for operand in inst.get_inputs():
for operand in inst.get_input_variables():
self.dfg.remove_use(operand, inst)
new_uses = self.dfg.get_uses(operand)
self.work_list.addmany(new_uses)
Expand Down
35 changes: 27 additions & 8 deletions vyper/venom/venom_to_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@
]
)

COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"])


_REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"]


Expand Down Expand Up @@ -195,8 +198,14 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]:
return top_asm

def _stack_reorder(
self, assembly: list, stack: StackModel, stack_ops: list[IRVariable]
) -> None:
self, assembly: list, stack: StackModel, stack_ops: list[IROperand], dry_run: bool = False
) -> int:
cost = 0

if dry_run:
assert len(assembly) == 0, "Dry run should not work on assembly"
stack = stack.copy()

stack_ops_count = len(stack_ops)

counts = Counter(stack_ops)
Expand All @@ -216,8 +225,10 @@ def _stack_reorder(
if op == stack.peek(final_stack_depth):
continue

self.swap(assembly, stack, depth)
self.swap(assembly, stack, final_stack_depth)
cost += self.swap(assembly, stack, depth)
cost += self.swap(assembly, stack, final_stack_depth)

return cost

def _emit_input_operands(
self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel
Expand Down Expand Up @@ -376,7 +387,7 @@ def _generate_evm_for_instruction(

if opcode == "phi":
ret = inst.get_outputs()[0]
phis = list(inst.get_inputs())
phis = list(inst.get_input_variables())
depth = stack.get_phi_depth(phis)
# collapse the arguments to the phi node in the stack.
# example, for `%56 = %label1 %13 %label2 %14`, we will
Expand Down Expand Up @@ -406,9 +417,16 @@ def _generate_evm_for_instruction(
target_stack_list = list(target_stack)
self._stack_reorder(assembly, stack, target_stack_list)

if opcode in COMMUTATIVE_INSTRUCTIONS:
cost_no_swap = self._stack_reorder([], stack, operands, dry_run=True)
operands[-1], operands[-2] = operands[-2], operands[-1]
cost_with_swap = self._stack_reorder([], stack, operands, dry_run=True)
if cost_with_swap > cost_no_swap:
operands[-1], operands[-2] = operands[-2], operands[-1]

# final step to get the inputs to this instruction ordered
# correctly on the stack
self._stack_reorder(assembly, stack, operands) # type: ignore
self._stack_reorder(assembly, stack, operands)

# some instructions (i.e. invoke) need to do stack manipulations
# with the stack model containing the return value(s), so we fiddle
Expand Down Expand Up @@ -533,13 +551,14 @@ def pop(self, assembly, stack, num=1):
stack.pop(num)
assembly.extend(["POP"] * num)

def swap(self, assembly, stack, depth):
def swap(self, assembly, stack, depth) -> int:
# Swaps of the top is no op
if depth == 0:
return
return 0

stack.swap(depth)
assembly.append(_evm_swap_for(depth))
return 1

def dup(self, assembly, stack, depth):
stack.dup(depth)
Expand Down

0 comments on commit d6b300d

Please sign in to comment.