Source code for stgraph.compiler.passes.fusion

from collections import deque
from .cf import CF
from .cse import CSE
from ..utils import ValType, ParallelMode, is_const_scalar, FusionType
from ..program import Program
import copy
from ..execution_unit import ExecutionUnit
from .dependency_analysis import dep_program
from datetime import datetime

from stgraph.compiler.debugging.stgraph_logger import print_log

[docs]class FusionStateMachine(): state_trans = { 0 : { 'a2d' : 1, 'a2s' : 1, 'd' : 3, 's' : 5 , 'e' : 1}, 1 : { 'e' : 1, 's' : 2, 'd' : 2, }, 2 : { 's' : 2, 'd' : 2, }, 3 : { 'd' : 3, 'a2d' : 1, 's' : 4, }, 4 : { 's' : 4, 'd' : 4, }, 5 : { 's' : 5, 'a2s' : 1, 'd' : 4, } } def __init__(self, init_stmt=None): self.cur = 0 if init_stmt: self.accept(init_stmt)
[docs] def accept(self, stmt): trans = FusionStateMachine.stmt_to_trans(stmt) return trans in FusionStateMachine.state_trans[self.cur]
[docs] def advance(self, stmt): trans = FusionStateMachine.stmt_to_trans(stmt) if trans in FusionStateMachine.state_trans[self.cur]: self.cur = FusionStateMachine.state_trans[self.cur][trans] return True return False
[docs] @staticmethod def stmt_to_trans(stmt): trans = '' if stmt.is_agg(): trans = 'a2d' if stmt.ret.is_dstvar() else 'a2s' elif stmt.is_edgewise(): trans = 'e' elif stmt.is_src(): trans = 's' elif stmt.is_dst(): trans = 'd' else: raise NotImplementedError('Unknown stmt graph type ' + stmt) return trans
[docs] def current_fusion_type(self): if self.cur in {3, 4, 5}: return FusionType.NN elif self.cur in {1, 2}: return FusionType.NEAN else: return FusionType.NOT_FUSIBLE
[docs]def mergable(prog1, prog2): var_id1 = {var.id for var in prog1.input_vars()} var_id2 = {var.id for var in prog2.input_vars()} return var_id1.issubset(var_id2) or var_id2.issubset(var_id1)
#return var_id1 == var_id2
[docs]def merge_program(prog_list): share_input_map = {} for i in range(len(prog_list)): share_input_map[i] = set() for j in range(i+1, len(prog_list)): if mergable(prog_list[i], prog_list[j]): share_input_map[i].add(j) merged_set = set() ret_list = [] for i, shared_set in share_input_map.items(): if i in merged_set: continue merged_set.add(i) ret_list.append(Program()) ret_list[-1].copy_append_prog(prog_list[i]) for j in shared_set: ret_list[-1].copy_append_prog(prog_list[j]) merged_set.add(j) for prog in ret_list: '''Remove redundant computation in fused program''' CSE(prog) return ret_list
[docs]def find_var(var, prog_list): for prog in prog_list: ret = prog.find_ret_var_by_id(var.id) if ret: return ret return None
[docs]def fusable(downstream_s, upstream_s, stmt2state_machine, new_fsm): if 'gtypecast' in upstream_s.op_name.lower(): # type casts are fusion breaker return False a = downstream_s.is_supported() b = upstream_s.is_supported() if a and b: fsm = stmt2state_machine[downstream_s] if fsm.accept(upstream_s): if new_fsm: nfsm = copy.deepcopy(fsm) nfsm.advance(upstream_s) stmt2state_machine[upstream_s] = nfsm else: fsm.advance(upstream_s) stmt2state_machine[upstream_s] = fsm return True elif downstream_s.is_nodewise() and upstream_s.is_nodewise(): return True return False
[docs]def merge_stmt(cur_stmt, p_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm): if 'gtypecast' in cur_stmt.op_name.lower(): stmt_stack.append(p_stmt) return if fusable(cur_stmt, p_stmt, stmt2state_machine, new_fsm): if p_stmt in stmt2fused_prog: if cur_stmt in stmt2fused_prog: cur_prog = stmt2fused_prog[cur_stmt] if stmt2fused_prog[p_stmt] == cur_prog: return stmt2fused_prog[p_stmt].copy_append_prog(cur_prog) for stmt in cur_prog: stmt2fused_prog[stmt] = stmt2fused_prog[p_stmt] cur_prog.clear_stmts() else: stmt2fused_prog[p_stmt].copy_append_stmt(cur_stmt) stmt2fused_prog[cur_stmt] = stmt2fused_prog[p_stmt] else: if cur_stmt in stmt2fused_prog: stmt2fused_prog[cur_stmt].copy_prepend_stmt(p_stmt) stmt2fused_prog[p_stmt] = stmt2fused_prog[cur_stmt] else: var_prog = Program() var_prog.copy_append_stmts([p_stmt, cur_stmt]) stmt2fused_prog[p_stmt] = var_prog stmt2fused_prog[cur_stmt] = var_prog prog_list.append(var_prog) stmt_stack.append(p_stmt) else: if cur_stmt not in stmt2fused_prog: var_prog = Program() var_prog.copy_append_stmt(cur_stmt) stmt2fused_prog[cur_stmt] = var_prog prog_list.append(var_prog) if p_stmt not in stmt2fused_prog: var_stack.append(p_stmt.ret)
[docs]def unit_independent(u1, u2): return not u1.depends_on(u2) and not u2.depends_on(u1)
[docs]def merge_independent(exec_units): # Merge units that share the same inputs and has no dependency among each other # Check dependency and propose candidate candidates = {} for i in range(len(exec_units)): candidates[i] = [] for j in range(i+1, len(exec_units)): if unit_independent(exec_units[i], exec_units[j]) and exec_units[i].compiled == exec_units[j].compiled: candidates[i].append(j) # Merge candidates merged_units = [] merged_set = set() for tar_id, src_id_list in candidates.items(): if tar_id in merged_set: continue tar_unit = exec_units[tar_id] for sid in src_id_list: src_unit = exec_units[sid] tar_unit.merge_with_independent_unit(src_unit) merged_set.add(sid) merged_units.append(tar_unit) return merged_units
[docs]def fuse(progs, outputs): ''' Generate one/multiple execution units from one or more programs, which are used for code generation. Parallel mode of execution unit is determined by the ValType of ret var. ''' if len(progs) == 0: return progs if len(progs) > 1: print_log("[orange1 bold]Fusion[/orange1 bold]: Program fusion started") progs = merge_program(progs) print_log("[orange1 bold]Fusion[/orange1 bold]: Program fusion completed") print_log("[orange1 bold]Fusion[/orange1 bold]: Operator fusion started") # Starting from each output var, fuse as many operators as possible according to dependenies. # Use DFS-manner to allow maximal locality of statements stmt2fused_prog = {} stmt2state_machine = {} prog_list = [] var_list = [] for var in outputs: ret = find_var(var, progs) if ret: var_list.append(ret) var_list.sort(key=lambda var: var.int_id) var_stack = deque(var_list) while var_stack: var = var_stack.pop() stmt_stack = deque([var.stmt]) while stmt_stack: cur_stmt = stmt_stack.pop() dep_stmts = [] for arg in cur_stmt.args: if not is_const_scalar(arg) and arg.stmt is not None: dep_stmts.append(arg.stmt) if len(dep_stmts) == 0: if cur_stmt not in stmt2fused_prog: var_prog = Program() var_prog.copy_append_stmt(cur_stmt) stmt2fused_prog[cur_stmt] = var_prog prog_list.append(var_prog) continue if cur_stmt not in stmt2state_machine: stmt2state_machine[cur_stmt] = FusionStateMachine(cur_stmt) if len(dep_stmts) == 1: p_stmt = dep_stmts[0] merge_stmt(cur_stmt, p_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm=False) elif len(dep_stmts) == 2: l_stmt = dep_stmts[0] r_stmt = dep_stmts[1] if l_stmt.depends_on(r_stmt): merge_stmt(cur_stmt, l_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm=False) elif r_stmt.depends_on(l_stmt): merge_stmt(cur_stmt, r_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm=False) else: merge_stmt(cur_stmt, r_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm=True) merge_stmt(cur_stmt, l_stmt, stmt2fused_prog, stmt2state_machine, stmt_stack, var_stack, prog_list, new_fsm=True) else: raise NotImplementedError('Currenty we assume num of oprands of all operators is no larger than 2') print_log("[orange1 bold]Fusion[/orange1 bold]: Operator fusion completed") prog_l = [] for p in prog_list: prog = Program() prog.copy_append_stmts(sorted(p, key=lambda x : x.ret.int_id)) prog_l.append(prog) prog_blks = [prog for prog in reversed(prog_l) if len(prog) > 0] print_log("[orange1 bold]Fusion[/orange1 bold]: Constructing execution unit") exe_units = [] for prog in prog_blks: args = set() tmps = set() seen_vars = set() compiled = False for stmt in prog: if not stmt.is_nodewise(): compiled = True for arg in stmt.args: if arg not in seen_vars and not is_const_scalar(arg): seen_vars.add(arg) args.add(arg) seen_vars.add(stmt.ret) tmps.add(stmt.ret) exe_units.append(ExecutionUnit(args, tmps, prog, compiled)) # Connecting units by setting their ret vars for i,b in enumerate(exe_units): if i > 0: for arg in b._args: for j in range(i): if arg in exe_units[j].tmps: exe_units[j].add_ret_val(arg) exe_units[i].add_parent_unit(exe_units[j]) # Set the final outputs for each execution unit for unit in exe_units: for var in outputs: if var in unit.tmps: unit.add_ret_val(var) # Materialize aggregation results for backward use (s.b.j. to mem-planning) except sum for u in exe_units: for stmt in u.program: if stmt.is_agg(): if 'sum' not in stmt.op_name.lower(): u.add_ret_val(stmt.ret) if stmt.ret.is_dstvar(): u.set_parallel_mode(ParallelMode.DstParallel) elif stmt.ret.is_srcvar(): u.set_parallel_mode(ParallelMode.SrcParallel) # Sort by return id in order to satisfy the dependency between execution unit # Correctness remains quesationable exe_units.sort(key=lambda x:x.max_ret_id()) exe_units = merge_independent(exe_units) print_log("[orange1 bold]Fusion[/orange1 bold]: Constructing execution unit completed") return exe_units