Source code for stgraph.compiler.execution_unit

"""The fundamental execution unit of STGraph"""

import math
import snoop
from .code_gen.cuda_driver import *
from .code_gen.kernel_context import KernelContext, LinearizedKernelContext
from .utils import is_const_scalar, ParallelMode, MAX_THREAD_PER_BLOCK, MAX_BLOCK 
from .code_gen.cuda_error import ASSERT_DRV

from stgraph.compiler.debugging.stgraph_logger import print_log

# TODO: remove
import numpy as np

[docs]class ExecutionUnit(object): unit_count = 0 def __init__(self, args, tmps, prog, compiled=False): self._args = args self._tmps = tmps self._compiled = compiled self._prog = prog self._rets = set() self._kernel_name = 'K' + str(ExecutionUnit.unit_count) self._parallel_mode = None self._unit_rets_cached = None self._unit_args_cached = None self._max_dims_cached = None self._parent_units = set() if self.feature_size() >= 0: self._template_name = 'fa' else: self._template_name = 'v2' ExecutionUnit.unit_count += 1 def __str__(self): return '\n-----------\nparallel_mode:{p_mode}\nargs:{args}\nrets:{rets}\ntmps:{tmps}\ncompiled:{compiled}\nprog:{prog}' \ .format(args=str(self._args), rets=str(self._rets), compiled=str(self._compiled), prog=str(self._prog), all_vars=self.get_all_vars(), p_mode=self._parallel_mode, tmps=self._tmps) def __repr__(self): return str(self) def __hash__(self): return hash(self._kernel_name) def __eq__(self, other): return isinstance(other, ExecutionUnit) and self._kernel_name == other._kernel_name
[docs] def use_fa_tmpl(self): return self._template_name == 'fa'
[docs] def create_context(self, index_type): if self.use_fa_tmpl(): return LinearizedKernelContext(self, index_type) else: return KernelContext(self, index_type)
[docs] def set_parallel_mode(self, mode): assert isinstance(mode, ParallelMode) self._parallel_mode = mode
[docs] def parallel_mode(self): return self._parallel_mode
[docs] def max_dims(self): if not self._max_dims_cached: # Assumption: vars are at most two dimensions for var in self.get_all_vars(): shape = var.var_shape if len(shape) == 1: if self._max_dims_cached and len(self._max_dims_cached) != 1: #raise NotImplementedError('Var must have consistent feature dimenstions' + str(var) + ' cached ' +str(self._max_dims_cached)) continue if not self._max_dims_cached: self._max_dims_cached = [1] self._max_dims_cached = [max(self._max_dims_cached[-1], shape[0])] elif len(shape) == 2: if self._max_dims_cached and len(self._max_dims_cached) != 2: #raise NotImplementedError('Var must have consistent feature dimenstions' + str(var) + ' cached ' +str(self._max_dims_cached)) continue if not self._max_dims_cached: self._max_dims_cached = [1, 1] self._max_dims_cached = [max(self._max_dims_cached[i], shape[i]) for i in range(2)] else: raise NotImplementedError('Have not suported the case when local var dim larger than 2') return self._max_dims_cached
[docs] def feature_size(self): s = 1 for d in self.max_dims(): s = s * d return s
[docs] def calculate_kernel_params_fa(self, num_nodes): feat_size = self.feature_size() min_threads = 64 max_threads = 256 if feat_size >= min_threads: nthrs = min(max_threads, feat_size) thrs_per_group = nthrs nodes_per_blk = 1 nblks = num_nodes else: nthrs = min_threads thrs_per_group = max(1, self.first_pow2_less_than_n(feat_size, nthrs)) nodes_per_blk = max(2, nthrs/thrs_per_group) nblks = (num_nodes+nodes_per_blk-1)//nodes_per_blk return int(nblks), int(nthrs), int(thrs_per_group), int(nodes_per_blk)
[docs] def calculate_kernel_params(self, num_nodes): kernel_params = self.calculate_kernel_launch_params(num_nodes) tile_sizes = self.compute_tile_sizes(kernel_params[-2], kernel_params[-1]) return kernel_params, tile_sizes
[docs] def first_pow2_less_than_n(self, n, upper_bound): while upper_bound > n: upper_bound = upper_bound // 2 return upper_bound
[docs] def calculate_kernel_launch_params(self, num_nodes): max_dims = self.max_dims() total_dim = 1 for dim in max_dims: total_dim = total_dim*dim if total_dim < MAX_THREAD_PER_BLOCK: bdim_x = max_dims[-1] if len(max_dims) == 1: bdim_y = 1 else: bdim_y = max_dims[-2] gdim_x = 1 gdim_y = min(num_nodes, MAX_BLOCK) else: bdim_x = min(max_dims[-1], 32) if len(max_dims) == 1: bdim_y = 1 else: bdim_y = min(max_dims[-2], 32) gdim_x = int((max_dims[-1] + bdim_x -1) /bdim_x) gdim_y = min(num_nodes, MAX_BLOCK) return (gdim_x, gdim_y, bdim_x, bdim_y)
[docs] def compute_tile_sizes(self, blockDimx, blockDimy): WARP_SIZE = 32 thread_dims = [blockDimx, blockDimy] prod = blockDimx * blockDimy if prod < 2 * WARP_SIZE: return thread_dims# means no tiling at all NWARP = prod/WARP_SIZE N_Ar = 0 N_Aw = 0 N_br = 0 N_bw = 0 kernel_args = self.kernel_args() for stmt in self._prog: for arg in stmt.args: if arg in kernel_args: if arg.var_shape == self._max_dims_cached: N_Ar += 1 else: N_br += 1 if stmt.ret in kernel_args: if stmt.ret.var_shape == self._max_dims_cached: N_Aw += 1 else: N_bw += 1 n = thread_dims[-1] m = thread_dims[-2] cof = N_Ar if n <= WARP_SIZE else WARP_SIZE if N_bw > 0: x1, x2 = self.nearest_pow2(math.sqrt(cof*n/N_bw)) t1 = N_Ar * n/x1 + N_bw * x1 t2 = N_Ar * n/x2 + N_bw * x2 tile_sizex = x1 if t1 <= t2 else t2 else: tile_sizex = WARP_SIZE if WARP_SIZE < blockDimx else blockDimx tile_sizey = int(WARP_SIZE/tile_sizex) assert tile_sizex * tile_sizey == WARP_SIZE while tile_sizey > blockDimy: tile_sizey = tile_sizey / 2 tile_sizex = tile_sizex * 2 return [int(tile_sizex), int(tile_sizey)]
[docs] def nearest_pow2(self, targ): i = 1 while 2 * i < targ: i = i*2 diff1 = i - targ diff2 = targ - i/2 return i, i/2
[docs] def unit_args(self): if not self._unit_args_cached: self._unit_args_cached = sorted([arg for arg in self._args if not is_const_scalar(arg)], key=lambda x : x.id) return self._unit_args_cached
[docs] def unit_rets(self): if not self._unit_rets_cached: self._unit_rets_cached = sorted([ret for ret in self._rets if not is_const_scalar(ret)], key=lambda x: x.id) return self._unit_rets_cached
[docs] def all_rets(self): return set([stmt.ret for stmt in self.program])
[docs] def kernel_args(self): return self.unit_args() + self.unit_rets()
[docs] def materilized_vars(self): if self.compiled: return self._rets.union(self._args) else: return self.get_all_vars()
[docs] def get_all_args(self): ''' return the set of all vars used/returned in the program of this exec unit ''' var_set = set() for stmt in self._prog: for var in stmt.args: if not is_const_scalar(var): var_set.add(var) return var_set
[docs] def get_all_vars(self): ''' return the set of all vars used/returned in the program of this exec unit ''' var_set = set() for stmt in self._prog: for var in stmt.args: if not is_const_scalar(var): var_set.add(var) var_set.add(stmt.ret) return var_set
[docs] def add_ret_val(self, ret_val): self._rets.add(ret_val)
[docs] def max_ret_id(self): return sorted([ret.int_id for ret in self.unit_rets()])[-1]
[docs] def prepare_compiled_kernel(self, graph, compiled_module): if self.parallel_mode() == ParallelMode.DstParallel: row_offsets_ptr = graph.fwd_row_offset_ptr col_indices_ptr = graph.fwd_column_indices_ptr eids_ptr = graph.fwd_eids_ptr node_ids_ptr = graph.fwd_node_ids_ptr else: #TODO: Will probably have to change this so that this accesses #backward row_offset, col_indices, eids row_offsets_ptr = graph.bwd_row_offset_ptr col_indices_ptr = graph.bwd_column_indices_ptr eids_ptr = graph.bwd_eids_ptr node_ids_ptr = graph.bwd_node_ids_ptr max_dims = [1, 1] if len(self.max_dims()) == 1: max_dims[-1] = self.max_dims()[-1] elif len(self.max_dims()) == 2: max_dims = self.max_dims() else: raise NotImplementedError('Feature dimension larger than 2 are not supported.') num_nodes = graph.get_num_nodes() if self.use_fa_tmpl(): launch_config = self.calculate_kernel_params_fa(num_nodes) print_log(f'[yellow bold]Execution Unit[/yellow bold]: Generating FA Kernel with num_nodes: {str(num_nodes)}, launch_config: {str(launch_config)}') self._K = FeatureAdaptiveKernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr, max_dims, self._kernel_name, compiled_module, launch_config) else: launch_config, tile_sizes = self.calculate_kernel_params(num_nodes) print_log(f'[yellow bold]Execution Unit[/yellow bold]: Generating V2 Kernel with num_nodes: {str(num_nodes)}, launch_config: {str(launch_config)}, tile_size: {str(tile_sizes)}, max_dims: {str(max_dims)}') self._K = V2Kernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, max_dims, self._kernel_name, compiled_module, launch_config, tile_sizes)
[docs] def reset_graph_info(self, graph): if self.parallel_mode() == ParallelMode.DstParallel: row_offsets_ptr = graph.fwd_row_offset_ptr col_indices_ptr = graph.fwd_column_indices_ptr eids_ptr = graph.fwd_eids_ptr node_ids_ptr = graph.fwd_node_ids_ptr else: row_offsets_ptr = graph.bwd_row_offset_ptr col_indices_ptr = graph.bwd_column_indices_ptr eids_ptr = graph.bwd_eids_ptr node_ids_ptr = graph.bwd_node_ids_ptr self._K.reset_graph_info(graph.get_num_nodes(), row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr)
[docs] def kernel_run(self, tensor_list): assert self._K, 'Must call prepare_compiled_kernel before call kernel_run.' self._K.run(tensor_list)
[docs] def merge_with_independent_unit(self, other): # union their inputs and outputs self._unit_args_cached = None self._unit_rets_cached = None self._max_dims_cached = None self._args = self._args.union(other._args) self._rets = self._rets.union(other._rets) self._tmps = self._tmps.union(other._tmps) # merge stmts first_agg = None for s in self._prog: if s.is_agg(): first_agg = s break self._prog.insert_stmts_before(first_agg, list(other._prog)) # adjust parallel mode dst_parallel_count = 0 for s in self._prog: if s.is_agg(): if s.ret.is_dstvar(): dst_parallel_count += 1 else: dst_parallel_count -= 1 if dst_parallel_count >= 0: self._parallel_mode = ParallelMode.DstParallel else: self._parallel_mode = ParallelMode.SrcParallel return self
[docs] def depends_on(self, other): if self == other or self.is_child_of(other): return True else: return any([u.depends_on(other) for u in self._parent_units])
[docs] def add_parent_unit(self, parent): self._parent_units.add(parent)
[docs] def has_parent(self) : return len(self._parent_units) > 0
[docs] def is_child_of(self, other): return other in self._parent_units
@property def program(self): return self._prog @property def tmps(self): return self._tmps @property def compiled(self): return self._compiled @property def kernel_name(self): return self._kernel_name
[docs]class Kernel():
[docs] def reset_graph_info(self, num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr): self.const_kernel_args[0] = c_void_p(row_offsets_ptr) self.const_kernel_args[1] = c_void_p(eids_ptr) self.const_kernel_args[2] = c_void_p(col_indices_ptr) self.const_kernel_args[3] = c_void_p(node_ids_ptr) self.const_kernel_args[4] = c_int(num_nodes) for i in range(5): self.const_kernel_ptrs[i] = c_void_p(addressof(self.const_kernel_args[i]))
[docs] def run(self, tensor_list): try: kernel_ptrs = [c_void_p(addressof(arg)) for arg in tensor_list] + self.const_kernel_ptrs params = (c_void_p * len(kernel_ptrs))(*kernel_ptrs) ret = cuLaunchKernel(self.K, self.launch_config[0], self.launch_config[1], self.launch_config[2], self.launch_config[3], self.launch_config[4], self.launch_config[5], 0, None, params, 0) ASSERT_DRV(ret) except Exception as e: raise e
[docs]class V2Kernel(Kernel): r"""The Version 2 Kernel This class contains the parameters for the second version of the kernel written for STGraph Parameters ---------- num_nodes : int Number of nodes present in the graph row_offsets_ptr : c_type Pointer to the row offset array col_indices_ptr : c_type Pointer to the column indicies array Attributes ---------- scalar_args : list[c_types] List of the scalar arguments passed to the kernel launch_config : list[int] List of the kernel launch configurations """ def __init__(self, num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, max_dims, kernel_name, compiled_module, launch_config, tile_sizes): self.scalar_args = [c_int(num_nodes), c_int(max_dims[1]), c_int(max_dims[0]), c_int(tile_sizes[0]), c_int(tile_sizes[1])] self.const_kernel_args = [c_void_p(row_offsets_ptr), c_void_p(eids_ptr), c_void_p(col_indices_ptr)] + self.scalar_args self.const_kernel_ptrs = [c_void_p(addressof(v)) for v in self.const_kernel_args] ret, self.K = cuModuleGetFunction(compiled_module, kernel_name.encode()) ASSERT_DRV(ret) self.launch_config = launch_config[0],launch_config[1], 1, launch_config[2], launch_config[3],1
[docs]class FeatureAdaptiveKernel(Kernel): def __init__(self, num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr, max_dims, kernel_name, compiled_module, launch_config): self.scalar_args = [c_int(num_nodes), c_int(max_dims[1]), c_int(max_dims[0]), c_int(launch_config[2]), c_int(launch_config[3])] self.const_kernel_args = [c_void_p(row_offsets_ptr), c_void_p(eids_ptr), c_void_p(col_indices_ptr), c_void_p(node_ids_ptr)] + self.scalar_args self.const_kernel_ptrs = [c_void_p(addressof(v)) for v in self.const_kernel_args] ret, self.K = cuModuleGetFunction(compiled_module, kernel_name.encode()) ASSERT_DRV(ret) self.launch_config = launch_config[0],1,1,launch_config[1],1,1