Source code for stgraph.compiler.registry

import abc
import sys, inspect
from collections import namedtuple
from .utils import ValType,is_const_scalar, WriteType, WriteLocation, infer_val_type
from .schema import Schema

from stgraph.compiler.debugging.stgraph_logger import print_log

GradInfo = namedtuple('GradInfo', ['targ', 'args', 'grad_x', 'op_schema'])

impl_registry = {}
cb_registry = {}
TMP_SUFFIX ='_tmp'

[docs]def register_or_look_up_backend_cb(stmt, cb): op_name = stmt.op_name.lower() if cb: if op_name not in cb_registry: cb_registry[op_name] = cb return cb else: if op_name in cb_registry: return cb_registry[op_name]
[docs]def look_up_registry(stmt): op_name = stmt.op_name.lower() op_type = stmt.op_type op_impl = None if op_name not in impl_registry: # We don't necessarily generate ops for Node-wise op as they can be # supported by backends if stmt.is_edgewise() and op_name not in cb_registry: print_log(f'[magenta bold]Registry[/magenta bold]: EdgeType op {op_name} is not registered in any registry!') else: op_impl = impl_registry[op_name] return op_impl
[docs]class OpImpl(abc.ABC): '''New ops need to inherit from this class with name "XXXOp"''' def __init__(self, fstmt, create_var_cb, create_stmt_cb): self.fstmt = fstmt self.create_var = create_var_cb self.create_stmt = create_stmt_cb
[docs] def grad(self, y, grad_y): ret_list = [] for pos, x in enumerate(self.fstmt.args): if not is_const_scalar(x) and x.requires_grad: ret_stmts = self.grad_impl(pos, x, y, grad_y) ret_list.append((x, ret_stmts)) return ret_list
[docs] def gen_var(self, var, ctx): kctx = ctx ctx = ctx.cur_stmt_ctx if is_const_scalar(var): return str(var) if var == self.ret: prefix = '' if 'agg' in var.stmt.op_name.lower() else var.dtype_str + ' ' return prefix + var.id + TMP_SUFFIX else: if var in ctx.kernel_arguments: return var.id + kctx.query_offset(var) else: return var.id + TMP_SUFFIX
[docs] def gen_write(self, ctx): kctx = ctx ctx = kctx.cur_stmt_ctx if ctx.write_type == WriteType.NONE: return ('inner_write', '') val = '' var = self.ret.id + kctx.query_offset(self.ret) delta = self.ret.id + TMP_SUFFIX if ctx.write_type == WriteType.ADD: val = '{var} += {delta};'.format(var=var, delta=delta) elif ctx.write_type == WriteType.ATOMIC: var_split = var.split('[') # replacing var[offset] with var + offset new_var = var_split[0] + '+' + var_split[1][:-1] divisor='' if 'sum' in self.fstmt.op_name.lower(): op = 'Add' elif 'max' in self.fstmt.op_name.lower(): op = 'Max' elif 'min' in self.fstmt.op_name.lower(): op = 'Min' elif 'mean' in self.fstmt.op_name.lower(): op = 'Add' divisor='/(end_off-start_off)' if ctx.write_location != WriteLocation.OUTER: raise NotImplementedError('Cannot support innter write of mean result due to unknown number of edges') else: raise NotImplementedError('Atomic instruction for', self.fstmt.op_name, 'is not implemented') val = 'atomic{op}({var}, {delta}{divisor});'.format(var=new_var, delta=delta, op=op, divisor=divisor) elif ctx.write_type == WriteType.ASSIGN: val = '{var} = {delta};'.format(var=var, delta=delta) key = 'inner_write' if ctx.write_location == WriteLocation.OUTER: key = 'outter_write' return key, val
[docs] def gen_load(self, ctx): k='load' v='' for arg in self.args: if arg in ctx.cur_stmt_ctx.kernel_arguments and arg not in ctx.loaded_args: v += '{type} {var_tmp} = {var}; '.format(type=arg.dtype_str, var_tmp=arg.id+TMP_SUFFIX, var=arg.id+ctx.query_offset(arg)) ctx.loaded_args.add(arg) return k, v.strip(' ')
[docs] def gen_edge_info_map(self, ctx): m = {'compute':'', 'load':''} key,val = self.gen_write(ctx) m[key] = val #k,v = self.gen_load(ctx) #m[k] = v return m
[docs] def gen_agg_info_map(self, ctx): m = {'init':'', 'compute':'', 'inner_write':'', 'outter_write':''} key,val = self.gen_write(ctx) m[key] = val return m
[docs] def create_var_like(self, x): return self.create_var(var_shape=x.var_shape, var_dtype=x.var_dtype, val_type=x.val_type, device=x.device)
[docs] def multiply_grad(self, dzdy, dydx, x): dim_size = len(dzdy.var_shape) dim_sizex = len(x.var_shape) assert dim_size == dim_sizex max_dim = [1 for i in range(dim_size)] if is_const_scalar(dydx): max_dim = dzdy.var_shape else: for i in range(dim_size): max_dim[i] = max(dzdy.var_shape[i], dydx.var_shape[i]) ret = [self.create_stmt(Schema('Mul'), args=[dzdy, dydx], ret=self.create_var(var_shape=max_dim, var_dtype=dzdy.var_dtype, val_type=infer_val_type([dzdy, dydx]), device=dzdy.device))] if x.var_shape != max_dim: if len(x.var_shape) != len(max_dim): raise NotImplementedError('Multiply grad has not supported input and gradient that have different dimension') diff_dim = -1 diff_count = 0 for i in range(len(max_dim)): if max_dim[i] != x.var_shape[i]: diff_dim = i diff_count += 1 if diff_count > 1: raise NotImplementedError('Multiply grad has not supported input and gradient that have more than 1 different dims') ret.append(self.create_stmt(Schema('Sum', dim=diff_dim, keep_dim=True), args=[ret[-1].ret], ret=self.create_var_like(x))) return ret
[docs] @abc.abstractmethod def grad_impl(self, pos, x, y, grad_y): '''return a map with keys: args, grad_x and op_schema'''
[docs] @abc.abstractmethod def gen_code(self, ctx): '''return the cuda code that corresponding to this op'''
@property def args(self): return self.fstmt.args @property def ret(self): return self.fstmt.ret @property def op_schema(self): return self.fstmt.op_schema
[docs]class BinaryOpImpl(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): stmt_list = self._grad_impl(pos, x, y, grad_y) if x.is_nodevar() and self.fstmt.is_edgewise(): stmt = self.create_stmt(Schema('AggSum'), args=[stmt_list[-1].ret], ret=self.create_var_like(x)) stmt_list.append(stmt) return stmt_list
@abc.abstractmethod def _grad_impl(self, pos, x, y, grad_y): '''Opreator specifc implementation of binaryop'''
[docs]class AddOp(BinaryOpImpl): def _grad_impl(self, pos, x, y, grad_y): '''y = x + k => dydx = 1''' return self.multiply_grad(dzdy=grad_y, dydx=1, x=x)
[docs] def gen_code(self, ctx): assert len(self.args) == 2 val0 = self.gen_var(self.args[0], ctx) val1 = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {val0} + {val1};'.format(ret=ret, val0=val0, val1=val1) return gen_info
#TODO: This was added by us
[docs]class SubOp(BinaryOpImpl): def _grad_impl(self, pos, x, y, grad_y): '''y = x - k => dydx = 1''' return self.multiply_grad(dzdy=grad_y, dydx=1, x=x)
[docs] def gen_code(self, ctx): assert len(self.args) == 2 val0 = self.gen_var(self.args[0], ctx) val1 = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {val0} - {val1};'.format(ret=ret, val0=val0, val1=val1) return gen_info
[docs]class LeakyReluOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): '''y = leaky_relu(x) => dydx = backward_leaky_relu(x)''' stmt_list = [] var1 = self.create_var_like(x) stmt_list.append(self.create_stmt(Schema('BackwardLeakyRelu', **self.op_schema._params), args=[x], ret=var1)) stmt_list += self.multiply_grad(grad_y, var1, x) return stmt_list
[docs] def gen_code(self, ctx): arg = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret}={val}>0?{val}:{slope}*{val};'.format(ret=ret,val=arg,slope=self.op_schema._params['negative_slope']) return gen_info
[docs]class ExpOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): '''y = exp(x) => dydx = exp(x) = y''' return self.multiply_grad(dzdy=grad_y, dydx=y, x=x)
[docs] def gen_code(self, ctx): arg = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = exp({val});'.format(ret=ret, val=arg) return gen_info
[docs]class MulOp(BinaryOpImpl): def _grad_impl(self, pos, x, y, grad_y): ''' y=x[0]*x[1] => dydx0 = x[1], dydx1 = x[0]''' assert pos < 2, 'Mul dealing with two operands' return self.multiply_grad(dzdy=grad_y, dydx=self.args[1-pos], x=x)
[docs] def gen_code(self, ctx): val0 = self.gen_var(self.args[0], ctx) val1 = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {val0}*{val1};'.format(ret=ret,val0=val0, val1=val1) return gen_info
[docs]class AggSumOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): '''y = AggSum(x) => dydx = Bcast(x)''' grad_stmt_list = self.multiply_grad(dzdy=grad_y, dydx=1, x=x) if not x.is_edgevar(): last_stmt = grad_stmt_list[-1] grad_stmt_list.append(self.create_stmt(Schema('AggSum'), args=[last_stmt.ret], ret=self.create_var_like(x))) return grad_stmt_list
[docs] def gen_init(self, var): key = 'init' val = var.dtype_str + ' '+ var.id + TMP_SUFFIX + ' = 0;' return key, val
[docs] def gen_code(self, ctx): val0 = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) initk,initv = self.gen_init(self.ret) gen_info =self.gen_agg_info_map(ctx) if ctx.cur_stmt_ctx.write_location == WriteLocation.INNER: gen_info['compute'] = '{ret} = {val};'.format(ret=ret, val=val0) else: gen_info['compute'] = '{ret} += {val};'.format(ret=ret, val=val0) gen_info[initk] = initv return gen_info
[docs]class AggMaxOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): '''y = AggMax(x) => dydx = (x)''' grad_stmt_list = [] # More precisely, the type of ret should be of type.E but it's OK as long as we don't materialize it. ret = self.create_var_like(x) ret._val_type = ValType.EDGE grad_stmt_list.append(self.create_stmt(Schema('BackwardAMax'), args=[x, y], ret=ret)) grad_stmt_list += self.multiply_grad(dzdy=grad_y, dydx=grad_stmt_list[-1].ret, x=x) if not x.is_edgevar(): last_stmt = grad_stmt_list[-1] grad_stmt_list.append(self.create_stmt(Schema('AggSum'), args=[last_stmt.ret], ret=self.create_var_like(x))) return grad_stmt_list
[docs] def gen_init(self, var): key = 'init' val = var.dtype_str + ' '+ var.id + TMP_SUFFIX + ' = 0xff800000;' # 0x7f800000 for inf. 0xff800000 for -inf return key, val
[docs] def gen_code(self, ctx): val0 = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) initk,initv = self.gen_init(self.ret) gen_info =self.gen_agg_info_map(ctx) if ctx.cur_stmt_ctx.write_location == WriteLocation.INNER: gen_info['compute'] = '{ret} = {val};'.format(ret=ret, val=val0) else: gen_info['compute'] = '{ret} = max({val}, {ret});'.format(ret=ret, val=val0) gen_info[initk] = initv return gen_info
[docs]class BackwardAMaxOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): '''''' raise NotImplementedError('Grad of grad is not supported')
[docs] def gen_code(self, ctx): forward_x = self.gen_var(self.args[0], ctx) forward_y = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {forward_x} == {forward_y} ? 1 : 0;'.format(ret=ret, forward_x=forward_x, forward_y=forward_y) return gen_info
[docs]class TrueDivOp(BinaryOpImpl): def _grad_impl(self, pos, x, y, grad_y): ''' y = x[0]/x[1] => dydx0 = 1/x[1] dydx1 = BackwardTrueDiv(x[0], x[1])''' assert pos < 2, 'TrueDiv dealing with two operands' stmt_list = [] if pos == 0: var = self.create_var_like(self.args[1]) stmt_list.append(self.create_stmt(Schema('TrueDiv'), args=[1, self.args[1]], ret=var)) else: stmt_list.append(self.create_stmt(Schema('Mul'), args=[self.args[1], self.args[1]], ret=self.create_var_like(self.args[1]))) stmt_list.append(self.create_stmt(Schema('Mul'), args=[-1, self.args[0]], ret=self.create_var_like(self.args[0]))) var = self.create_var_like(y) stmt_list.append(self.create_stmt(Schema('TrueDiv'), args=[stmt_list[-1].ret, stmt_list[-2].ret], ret=var)) stmt_list += self.multiply_grad(dzdy=grad_y, dydx=var, x=x) return stmt_list
[docs] def gen_code(self, ctx): left = self.gen_var(self.args[0], ctx) right = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {left}/{right};'.format(ret=ret, left=left, right=right) return gen_info
[docs]class ReluOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): assert x.val_type == grad_y.val_type ret = self.create_var_like(x) stmt_list = [self.create_stmt(Schema('BackwardRelu'), args=[x, grad_y], ret=ret)] return stmt_list
[docs] def gen_code(self, ctx): inp = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {inp} > 0 ? {inp} : 0;'.format(ret=ret, inp=inp) return gen_info
[docs]class BackwardReluOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): raise NotImplementedError('Grad for BackwardRelu is not implemented')
[docs] def gen_code(self, ctx): assert len(self.args) == 2, 'backward relu takes two arguments but {n} are given'.format(len(self.args)) inp0 = self.gen_var(self.args[0], ctx) inp1 = self.gen_var(self.args[1], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {inp0} > 0 ? {inp1} : 0;'.format(ret=ret, inp0=inp0, inp1=inp1) return gen_info
[docs]class GTypeCastOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): raise NotImplementedError('Grad for GTypeCast is not implemented')
[docs] def gen_code(self, ctx): raise NotImplementedError('Cannot generate code for GTypeCast')
[docs]class BackwardLeakyReluOp(OpImpl):
[docs] def grad_impl(self, pos, x, y, grad_y): raise NotImplementedError('Grad of grad is not supported')
[docs] def gen_code(self, ctx): x = self.gen_var(self.args[0], ctx) ret = self.gen_var(self.ret, ctx) gen_info = self.gen_edge_info_map(ctx) gen_info['compute'] = '{ret} = {x}>0?1:{slope};'.format(ret=ret,x=x,slope=self.op_schema._params['negative_slope']) return gen_info
[docs]def register_ops(): for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and name.endswith('Op'): key = name.split('Op')[0].lower() print_log(f'[magenta bold]Registry[/magenta bold]: Registering {name} with key {key}') impl_registry[key] = obj
register_ops()