from enum import Enum
from collections.abc import Iterable
unused_ids = set()
val_seq = 0
MAX_THREAD_PER_BLOCK=1024
MAX_BLOCK=65535
var_prefix='V'
cen_attr_postfix='cen'
inb_attr_postfix='inb'
[docs]class EdgeDirection(Enum):
IN = 0
OUT = 1
[docs]class ValType(Enum):
SRC = 0
DEST = 1
EDGE = 2
PARAM = 3
[docs]class OpType(Enum):
S = 0
E = 1
A = 2
D = 3
[docs]class FusionType(Enum):
NEAN = 0
NN = 1
NOT_FUSIBLE = 2
[docs]class ParallelMode(Enum):
SrcParallel = 0
DstParallel = 1
[docs]class WriteType(Enum):
ADD = 0
ATOMIC = 1
ASSIGN = 2
NONE = 3
[docs]class WriteLocation(Enum):
INNER = 0
OUTER = 1
NONE = 2
[docs]def is_const_scalar(val):
return type(val) in (str, int, float, bool)
[docs]def infer_val_type(vals):
'''
When type of vals are different, we return edge type.
P type is considered same with every one.
'''
assert isinstance(vals, Iterable), 'vals must be iterable'
assert len(vals) >= 1
for i in range(len(vals)):
if not is_const_scalar(vals[i]) and vals[i].val_type != ValType.PARAM:
first_non_p_type = vals[i].val_type
diff_val_type = any(val.val_type != first_non_p_type for val in vals if not is_const_scalar(val) and val.val_type != ValType.PARAM )
if diff_val_type:
vtype = ValType.EDGE
else:
vtype = first_non_p_type
return vtype
[docs]def infer_op_type_from_args(op_schema, args):
if 'agg' in op_schema._op_name.lower():
return OpType.A
inf_val_type = infer_val_type(args)
if inf_val_type == ValType.EDGE:
return OpType.E
elif inf_val_type == ValType.SRC:
return OpType.S
elif inf_val_type == ValType.DEST:
return OpType.D
[docs]def any_var(var_list):
first_var = None
for var in var_list:
if not is_const_scalar(var):
first_var = var
return first_var
[docs]def bcast_dim(var_list):
first_var = any_var(var_list)
assert first_var != None
maxdim = [dim for dim in first_var.var_shape]
for i in range(0, len(var_list)):
if is_const_scalar(var_list[i]):
continue
assert len(maxdim) == len(var_list[i].var_shape), str(maxdim) + str(var_list[i].var_shape) + str(var_list)
for j in range(len(maxdim)):
maxdim[j] = max(maxdim[j], var_list[i].var_shape[j])
return maxdim