Source code for stgraph.compiler.passes.visualize

import networkx as nx
from ..utils import is_const_scalar, OpType

count = 0
var_shape='plain'
stmt_shape='box'
color_map= {
    OpType.S : 'lightgreen',
    OpType.E : 'lightblue',
    OpType.D : 'lightyellow',
    OpType.A : 'red',
}

[docs]def plog_program(prog): plog_programs([prog])
[docs]def plot_programs(progs, filename='egl-dag'): global count G = nx.DiGraph() edges = [] stmt_count = {} G.add_node('src op', shape=stmt_shape, color='lightgreen', style='filled') G.add_node('edge op', shape=stmt_shape, color='lightblue', style='filled') G.add_node('dst op', shape=stmt_shape, color='lightyellow', style='filled') G.add_node('agg op', shape=stmt_shape, color='red', style='filled') for prog in progs: for stmt in prog: arg_nodes = [arg if is_const_scalar(arg) else arg.id for arg in stmt.args] stmt_node = str(stmt.op_name) if stmt_node not in stmt_count: stmt_count[stmt_node] = 0 else: stmt_count[stmt_node] += 1 stmt_node = str(stmt.op_name) + '-' + str(stmt_count[stmt_node]) ret_node = stmt.ret.id # Add nodes for arg in arg_nodes: G.add_node(arg, shape=var_shape) G.add_node(stmt_node, shape=stmt_shape, color=color_map[stmt.op_type], style='filled') G.add_node(ret_node, shape=var_shape) # Add edges for arg in arg_nodes: G.add_edge(arg, stmt_node) G.add_edge(stmt_node, ret_node) p=nx.drawing.nx_pydot.to_pydot(G) p.write_svg(filename + str(count) + '.svg') count += 1
compiled_color_map = { True: 'green', False: 'blue' }
[docs]def plot_exec_units(units, filename='egl-fused-dag'): global count G = nx.DiGraph() stmt_set = set() G.add_node('fused-and-compiled', shape=stmt_shape, color='green', style='filled') G.add_node('not compiled', shape=stmt_shape, color='blue', style='filled') for i, unit in enumerate(units): arg_nodes = [arg if is_const_scalar(arg) else arg.id for arg in unit.unit_args()] unit_node = unit.kernel_name + '\n' + '\n'.join([stmt.op_name for stmt in unit._prog]) ret_nodes = [ret.id for ret in unit.unit_rets()] # Add nodes for arg in arg_nodes: G.add_node(arg, shape=var_shape) G.add_node(unit_node, shape=stmt_shape, color=compiled_color_map[unit.compiled], style='filled') for ret in ret_nodes: G.add_node(ret, shape=var_shape) # Add edges for arg in arg_nodes: G.add_edge(arg, unit_node) for ret in ret_nodes: G.add_edge(unit_node, ret) p=nx.drawing.nx_pydot.to_pydot(G) p.write_svg(filename + str(count) + '.svg') count += 1