Source code for stgraph.compiler.val.val_factory

from stgraph.compiler.utils import ValType
from stgraph.compiler.val.pytorch.torch_val import TorchVal

[docs]class ValFactory: def __init__(self): """Factory class to create Val objects""" pass
[docs] def create(self, type: ValType, tensor, backend, id, fprog, reduce_dim): val_backend = self.get_val_backend(backend) return val_backend( backend=backend, tensor=tensor, val_type=type, id=id, fprog=fprog, reduce_dim=reduce_dim, )
[docs] def get_val_backend(self, backend): key, _ = backend if key == "torch": return TorchVal else: raise NotImplementedError( f"Backend support for {key} has not been implemented yet" )