stgraph.compiler package

Subpackages

Submodules

stgraph.compiler.autodiff module

The auto-differentiation module

stgraph.compiler.autodiff.diff(vars, grads, forward_units, fprog)[source]

The forward graph differentiator

For each var we find the statment that computes it, then we use itself as well as its grad to get the statements to calculate or accumulate the gradient for each of its imputs.

We need to synchrounize the graidents for a same var in order to keep the statements ordered by return var id. In order to do that, we use processed_count == forward_num_users_of_var as condition to determine whether we can push it to queue. If processed_count < forward* it means that we need to wait for computing more of its gradients before propogate back through that variable. If the var is in stopping var, there is no need to propogate further as the task will be delegated to backend system.

Parameters:
  • vars (Var) – The var that has gradient to propogate back, determined by zoomOut

  • grads – The coresponding gradient for each var

  • forward_units – Forward execution units

Returns:

The differentiated program to compute the gradients

Return type:

BProg

stgraph.compiler.execution_unit module

The fundamental execution unit of STGraph

class stgraph.compiler.execution_unit.ExecutionUnit(args, tmps, prog, compiled=False)[source]

Bases: object

add_parent_unit(parent)[source]
add_ret_val(ret_val)[source]
all_rets()[source]
calculate_kernel_launch_params(num_nodes)[source]
calculate_kernel_params(num_nodes)[source]
calculate_kernel_params_fa(num_nodes)[source]
property compiled
compute_tile_sizes(blockDimx, blockDimy)[source]
create_context(index_type)[source]
depends_on(other)[source]
feature_size()[source]
first_pow2_less_than_n(n, upper_bound)[source]
get_all_args()[source]

return the set of all vars used/returned in the program of this exec unit

get_all_vars()[source]

return the set of all vars used/returned in the program of this exec unit

has_parent()[source]
is_child_of(other)[source]
kernel_args()[source]
property kernel_name
kernel_run(tensor_list)[source]
materilized_vars()[source]
max_dims()[source]
max_ret_id()[source]
merge_with_independent_unit(other)[source]
nearest_pow2(targ)[source]
parallel_mode()[source]
prepare_compiled_kernel(graph, compiled_module)[source]
property program
reset_graph_info(graph)[source]
set_parallel_mode(mode)[source]
property tmps
unit_args()[source]
unit_count = 0
unit_rets()[source]
use_fa_tmpl()[source]
class stgraph.compiler.execution_unit.FeatureAdaptiveKernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr, max_dims, kernel_name, compiled_module, launch_config)[source]

Bases: Kernel

class stgraph.compiler.execution_unit.Kernel[source]

Bases: object

reset_graph_info(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr)[source]
run(tensor_list)[source]
class stgraph.compiler.execution_unit.V2Kernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, max_dims, kernel_name, compiled_module, launch_config, tile_sizes)[source]

Bases: Kernel

The Version 2 Kernel

This class contains the parameters for the second version of the kernel written for STGraph

Parameters:
  • num_nodes (int) – Number of nodes present in the graph

  • row_offsets_ptr (c_type) – Pointer to the row offset array

  • col_indices_ptr (c_type) – Pointer to the column indicies array

scalar_args

List of the scalar arguments passed to the kernel

Type:

list[c_types]

launch_config

List of the kernel launch configurations

Type:

list[int]

stgraph.compiler.executor module

class stgraph.compiler.executor.ExeState[source]

Bases: object

all_bu_executed()[source]
clear_current_tensor_state()[source]
is_executed_bu(bu)[source]
reset(input_map, f_merged_units, bunits)[source]
track_executed_bu(bu)[source]
track_tensor(key, val)[source]
class stgraph.compiler.executor.Executor(graph, forward_exec_units, backward_exec_units, compiled_module, rets)[source]

Bases: object

backward_cb(kid, grad_list)[source]

FuncWrapper will call this function in backward pass

construct_backward_mappping(funits, bunits)[source]
create_tensor_for_grad_vars(var_list, tensor_map)[source]
create_tensor_for_vars(var_list)[source]
execute(FuncWrapper)[source]

Execute forward pass

