diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 193 |
1 files changed, 117 insertions, 76 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 867613cd..fbb306e8 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -22,6 +22,15 @@ from .data_type import DataType from .operation import get_slice_offsets +# Custom decorator function to allow formatting docstrings containing "{}" +def docstring_format_args(args): + def docstring(func): + func.__doc__ = func.__doc__.format(*args) + return func + + return docstring + + def warn_cpu(op, msg): print("Warning: {} {}, placing on CPU".format(op.type, msg)) @@ -61,6 +70,9 @@ class SupportedOperators: ) binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops + supported_int32_tensor_ops = ( + set(("Requantize", "ReduceSum", "CLZ",)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops + ) activation_ops = set( ( "QuantizedRelu", @@ -90,6 +102,9 @@ class SupportedOperators: shapeless_input_ops = set(("Split", "SplitV",)) | binary_elem_wise_main_ops supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",)) supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops + supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32)) + # Defined ranges for allowed values: + tens_dim_range = (1, 65535) def __init__(self): # Setup supported operator restriction checkers @@ -121,93 +136,119 @@ class SupportedOperators: self.supported_operator_restrictions.update( {op: self.check_activation_ops for op in SupportedOperators.activation_ops} ) + # Setup the generic constraints + self.generic_constraints = [] + self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape) + self.generic_constraints.append(SupportedOperators.constraint_tens_shapeless) + self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size) + self.generic_constraints.append(SupportedOperators.constraint_tens_dtype) + self.generic_constraints.append(SupportedOperators.constraint_tens_dimension) + self.generic_constraints.append(SupportedOperators.constraint_faf) + self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale) def is_operator_supported(self, op): if op.type not in SupportedOperators.supported_operators: return False - if not self.check_generic_restrictions(op): - return False + for constraint in self.generic_constraints: + valid, extra = constraint(op) + if not valid: + print('Warning: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type)) + print(" - {}".format(constraint.__doc__)) + if extra: + print(" {}".format(extra)) + return False if op.type in self.supported_operator_restrictions: return self.supported_operator_restrictions[op.type](op) return True - @classmethod - def check_generic_restrictions(cls, op): - # check fully defined shapes - for t in op.inputs: - if not t: - continue - if not t.has_fully_defined_shape(): - print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU") - return False - if t.shape == [] and op.type not in cls.shapeless_input_ops: - print( - "Warning:", - op.type, - "has input(s) of shape [].", - "Scalar input or broadcasting is not supported for this operator,", - "placing on CPU", - ) - return False - if len(t.shape) > 4: - print("Warning:", op.type, "has input(s) of unsupported shape", t.shape, "placing on CPU") - return False - for t in op.outputs: - if not t.has_fully_defined_shape(): - print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU") - return False - if t.shape == []: - print( - "Warning:", - op.type, - "has output(s) of shape [].", - "Scalar input or broadcasting is not supported for this operator,", - "placing on CPU", - ) - return False - if len(t.shape) > 4: - print("Warning:", op.type, "has output(s) of unsupported shape", t.shape, "placing on CPU") - return False + @staticmethod + def constraint_tens_defined_shape(op): + "Input(s) and Output Tensors must have a defined shape" + valid = True + extra = [] + for tens in op.inputs + op.outputs: + if tens: + valid &= tens.has_fully_defined_shape() + extra.append("shape={}".format(tens.shape)) + return valid, " ".join(extra) - # check data type - tensors = [t for t in op.get_ifm_ifm2_weights_ofm() if t is not None] - if not tensors: - tensors = op.inputs - for t in tensors: - if not (t.dtype.type & BaseType.Int): - return False - if ( - t.element_size() > 2 - and op.type - not in set(("Requantize", "ReduceSum", "CLZ",)) - | cls.binary_elem_wise_add_mul_sub - | cls.binary_elem_wise_shift_ops - ): - return False - # check size - if any(dim > 65536 for dim in t.shape): - return False + @classmethod + @docstring_format_args([shapeless_input_ops]) + def constraint_tens_shapeless(cls, op): + "Scalar or Broadcasting Tensors are only valid for Input Tensors, and when op type is: {}" + valid = True + extra = [] + for tens in op.inputs: + if tens and tens.shape == []: + valid &= op.type in cls.shapeless_input_ops + extra.append("shape={}".format(tens.shape)) + for tens in op.outputs: + if tens.shape == []: + valid = False + extra.append("shape={}".format(tens.shape)) + return valid, " ".join(extra) + + @staticmethod + def constraint_tens_shape_size(op): + "Input(s) and Output Tensors must not be greater than 4D" + valid = True + extra = [] + for tens in op.inputs + op.outputs: + if tens: + valid &= len(tens.shape) <= 4 + extra.append("shape={}".format(tens.shape)) + return valid, " ".join(extra) - # check fused activations - if ( - "fused_activation_function" in op.attrs - and op.attrs["fused_activation_function"] is not None - and op.attrs["fused_activation_function"] not in cls.supported_fused_activations - ): - return False + @classmethod + @docstring_format_args([supported_dtypes, supported_int32_tensor_ops]) + def constraint_tens_dtype(cls, op): + "Tensors must be of type: {}. Tensors which are int32 are only valid when op type is: {}" + valid = True + extra = [] + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + tensors = tensors if tensors else op.inputs + for tens in tensors: + if tens.dtype == DataType.int32: + valid &= op.type in cls.supported_int32_tensor_ops + else: + valid &= tens.dtype in cls.supported_dtypes + extra.append("dtype={}".format(tens.dtype)) + return valid, " ".join(extra) - # check inf values - for tens in op.get_ifm_ifm2_weights_ofm(): - if ( - (tens is not None) - and (tens.quantization is not None) - and (tens.quantization.scale_f32 is not None) - and (np.isinf(tens.quantization.scale_f32).any()) - ): - print("Warning:", op.type, "has inf valued tensor(s), placing on CPU") - return False + @classmethod + @docstring_format_args(tens_dim_range) + def constraint_tens_dimension(cls, op): + "Tensor dimensions must be in the range {}-{} (inclusive)" + tens_min, tens_max = cls.tens_dim_range + valid = True + extra = [] + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + tensors = tensors if tensors else op.inputs + for tens in tensors: + valid &= all(tens_min <= dim <= tens_max for dim in tens.shape) + extra.append("shape={}".format(tens.shape)) + return valid, " ".join(extra) - return True + @classmethod + @docstring_format_args([supported_fused_activations]) + def constraint_faf(cls, op): + "The fused activation function (if present) must be one of type: {}" + faf = op.attrs.get("fused_activation_function") + valid = (faf is None) or (faf in cls.supported_fused_activations) + extra = "fused_activation_function={}".format(faf) + return valid, extra + + @staticmethod + def constraint_tens_quant_scale(op): + "Tensors with quantization scales must be finite" + valid = True + extra = [] + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + for tens in tensors: + if tens.quantization is not None and tens.quantization.scale_f32 is not None: + valid &= not np.isinf(tens.quantization.scale_f32).any() + extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32)) + return valid, " ".join(extra) @classmethod def check_convolution_restrictions(cls, op): @@ -525,7 +566,7 @@ class SupportedOperators: return False # check shape - if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape: + if ifm_tensor.shape != ofm_tensor.shape: return False return True |