from ..utils import is_const_scalar, ParallelMode
from collections import namedtuple
from .compiler import compile_cuda
from jinja2 import Environment, PackageLoader
EdgeInfo = namedtuple('EdgeInfo', ['load', 'compute', 'inner_write'])
NodeInfo = namedtuple('NodeInfo', ['load', 'compute', 'inner_write'])
ArgInfo = namedtuple('ArgInfo', ['name', 'type', 'is_ptr'])
AggInfo = namedtuple('AggInfo', ['init', 'compute', 'inner_write', 'outter_write'])
const_id = 0
[docs]def gen_arg_info(arg):
if is_const_scalar(arg):
global const_id
arg_info = ArgInfo(name='c'+str(const_id), type=str(type(arg)), is_ptr=False)
const_id += 1
else:
arg_info = ArgInfo(name=arg.id, type=str(arg.dtype_str), is_ptr=True)
return arg_info
[docs]def gen_agg_info(stmt, ctx):
m = stmt.gen_code(ctx)
if not m:
raise NotImplementedError('Cannot generate code for', stmt)
return AggInfo(**m)
[docs]def gen_edge_info(stmt, ctx):
m = stmt.gen_code(ctx)
if not m:
raise NotImplementedError('Cannot generate code for', stmt)
return EdgeInfo(**m)
[docs]def gen_node_info(stmt, ctx):
m = stmt.gen_code(ctx)
if not m:
raise NotImplementedError('Cannot generate code for', stmt)
return NodeInfo(**m)
[docs]def gen_code(exe_units, index_type, graph_type):
'''Generating cuda code by instantiate code template'''
if not isinstance(exe_units, list):
exe_units = [exe_units]
configs = []
for unit in exe_units:
if not unit.compiled:
continue
arginfos = []
nodeinfos = []
agginfos = []
edgeinfos = []
for var in unit.kernel_args():
arginfos.append(gen_arg_info(var))
ctx = unit.create_context(index_type)
after_agg = False
for stmt in unit.program:
ctx.set_stmt_ctx(stmt)
if stmt.is_agg():
agginfos.append(gen_agg_info(stmt, ctx))
after_agg = True
elif stmt.is_edgewise():
edgeinfos.append(gen_edge_info(stmt, ctx))
elif stmt.is_nodewise():
if after_agg:
nodeinfos.append(gen_node_info(stmt, ctx))
else:
edgeinfos.append(gen_edge_info(stmt, ctx))
dst_parallel = True if unit.parallel_mode() == ParallelMode.DstParallel else False
configs.append({
'kernel_name': unit.kernel_name,
'index_type' : index_type,
'args': arginfos,
'edges': edgeinfos,
'aggs': agginfos,
'nodes': nodeinfos,
'row_offset': 'dst_id' if dst_parallel else 'src_id',
'init_outter_offset': ctx.param_offset_init + (ctx.dst_var_offset_init if dst_parallel else ctx.src_var_offset_init),
'col_index': 'src_id' if dst_parallel else 'dst_id',
'init_inner_offset': (ctx.src_var_offset_init if dst_parallel else ctx.dst_var_offset_init) + ctx.edge_var_offset_init,
'template_name': ctx.template_name,
'graph_type': graph_type
})
return gen_cuda(configs)
[docs]def render_template(config, template_name):
env = Environment(
loader=PackageLoader("stgraph.compiler.code_gen"),
)
tpl = env.get_template("fa/{}.jinja".format(template_name))
return tpl.render(**config)
[docs]def gen_cuda(configs):
h = ''
for config in configs:
if config['template_name'] == 'fa':
if config['graph_type'] == 'csr':
rendered_tpl = render_template(config, "tpl_fa_csr")
elif config['graph_type'] == 'csr_unsorted':
rendered_tpl = render_template(config, "tpl_fa_csr_unsorted")
elif config['graph_type'] == 'pcsr':
rendered_tpl = render_template(config, "tpl_fa_pcsr")
elif config['graph_type'] == 'pcsr_unsorted':
rendered_tpl = render_template(config, "tpl_fa_pcsr_unsorted")
elif config['graph_type'] == 'gpma':
rendered_tpl = render_template(config, "tpl_fa_gpma")
elif config['graph_type'] == 'gpma_unsorted':
rendered_tpl = render_template(config, "tpl_fa_gpma_unsorted")
else:
raise NotImplementedError('{} Template for {} is not supported'.format(config['template_name'],config['graph_type']))
elif config['template_name'] == 'v2':
raise NotImplementedError('{} Template for {} is not supported'.format(config['template_name'],config['graph_type']))
else:
raise NotImplementedError('{} Template not supported'.format(config['template_name']))
h += rendered_tpl
return compile_cuda(h)