execute_compiled(uid, FuncWrapper)[source]
execute_prog(units)[source]
execute_unit(unit, tensor_list)[source]
forward_cb(uid, kernel_args, rets, tensor_list)[source]

FuncWrapper will call this function in forward pass

merge_units(exec_units)[source]
restart(input_map, graph=None)[source]
set_new_zeros_cb(cb)[source]
set_raw_ptr_cb(cb)[source]
class stgraph.compiler.executor.MergedUnit(units)[source]

Bases: object

append(unit)[source]
compiled()[source]
joint_args()[source]
joint_inputs()[source]
joint_rets()[source]
kernel_arg_list()[source]
last()[source]
union_of_rets()[source]
class stgraph.compiler.executor.Stack(val=None)[source]

Bases: object

pop()[source]
print()[source]
push(val)[source]
top()[source]

stgraph.compiler.node module

class stgraph.compiler.node.CentralNode[source]

Bases: object

update_allnode(feat_map)[source]
class stgraph.compiler.node.NbEdge(center, direction, nbnodes)[source]

Bases: object

class stgraph.compiler.node.NbNode(center, direction)[source]

Bases: object

stgraph.compiler.program module

class stgraph.compiler.program.Program[source]

Bases: object

append_stmt(stmt)[source]
begin()[source]
clear_stmts()[source]
copy_append_prog(other_prog)[source]
copy_append_stmt(stmt)[source]
copy_append_stmts(stmts)[source]
copy_prepend_prog(other_prog)[source]
copy_prepend_stmt(stmt)[source]
deepcopy()[source]
empty()[source]
end()[source]
find_ret_var_by_id(targ_id)[source]
find_var_by_id(targ_id)[source]
has_stmt(stmt)[source]
input_vars()[source]
insert_stmts_before(stmt, stmts_list)[source]
last_stmt()[source]
prepend_stmt(stmt)[source]
resort_vars()[source]
set_input(var_name)[source]
class stgraph.compiler.program.Stmt(op_schema, args, ret, op_type, callback)[source]

Bases: object

class StmtInfo(op_schema, args)

Bases: tuple

args

Alias for field number 1

op_schema

Alias for field number 0

copy()[source]
classmethod create_add_stmt(args)[source]
classmethod create_binary_bcast_stmt(op_schema, args, callback=None)[source]
classmethod create_mul_stmt(args)[source]
classmethod create_stmt(op_schema=None, args=None, ret=None, callback=None)[source]
depends_on(stmt)[source]
execute(args, **kargs)[source]
gen_code(ctx)[source]
grad(y, grad_y)[source]
insert_after(new_stmt)[source]
insert_before(new_stmt)[source]
is_agg()[source]
is_dst()[source]
is_edgewise()[source]
is_element_wise_fusable()[source]
is_nodewise()[source]
is_src()[source]
is_supported()[source]
property op_name
print_stmt_args()[source]
remove_cur()[source]
replace_arg_with(old, new, propogate_shape)[source]
shape_propogation()[source]
stmt_info()[source]
type_eq(var1, var2)[source]
class stgraph.compiler.program.Var(var_id, val_type, var_shape, var_dtype, device, requires_grad)[source]

Bases: object

add_user(stmt)[source]
classmethod copy(other)[source]
classmethod create_var(var_shape=None, var_dtype=None, val_type=None, var_id=None, device=None, requires_grad=True)[source]
detach_from_stmt()[source]
property device
property id
property int_id
is_dstvar()[source]
is_edgevar()[source]
is_nodevar()[source]
is_srcvar()[source]
replace_all_uses_with(other_var, propogate_shape)[source]
replace_grad(new_grad)[source]
property requires_grad
rmv_user(stmt)[source]
set_grad(other_var)[source]
set_to_be_grad_of(other_var)[source]
property stmt
used_by(stmt)[source]
property users
property val_type
property var_dtype
property var_shape

stgraph.compiler.registry module

class stgraph.compiler.registry.AddOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: BinaryOpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

class stgraph.compiler.registry.AggMaxOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

gen_init(var)[source]
grad_impl(pos, x, y, grad_y)[source]

y = AggMax(x) => dydx = (x)

class stgraph.compiler.registry.AggSumOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

gen_init(var)[source]
grad_impl(pos, x, y, grad_y)[source]

