Source code for stgraph.compiler.stgraph

import functools
from collections import defaultdict, namedtuple
from collections.abc import Iterable

from .node import CentralNode
# from .op.op import create_op, AggMaxOp, AggMinOp, AggMeanOp
from .program import Var, Stmt, Program
from .passes import optimize, CF, fuse, visualize
from .schema import Schema
from .autodiff import diff
from .code_gen import code_gen 
from .executor import Executor
from .utils import var_prefix, cen_attr_postfix, inb_attr_postfix
import gc
import torch

from stgraph.compiler.backend.callback import STGraphBackend
from stgraph.compiler.utils import ValType
from stgraph.compiler.val.val_factory import ValFactory
from stgraph.compiler.op.op_factory import OpFactory

import snoop


[docs]class Context(): def __init__(self, func, nspace, run_cb): functools.update_wrapper(self, func) self._f = func self._nspace = nspace self._entry_count = 0 self._run_cb = run_cb self.val_factory = ValFactory() self._op_factory = OpFactory() # Hold reference to parameters of current module to avoid repeated lookup self._input_cache = {} self._graph_info_cache = None self._executor_cache = None def __call__(self, **kwargs): executor = self._setup_executor(**kwargs) ret = self._run_cb(executor) if len(ret) == 1: return ret[0] return ret def _setup_executor(self, **kwargs): graph = kwargs.get('g', None) node_feats = kwargs.get('n_feats', {}) edge_feats = kwargs.get('e_feats', {}) if not graph: raise NameError('Need to provide the graph as one of keyward arguments') if self._entry_count == 0: fprog = Program() ret = self._trace(node_feats, edge_feats, self._input_cache, fprog) # print('TracedProgram' + str(fprog), 'Ret value:', ret) # pretty_print_GIR(fprog,"TGCN GIR") self._executor_cache = self._diff_then_compile(ret, fprog, graph) for k, v in node_feats.items(): self._input_cache[var_prefix + k + cen_attr_postfix] = v self._input_cache[var_prefix + k + inb_attr_postfix] = v for k, v in edge_feats.items(): self._input_cache[var_prefix+k] = v self._executor_cache.restart(self._input_cache, graph) self._entry_count += 1 return self._executor_cache def _trace(self, nfeats, efeats, input_cache, fprog): backend = self._find_backend() central_node = self._init_central_node(nfeats, efeats, fprog, backend) # pretty_print_Central_Node(central_node=central_node, print_tensors=False) old_libs = defaultdict(dict) self._monkey_patch_namespace(old_libs, input_cache, fprog, backend) ret = self._f(central_node) self._remove_patch(old_libs, backend) self._destroy_central_node(central_node, nfeats, efeats) if ret == None: raise NameError('Ret is none. Execution is aborted') return [ret.var] if not isinstance(ret, Iterable) else ret.var def _diff_then_compile(self, out_set, fprog, graph): optimize(fprog) vars = [] for var in out_set: vars.append(var) forward_exe_units = fuse([fprog], vars) grads = [] for var in vars: grads.append(Var.create_var(var_shape=var.var_shape, var_dtype=var.var_dtype, val_type=var.val_type, device=var.device)) backward_exe_units = diff(vars, grads, forward_exe_units, fprog) # visualize.plot_exec_units(forward_exe_units + backward_exe_units) # NOTE: The last parameter here was ('int' if graph.nbits == 32 else 'long long int') but we changed # it to just 'int' since that should be sufficient for all use case that we can think of now compiled_module = code_gen.gen_code(forward_exe_units + backward_exe_units, 'int', graph.graph_type()) return Executor(graph, forward_exe_units, backward_exe_units, compiled_module, vars) def _init_central_node(self, nfeats, efeats, fprog, backend): cen = CentralNode() if nfeats: for k, v in nfeats.items(): dst_node_val = self.val_factory.create(ValType.DEST, v, backend, id=k+cen_attr_postfix, fprog=fprog, reduce_dim=True) setattr(cen, k, dst_node_val) for n in cen.innbs: src_node_val = self.val_factory.create(ValType.SRC, v, backend, id=k+inb_attr_postfix, fprog=fprog, reduce_dim=True) setattr(n, k, src_node_val) if efeats: for k, v in efeats.items(): for e in cen.inedges: edge_val = self.val_factory.create(ValType.EDGE, v, backend, id=k, fprog=fprog, reduce_dim=True) setattr(e, k, edge_val) return cen def _destroy_central_node(self, cen, nfeats, efeats): if nfeats: for k, _ in nfeats.items(): delattr(cen, k) for n in cen.innbs: delattr(n, k) if efeats: for k, _ in efeats.items(): for e in cen.inedges: delattr(e, k) def _monkey_patch_namespace(self, old_libs, input_cache, fprog, backend): """Symbolizing central node and its innbs and inedges""" if backend[0] == 'torch': for i, nspace in enumerate(self._nspace): if '__name__' in nspace.__dict__: # symbolizing functions for torch namespace k = self._mapping_key(i, 'function') for key in nspace.__dict__: m = nspace.__dict__[key] if 'function' in str(type(m)): if key in old_libs[k]: raise KeyError('Found', key, ' already in old_libs') old_libs[k][key] = m nspace.__dict__[key] = self._op_factory.create(m, backend[0], fprog) else: ## Dealing with module for key in nspace.__dict__.keys(): # symbolizing parameters for self namespace m = nspace.__dict__[key] k = self._mapping_key(i, key) if key.startswith('_parameters'): for mkey in m.keys(): if mkey in old_libs[k]: raise KeyError('Found', key, ' already in old_libs') old_libs[k][mkey] = m[mkey] input_cache[var_prefix+mkey] = m[mkey] param_val = self.val_factory.create(ValType.PARAM, m[mkey], backend, id=mkey, fprog=fprog, reduce_dim=False) m[mkey] = param_val # symbolizing buffers for self namespace if key.startswith('_buffers'): for mkey in m.keys(): if mkey in old_libs[k]: raise KeyError('Found', key, 'already in old_libs') old_libs[k][mkey] = m[mkey] input_cache[var_prefix+mkey] = m[mkey] param_val = self.val_factory.create(ValType.PARAM, m[mkey], backend, id=mkey, fprog=fprog, reduce_dim=False) m[mkey] = param_val # symbolizing modules for self namespace if key.startswith('_modules'): for mkey in m.keys(): if mkey in old_libs[k]: raise KeyError('Found', key, ' already in old_libs') old_libs[k][mkey] = m[mkey] m[mkey] = self._op_factory.create(m[mkey], backend[0], fprog) else: raise NotImplementedError('Backend ' + backend[0] + ' is not supported yet!') def _remove_patch(self, old_libs, backend): if backend[0] == 'torch': for i, nspace in enumerate(self._nspace): if '__name__' in nspace.__dict__: # desymbolizing functions for torch namespace if 'torch' in nspace.__name__.lower(): k = self._mapping_key(i, 'function') for key in old_libs[k]: nspace.__dict__[key] = old_libs[k][key] else: for key in nspace.__dict__.keys(): # desymbolizing parameters for self namespace m = nspace.__dict__[key] k = self._mapping_key(i, key) if key.startswith('_parameters') or key.startswith('_modules') or key.startswith('_buffers'): for mkey in old_libs[k]: m[mkey] = old_libs[k][mkey] else: raise NotImplementedError('Backend ' + backend[0] + ' is not supported yet!') def _find_backend(self): """ Finds the backend framework being used Returns: A tuple containing the name and module instance of the backend being used """ backend_module = self._nspace[1] backend_name = backend_module.__name__ return (backend_name, backend_module) def _mapping_key(self, name_space_id, original_key): return str(name_space_id) + str(original_key)
[docs]class STGraph(): def __init__(self, backend_framework: STGraphBackend): self._ctx_map = {} self._backend_framework = backend_framework self._run_cb = backend_framework.backend_cb
[docs] def compile(self, gnn_module, hetero_graph=False): # adding the GNN module and the backend framework to the namespace list namespace = [gnn_module, self._backend_framework.backend_module] def wrapper(func): if not func.__name__ in self._ctx_map: if not hetero_graph: self._ctx_map[func.__name__] = Context(func, namespace, self._run_cb) else: raise NotImplementedError('Heterogeneous graph is not supported yet') return self._ctx_map[func.__name__] return wrapper