stgraph.compiler package¶
Subpackages¶
- stgraph.compiler.backend package
- stgraph.compiler.code_gen package
- Subpackages
- Submodules
- stgraph.compiler.code_gen.code_gen module
- stgraph.compiler.code_gen.compiler module
- stgraph.compiler.code_gen.cuda_check module
- stgraph.compiler.code_gen.cuda_driver module
- stgraph.compiler.code_gen.cuda_error module
- stgraph.compiler.code_gen.device_info module
- stgraph.compiler.code_gen.kernel_context module
KernelContextKernelContext.eq_dim()KernelContext.get_offset_key()KernelContext.init_offset_cache()KernelContext.kernel_argument_used_in_stmt()KernelContext.matrix_var_offset()KernelContext.query_offset()KernelContext.scalar_var_offset()KernelContext.set_stmt_ctx()KernelContext.vector_var_offset()KernelContext.write_inner()
LinearizedKernelContextStmtGenCtx
- Module contents
- stgraph.compiler.debugging package
- stgraph.compiler.op package
- stgraph.compiler.passes package
- Submodules
- stgraph.compiler.passes.cf module
- stgraph.compiler.passes.cse module
- stgraph.compiler.passes.dce module
- stgraph.compiler.passes.dependency_analysis module
- stgraph.compiler.passes.fusion module
- stgraph.compiler.passes.mem_planning module
- stgraph.compiler.passes.peephole module
- stgraph.compiler.passes.visualize module
- Module contents
- stgraph.compiler.val package
Submodules¶
stgraph.compiler.autodiff module¶
The auto-differentiation module
- stgraph.compiler.autodiff.diff(vars, grads, forward_units, fprog)[source]¶
The forward graph differentiator
For each var we find the statment that computes it, then we use itself as well as its grad to get the statements to calculate or accumulate the gradient for each of its imputs.
We need to synchrounize the graidents for a same var in order to keep the statements ordered by return var id. In order to do that, we use processed_count == forward_num_users_of_var as condition to determine whether we can push it to queue. If processed_count < forward* it means that we need to wait for computing more of its gradients before propogate back through that variable. If the var is in stopping var, there is no need to propogate further as the task will be delegated to backend system.
- Parameters:
vars (Var) – The var that has gradient to propogate back, determined by zoomOut
grads – The coresponding gradient for each var
forward_units – Forward execution units
- Returns:
The differentiated program to compute the gradients
- Return type:
BProg
stgraph.compiler.execution_unit module¶
The fundamental execution unit of STGraph
- class stgraph.compiler.execution_unit.ExecutionUnit(args, tmps, prog, compiled=False)[source]¶
Bases:
object- property compiled¶
- property kernel_name¶
- property program¶
- property tmps¶
- unit_count = 0¶
- class stgraph.compiler.execution_unit.FeatureAdaptiveKernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, node_ids_ptr, max_dims, kernel_name, compiled_module, launch_config)[source]¶
Bases:
Kernel
- class stgraph.compiler.execution_unit.V2Kernel(num_nodes, row_offsets_ptr, col_indices_ptr, eids_ptr, max_dims, kernel_name, compiled_module, launch_config, tile_sizes)[source]¶
Bases:
KernelThe Version 2 Kernel
This class contains the parameters for the second version of the kernel written for STGraph
- Parameters:
num_nodes (int) – Number of nodes present in the graph
row_offsets_ptr (c_type) – Pointer to the row offset array
col_indices_ptr (c_type) – Pointer to the column indicies array
- scalar_args¶
List of the scalar arguments passed to the kernel
- Type:
list[c_types]
- launch_config¶
List of the kernel launch configurations
- Type:
list[int]
stgraph.compiler.executor module¶
- class stgraph.compiler.executor.Executor(graph, forward_exec_units, backward_exec_units, compiled_module, rets)[source]¶
Bases:
object
stgraph.compiler.node module¶
stgraph.compiler.program module¶
- class stgraph.compiler.program.Stmt(op_schema, args, ret, op_type, callback)[source]¶
Bases:
object- class StmtInfo(op_schema, args)¶
Bases:
tuple- args¶
Alias for field number 1
- op_schema¶
Alias for field number 0
- property op_name¶
- class stgraph.compiler.program.Var(var_id, val_type, var_shape, var_dtype, device, requires_grad)[source]¶
Bases:
object- classmethod create_var(var_shape=None, var_dtype=None, val_type=None, var_id=None, device=None, requires_grad=True)[source]¶
- property device¶
- property id¶
- property int_id¶
- property requires_grad¶
- property stmt¶
- property users¶
- property val_type¶
- property var_dtype¶
- property var_shape¶
stgraph.compiler.registry module¶
- class stgraph.compiler.registry.AddOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
BinaryOpImpl
- class stgraph.compiler.registry.AggMaxOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.AggSumOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.BackwardAMaxOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.BackwardLeakyReluOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.BackwardReluOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.BinaryOpImpl(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.GTypeCastOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.GradInfo(targ, args, grad_x, op_schema)¶
Bases:
tuple- args¶
Alias for field number 1
- grad_x¶
Alias for field number 2
- op_schema¶
Alias for field number 3
- targ¶
Alias for field number 0
- class stgraph.compiler.registry.LeakyReluOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
OpImpl
- class stgraph.compiler.registry.MulOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
BinaryOpImpl
- class stgraph.compiler.registry.OpImpl(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
ABCNew ops need to inherit from this class with name “XXXOp”
- property args¶
- property op_schema¶
- property ret¶
- class stgraph.compiler.registry.SubOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
BinaryOpImpl
- class stgraph.compiler.registry.TrueDivOp(fstmt, create_var_cb, create_stmt_cb)[source]¶
Bases:
BinaryOpImpl
stgraph.compiler.schema module¶
stgraph.compiler.stgraph module¶
- class stgraph.compiler.stgraph.STGraph(backend_framework: STGraphBackend)[source]¶
Bases:
object
stgraph.compiler.utils module¶
- class stgraph.compiler.utils.EdgeDirection(value)[source]¶
Bases:
EnumAn enumeration.
- IN = 0¶
- OUT = 1¶
- class stgraph.compiler.utils.FusionType(value)[source]¶
Bases:
EnumAn enumeration.
- NEAN = 0¶
- NN = 1¶
- NOT_FUSIBLE = 2¶
- class stgraph.compiler.utils.OpType(value)[source]¶
Bases:
EnumAn enumeration.
- A = 2¶
- D = 3¶
- E = 1¶
- S = 0¶
- class stgraph.compiler.utils.ParallelMode(value)[source]¶
Bases:
EnumAn enumeration.
- DstParallel = 1¶
- SrcParallel = 0¶
- class stgraph.compiler.utils.ValType(value)[source]¶
Bases:
EnumAn enumeration.
- DEST = 1¶
- EDGE = 2¶
- PARAM = 3¶
- SRC = 0¶
- class stgraph.compiler.utils.WriteLocation(value)[source]¶
Bases:
EnumAn enumeration.
- INNER = 0¶
- NONE = 2¶
- OUTER = 1¶
- class stgraph.compiler.utils.WriteType(value)[source]¶
Bases:
EnumAn enumeration.
- ADD = 0¶
- ASSIGN = 2¶
- ATOMIC = 1¶
- NONE = 3¶
Module contents¶
The backbone of STGraph