y = AggSum(x) => dydx = Bcast(x)

class stgraph.compiler.registry.BackwardAMaxOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]
class stgraph.compiler.registry.BackwardLeakyReluOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

class stgraph.compiler.registry.BackwardReluOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

class stgraph.compiler.registry.BinaryOpImpl(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

class stgraph.compiler.registry.ExpOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

y = exp(x) => dydx = exp(x) = y

class stgraph.compiler.registry.GTypeCastOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

class stgraph.compiler.registry.GradInfo(targ, args, grad_x, op_schema)

Bases: tuple

args

Alias for field number 1

grad_x

Alias for field number 2

op_schema

Alias for field number 3

targ

Alias for field number 0

class stgraph.compiler.registry.LeakyReluOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

y = leaky_relu(x) => dydx = backward_leaky_relu(x)

class stgraph.compiler.registry.MulOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: BinaryOpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

class stgraph.compiler.registry.OpImpl(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: ABC

New ops need to inherit from this class with name “XXXOp”

property args
create_var_like(x)[source]
gen_agg_info_map(ctx)[source]
abstract gen_code(ctx)[source]

return the cuda code that corresponding to this op

gen_edge_info_map(ctx)[source]
gen_load(ctx)[source]
gen_var(var, ctx)[source]
gen_write(ctx)[source]
grad(y, grad_y)[source]
abstract grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

multiply_grad(dzdy, dydx, x)[source]
property op_schema
property ret
class stgraph.compiler.registry.ReluOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: OpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

grad_impl(pos, x, y, grad_y)[source]

return a map with keys: args, grad_x and op_schema

class stgraph.compiler.registry.SubOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: BinaryOpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

class stgraph.compiler.registry.TrueDivOp(fstmt, create_var_cb, create_stmt_cb)[source]

Bases: BinaryOpImpl

gen_code(ctx)[source]

return the cuda code that corresponding to this op

stgraph.compiler.registry.look_up_registry(stmt)[source]
stgraph.compiler.registry.register_ops()[source]
stgraph.compiler.registry.register_or_look_up_backend_cb(stmt, cb)[source]

stgraph.compiler.schema module

class stgraph.compiler.schema.Schema(op_name, **kargs)[source]

Bases: object

Schema of an op

stgraph.compiler.stgraph module

class stgraph.compiler.stgraph.Context(func, nspace, run_cb)[source]

Bases: object

class stgraph.compiler.stgraph.STGraph(backend_framework: STGraphBackend)[source]

Bases: object

compile(gnn_module, hetero_graph=False)[source]

stgraph.compiler.utils module

class stgraph.compiler.utils.EdgeDirection(value)[source]

Bases: Enum

An enumeration.

IN = 0
OUT = 1
class stgraph.compiler.utils.FusionType(value)[source]

Bases: Enum

An enumeration.

NEAN = 0
NN = 1
NOT_FUSIBLE = 2
class stgraph.compiler.utils.OpType(value)[source]

Bases: Enum

An enumeration.

A = 2
D = 3
E = 1
S = 0
class stgraph.compiler.utils.ParallelMode(value)[source]

Bases: Enum

An enumeration.

DstParallel = 1
SrcParallel = 0
class stgraph.compiler.utils.ValType(value)[source]

Bases: Enum

An enumeration.

DEST = 1
EDGE = 2
PARAM = 3
SRC = 0
class stgraph.compiler.utils.WriteLocation(value)[source]

Bases: Enum

An enumeration.

INNER = 0
NONE = 2
OUTER = 1
class stgraph.compiler.utils.WriteType(value)[source]

Bases: Enum

An enumeration.

ADD = 0
ASSIGN = 2
ATOMIC = 1
NONE = 3
stgraph.compiler.utils.any_var(var_list)[source]
stgraph.compiler.utils.bcast_dim(var_list)[source]
stgraph.compiler.utils.infer_op_type_from_args(op_schema, args)[source]
stgraph.compiler.utils.infer_val_type(vals)[source]

When type of vals are different, we return edge type. P type is considered same with every one.

stgraph.compiler.utils.is_const_scalar(val)[source]

Module contents

The backbone of STGraph