Skip to content

Commit

Permalink
feat[venom]: new DFTPass algorithm (#4255)
Browse files Browse the repository at this point in the history
this commit upgrades the DFT algorithm to allow for more instruction
movement and performs "multidimensional" fencing, which allows
instructions to be reordered across volatile instructions if there
is no effect barrier. since barriers do not truly live in the data
dependency graph, it introduces a heuristic which chooses which barrier
to recurse into first.

it also removes the use of order ids and sorting, which improves
performance.

---------

Co-authored-by: Charles Cooper <[email protected]>
Co-authored-by: HodanPlodky <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 48cb39b commit c32b9b4
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 66 deletions.
3 changes: 3 additions & 0 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,9 @@ def _stack_peephole_opts(assembly):
if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS:
changed = True
del assembly[i]
if assembly[i] == "DUP1" and assembly[i + 1] == "SWAP1":
changed = True
del assembly[i + 1]
i += 1

return changed
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def request_analysis(self, analysis_cls: Type[IRAnalysis], *args, **kwargs):
if analysis_cls in self.analyses_cache:
return self.analyses_cache[analysis_cls]
analysis = analysis_cls(self, self.function)
self.analyses_cache[analysis_cls] = analysis
analysis.analyze(*args, **kwargs)

self.analyses_cache[analysis_cls] = analysis
return analysis

def invalidate_analysis(self, analysis_cls: Type[IRAnalysis]):
Expand Down
8 changes: 7 additions & 1 deletion vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vyper.utils import OrderedSet
from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis
from vyper.venom.analysis.liveness import LivenessAnalysis
from vyper.venom.basicblock import IRInstruction, IRVariable
from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable
from vyper.venom.function import IRFunction


Expand All @@ -20,6 +20,12 @@ def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
def get_uses(self, op: IRVariable) -> OrderedSet[IRInstruction]:
return self._dfg_inputs.get(op, OrderedSet())

def get_uses_in_bb(self, op: IRVariable, bb: IRBasicBlock):
"""
Get uses of a given variable in a specific basic block.
"""
return [inst for inst in self.get_uses(op) if inst.parent == bb]

# the instruction which produces this variable.
def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]:
return self._dfg_outputs.get(op)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _calculate_out_vars(self, bb: IRBasicBlock) -> bool:
Compute out_vars of basic block.
Returns True if out_vars changed
"""
out_vars = bb.out_vars
out_vars = bb.out_vars.copy()
bb.out_vars = OrderedSet()
for out_bb in bb.cfg_out:
target_vars = self.input_vars_from(bb, out_bb)
Expand Down
65 changes: 59 additions & 6 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ class IRInstruction:
# set of live variables at this instruction
liveness: OrderedSet[IRVariable]
parent: "IRBasicBlock"
fence_id: int
annotation: Optional[str]
ast_source: Optional[IRnode]
error_msg: Optional[str]
Expand All @@ -219,7 +218,6 @@ def __init__(
self.operands = list(operands) # in case we get an iterator
self.output = output
self.liveness = OrderedSet()
self.fence_id = -1
self.annotation = None
self.ast_source = None
self.error_msg = None
Expand All @@ -236,6 +234,22 @@ def is_commutative(self) -> bool:
def is_bb_terminator(self) -> bool:
return self.opcode in BB_TERMINATORS

@property
def is_phi(self) -> bool:
return self.opcode == "phi"

@property
def is_param(self) -> bool:
return self.opcode == "param"

@property
def is_pseudo(self) -> bool:
"""
Check if instruction is pseudo, i.e. not an actual instruction but
a construct for intermediate representation like phi and param.
"""
return self.is_phi or self.is_param

def get_read_effects(self):
return effects.reads.get(self.opcode, effects.EMPTY)

Expand Down Expand Up @@ -321,6 +335,20 @@ def get_ast_source(self) -> Optional[IRnode]:
return inst.ast_source
return self.parent.parent.ast_source

def str_short(self) -> str:
s = ""
if self.output:
s += f"{self.output} = "
opcode = f"{self.opcode} " if self.opcode != "store" else ""
s += opcode
operands = self.operands
if opcode not in ["jmp", "jnz", "invoke"]:
operands = list(reversed(operands))
s += ", ".join(
[(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands]
)
return s

def __repr__(self) -> str:
s = ""
if self.output:
Expand All @@ -337,10 +365,7 @@ def __repr__(self) -> str:
if self.annotation:
s += f" <{self.annotation}>"

if self.liveness:
return f"{s: <30} # {self.liveness}"

return s
return f"{s: <30}"


def _ir_operand_from_value(val: Any) -> IROperand:
Expand Down Expand Up @@ -477,6 +502,34 @@ def remove_instruction(self, instruction: IRInstruction) -> None:
def clear_instructions(self) -> None:
self.instructions = []

@property
def phi_instructions(self) -> Iterator[IRInstruction]:
for inst in self.instructions:
if inst.opcode == "phi":
yield inst
else:
return

@property
def non_phi_instructions(self) -> Iterator[IRInstruction]:
return (inst for inst in self.instructions if inst.opcode != "phi")

@property
def param_instructions(self) -> Iterator[IRInstruction]:
for inst in self.instructions:
if inst.opcode == "param":
yield inst
else:
return

@property
def pseudo_instructions(self) -> Iterator[IRInstruction]:
return (inst for inst in self.instructions if inst.is_pseudo)

@property
def body_instructions(self) -> Iterator[IRInstruction]:
return (inst for inst in self.instructions[:-1] if not inst.is_pseudo)

def replace_operands(self, replacements: dict) -> None:
"""
Update operands with replacements.
Expand Down
6 changes: 6 additions & 0 deletions vyper/venom/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class Effects(Flag):
BALANCE = auto()
EXTCODE = auto()

