from sympy import *
import re
from collections import deque
from ..utils import is_const_scalar, bcast_dim, infer_val_type, var_prefix
from .dependency_analysis import dep_program
from .cse import CSE
from .dce import DCE
from ..program import Stmt, Var
from ..schema import Schema
from stgraph.compiler.debugging.stgraph_logger import print_log
[docs]def execute_sym_program(prog, sym_table, rmv_list):
for s in prog:
if 'mul' in s.op_name.lower():
l = sym_table[s.args[0].id] if not is_const_scalar(s.args[0]) else s.args[0]
r = sym_table[s.args[1].id] if not is_const_scalar(s.args[1]) else s.args[1]
sym_table[s.ret.id] = l * r
elif 'div'in s.op_name.lower():
l = sym_table[s.args[0].id] if not is_const_scalar(s.args[0]) else s.args[0]
r = sym_table[s.args[1].id] if not is_const_scalar(s.args[1]) else s.args[1]
sym_table[s.ret.id] = l / r
elif 'sum' in s.op_name.lower():
# sum and aggsum are fusion breaker.
# We ensure that before and after sum or aggsum is a series of mul/div,
# we leverage the distributive and associative nature of sum and mul, to ignore the sum operator
if 'agg' not in s.op_name.lower():
rmv_list.append(s)
l = sym_table[s.args[0].id] if not is_const_scalar(s.args[0]) else s.args[0]
sym_table[s.ret.id] = l
else:
print_log(f'[red bold]Peephole[/red bold]: Early stopping due to enouncter {str(s)}')
break
[docs]def generate_stmts_from_expr(expr, var_table):
preceding_neg = False
if expr[0] == '-':
preceding_neg = True
expr = expr[1:]
tok_list = re.split('([-*/])', expr)
tok_q = deque(tok_list)
stmt_list = []
while len(tok_q) > 1:
arg0 = var_table[tok_q.popleft()]
op = tok_q.popleft()
arg1 = var_table[tok_q.popleft()]
if op == '*':
op_name = 'Mul'
elif op == '/':
op_name = 'TrueDiv'
else:
raise NotImplementedError('op', op, 'is not supprted for PH optimization')
stmt_list.append(Stmt.create_binary_bcast_stmt(Schema(op_name), args=[arg0, arg1]))
ret = stmt_list[-1].ret
var_table[ret.id] = ret
tok_q.appendleft(ret.id)
if preceding_neg:
stmt_list.append(Stmt.create_binary_bcast_stmt(Schema('Mul'), args=[-1, stmt_list[-1].ret]))
return stmt_list
[docs]def shape_propogation(s):
if 'agg' not in s.op_name.lower():
# Effectively, we merge the sum op with aggsum op
dim = bcast_dim(s.args)
if dim != s.ret.var_shape:
s.ret.var_shape = dim
for stmt in s.ret.users:
shape_propogation(stmt)
[docs]def sum_propogation(sum_stmt):
ret = sum_stmt.ret
arg = sum_stmt.args[0]
for stmt in ret.users:
for i in range(len(stmt.args)):
old_arg = stmt.args[i]
if old_arg == ret:
# Replace sum ret with sum arg
stmt.args[i] = arg
shape_propogation(stmt)
[docs]def PH(BProg, known_vars, output_vars):
'''
Peephole optimization pass. Mofiy program in place.
It replaces unfusable expressions('sum' and 'aggsum') in forward and backward programs with known_vars
by applying various mathmatically equivelent tansformations.
'''
# Find candidate chain breaker;
candidate_vars = set()
var_table = {}
for s in BProg:
if 'sum' in s.op_name.lower() or 'agg' in s.op_name.lower():
if s.ret not in output_vars:
candidate_vars.add(s.ret)
var_table[s.ret.id] = s.ret
sym_table = {}
for var in known_vars:
sym_table[var.id] = Symbol(var.id)
var_table[var.id] = var
sum_map = {}
for var in known_vars:
dep_prog = dep_program(var, known_vars - set([var]))
sum_map[var.id] = []
execute_sym_program(dep_prog, sym_table, sum_map[var.id])
for var in candidate_vars:
dep_prog = dep_program(var, known_vars)
sum_map[var.id] = []
execute_sym_program(dep_prog, sym_table, sum_map[var.id])
simplifiable_expression = []
for cv in candidate_vars:
for kv in known_vars:
if cv.id in sym_table and kv.id in sym_table:
var1 = sym_table[cv.id]
var2 = sym_table[kv.id]
var3 = var1/var2
if 'mul' in str(type(var3)).lower() and str(var3).count(var_prefix) < str(var1).count(var_prefix) - 1:
simplifiable_expression.append((cv.id, str(var3) +'*'+kv.id, sum_map[cv.id]))
for tu in simplifiable_expression:
target_var = var_table[tu[0]]
stmts = generate_stmts_from_expr(tu[1], var_table)
BProg.insert_stmts_before(target_var.stmt, stmts)
# Replace last op
target_var.replace_all_uses_with(stmts[-1].ret, propogate_shape=False)
DCE(BProg, output_vars)
for s in tu[2]:
for st in BProg:
if st == s:
# Omit the last stmt as it's already be removed
sum_propogation(st)
st.remove_cur()
BProg.resort_vars()