Source code for stgraph.compiler.backend.callback

from abc import ABC, abstractmethod

[docs]class STGraphBackend(ABC): def __init__(self): self.backend_name = None self.backend_module = None self.kernel_wrapper = None
[docs] @abstractmethod def new_zeros_call_back(self, size, dtype, device, requires_grad=True): pass
[docs] @abstractmethod def tensor_raw_ptr(self, tensor): pass
[docs] def backend_cb(self, executor): executor.set_new_zeros_cb(self.new_zeros_call_back) executor.set_raw_ptr_cb(self.tensor_raw_ptr) return executor.execute(self.kernel_wrapper)