def __iter__(self):
# python3.10 doesn't have an iter implementation. we can
# remove this once we drop python3.10 support.
return (m for m in self.__class__.__members__.values() if m in self)


EMPTY = Effects(0)
ALL = ~EMPTY
Expand Down Expand Up @@ -68,6 +73,7 @@ class Effects(Flag):
"revert": MEMORY,
"return": MEMORY,
"sha3": MEMORY,
"sha3_64": MEMORY,
"msize": MSIZE,
}

Expand Down
167 changes: 112 additions & 55 deletions vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,138 @@
from collections import defaultdict

import vyper.venom.effects as effects
from vyper.utils import OrderedSet
from vyper.venom.analysis import DFGAnalysis
from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable
from vyper.venom.analysis import DFGAnalysis, IRAnalysesCache, LivenessAnalysis
from vyper.venom.basicblock import IRBasicBlock, IRInstruction
from vyper.venom.function import IRFunction
from vyper.venom.passes.base_pass import IRPass


class DFTPass(IRPass):
function: IRFunction
inst_order: dict[IRInstruction, int]
inst_order_num: int
inst_offspring: dict[IRInstruction, OrderedSet[IRInstruction]]
visited_instructions: OrderedSet[IRInstruction]
ida: dict[IRInstruction, OrderedSet[IRInstruction]]

def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
super().__init__(analyses_cache, function)
self.inst_offspring = {}

def run_pass(self) -> None:
self.inst_offspring = {}
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()

self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
basic_blocks = list(self.function.get_basic_blocks())

self.function.clear_basic_blocks()
for bb in basic_blocks:
self._process_basic_block(bb)

self.analyses_cache.invalidate_analysis(LivenessAnalysis)

def _process_basic_block(self, bb: IRBasicBlock) -> None:
self.function.append_basic_block(bb)

self._calculate_dependency_graphs(bb)
self.instructions = list(bb.pseudo_instructions)
non_phi_instructions = list(bb.non_phi_instructions)

self.visited_instructions = OrderedSet()
for inst in non_phi_instructions:
self._calculate_instruction_offspring(inst)

# Compute entry points in the graph of instruction dependencies
entry_instructions: OrderedSet[IRInstruction] = OrderedSet(non_phi_instructions)
for inst in non_phi_instructions:
to_remove = self.ida.get(inst, OrderedSet())
if len(to_remove) > 0:
entry_instructions.dropmany(to_remove)

entry_instructions_list = list(entry_instructions)

def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: int = 0):
for op in inst.get_outputs():
assert isinstance(op, IRVariable), f"expected variable, got {op}"
uses = self.dfg.get_uses(op)
# Move the terminator instruction to the end of the list
self._move_terminator_to_end(entry_instructions_list)

for uses_this in uses:
if uses_this.parent != inst.parent or uses_this.fence_id != inst.fence_id:
# don't reorder across basic block or fence boundaries
continue
self.visited_instructions = OrderedSet()
for inst in entry_instructions_list:
self._process_instruction_r(self.instructions, inst)

# if the instruction is a terminator, we need to place
# it at the end of the basic block
# along with all the instructions that "lead" to it
self._process_instruction_r(bb, uses_this, offset)
bb.instructions = self.instructions
assert bb.is_terminated, f"Basic block should be terminated {bb}"

def _move_terminator_to_end(self, instructions: list[IRInstruction]) -> None:
terminator = next((inst for inst in instructions if inst.is_bb_terminator), None)
if terminator is None:
raise ValueError(f"Basic block should have a terminator instruction {self.function}")
instructions.remove(terminator)
instructions.append(terminator)

def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInstruction):
if inst in self.visited_instructions:
return
self.visited_instructions.add(inst)
self.inst_order_num += 1

if inst.is_bb_terminator:
offset = len(bb.instructions)

if inst.opcode == "phi":
# phi instructions stay at the beginning of the basic block
# and no input processing is needed
# bb.instructions.append(inst)
self.inst_order[inst] = 0
if inst.is_pseudo:
return

for op in inst.get_input_variables():
target = self.dfg.get_producing_instruction(op)
assert target is not None, f"no producing instruction for {op}"
if target.parent != inst.parent or target.fence_id != inst.fence_id:
# don't reorder across basic block or fence boundaries
continue
self._process_instruction_r(bb, target, offset)
children = list(self.ida[inst])

self.inst_order[inst] = self.inst_order_num + offset
def key(x):
cost = inst.operands.index(x.output) if x.output in inst.operands else 0
return cost - len(self.inst_offspring[x]) * 0.5

def _process_basic_block(self, bb: IRBasicBlock) -> None:
self.function.append_basic_block(bb)
# heuristic: sort by size of child dependency graph
children.sort(key=key)

for inst in bb.instructions:
inst.fence_id = self.fence_id
if inst.is_volatile:
self.fence_id += 1
for dep_inst in children:
self._process_instruction_r(instructions, dep_inst)

# We go throught the instructions and calculate the order in which they should be executed
# based on the data flow graph. This order is stored in the inst_order dictionary.
# We then sort the instructions based on this order.
self.inst_order = {}
self.inst_order_num = 0
for inst in bb.instructions:
self._process_instruction_r(bb, inst)
instructions.append(inst)

bb.instructions.sort(key=lambda x: self.inst_order[x])
def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None:
# ida: instruction dependency analysis
self.ida = defaultdict(OrderedSet)

def run_pass(self) -> None:
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
non_phis = list(bb.non_phi_instructions)

self.fence_id = 0
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()
#
# Compute dependency graph
#
last_write_effects: dict[effects.Effects, IRInstruction] = {}
last_read_effects: dict[effects.Effects, IRInstruction] = {}

basic_blocks = list(self.function.get_basic_blocks())
for inst in non_phis:
for op in inst.operands:
dep = self.dfg.get_producing_instruction(op)
if dep is not None and dep.parent == bb:
self.ida[inst].add(dep)

self.function.clear_basic_blocks()
for bb in basic_blocks:
self._process_basic_block(bb)
write_effects = inst.get_write_effects()
read_effects = inst.get_read_effects()

for write_effect in write_effects:
if write_effect in last_read_effects:
self.ida[inst].add(last_read_effects[write_effect])
last_write_effects[write_effect] = inst

for read_effect in read_effects:
if read_effect in last_write_effects and last_write_effects[read_effect] != inst:
self.ida[inst].add(last_write_effects[read_effect])
last_read_effects[read_effect] = inst

def _calculate_instruction_offspring(self, inst: IRInstruction):
if inst in self.inst_offspring:
return self.inst_offspring[inst]

self.inst_offspring[inst] = self.ida[inst].copy()

deps = self.ida[inst]
for dep_inst in deps:
assert inst.parent == dep_inst.parent
if dep_inst.opcode == "store":
continue
res = self._calculate_instruction_offspring(dep_inst)
self.inst_offspring[inst] |= res

return self.inst_offspring[inst]
8 changes: 6 additions & 2 deletions vyper/venom/venom_to_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,14 @@ def dup(self, assembly, stack, depth):
assembly.append(_evm_dup_for(depth))

def swap_op(self, assembly, stack, op):
return self.swap(assembly, stack, stack.get_depth(op))
depth = stack.get_depth(op)
assert depth is not StackModel.NOT_IN_STACK, f"Cannot swap non-existent operand {op}"
return self.swap(assembly, stack, depth)

def dup_op(self, assembly, stack, op):
self.dup(assembly, stack, stack.get_depth(op))
depth = stack.get_depth(op)
assert depth is not StackModel.NOT_IN_STACK, f"Cannot dup non-existent operand {op}"
self.dup(assembly, stack, depth)


def _evm_swap_for(depth: int) -> str:
Expand Down

0 comments on commit c32b9b4

Please sign in to comment.