aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/supported_operators.py193
-rw-r--r--ethosu/vela/test/test_supported_operators.py76
2 files changed, 193 insertions, 76 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 867613cd..fbb306e8 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -22,6 +22,15 @@ from .data_type import DataType
from .operation import get_slice_offsets
+# Custom decorator function to allow formatting docstrings containing "{}"
+def docstring_format_args(args):
+ def docstring(func):
+ func.__doc__ = func.__doc__.format(*args)
+ return func
+
+ return docstring
+
+
def warn_cpu(op, msg):
print("Warning: {} {}, placing on CPU".format(op.type, msg))
@@ -61,6 +70,9 @@ class SupportedOperators:
)
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
+ supported_int32_tensor_ops = (
+ set(("Requantize", "ReduceSum", "CLZ",)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+ )
activation_ops = set(
(
"QuantizedRelu",
@@ -90,6 +102,9 @@ class SupportedOperators:
shapeless_input_ops = set(("Split", "SplitV",)) | binary_elem_wise_main_ops
supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
+ supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+ # Defined ranges for allowed values:
+ tens_dim_range = (1, 65535)
def __init__(self):
# Setup supported operator restriction checkers
@@ -121,93 +136,119 @@ class SupportedOperators:
self.supported_operator_restrictions.update(
{op: self.check_activation_ops for op in SupportedOperators.activation_ops}
)
+ # Setup the generic constraints
+ self.generic_constraints = []
+ self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_shapeless)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
+ self.generic_constraints.append(SupportedOperators.constraint_faf)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
def is_operator_supported(self, op):
if op.type not in SupportedOperators.supported_operators:
return False
- if not self.check_generic_restrictions(op):
- return False
+ for constraint in self.generic_constraints:
+ valid, extra = constraint(op)
+ if not valid:
+ print('Warning: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type))
+ print(" - {}".format(constraint.__doc__))
+ if extra:
+ print(" {}".format(extra))
+ return False
if op.type in self.supported_operator_restrictions:
return self.supported_operator_restrictions[op.type](op)
return True
- @classmethod
- def check_generic_restrictions(cls, op):
- # check fully defined shapes
- for t in op.inputs:
- if not t:
- continue
- if not t.has_fully_defined_shape():
- print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU")
- return False
- if t.shape == [] and op.type not in cls.shapeless_input_ops:
- print(
- "Warning:",
- op.type,
- "has input(s) of shape [].",
- "Scalar input or broadcasting is not supported for this operator,",
- "placing on CPU",
- )
- return False
- if len(t.shape) > 4:
- print("Warning:", op.type, "has input(s) of unsupported shape", t.shape, "placing on CPU")
- return False
- for t in op.outputs:
- if not t.has_fully_defined_shape():
- print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU")
- return False
- if t.shape == []:
- print(
- "Warning:",
- op.type,
- "has output(s) of shape [].",
- "Scalar input or broadcasting is not supported for this operator,",
- "placing on CPU",
- )
- return False
- if len(t.shape) > 4:
- print("Warning:", op.type, "has output(s) of unsupported shape", t.shape, "placing on CPU")
- return False
+ @staticmethod
+ def constraint_tens_defined_shape(op):
+ "Input(s) and Output Tensors must have a defined shape"
+ valid = True
+ extra = []
+ for tens in op.inputs + op.outputs:
+ if tens:
+ valid &= tens.has_fully_defined_shape()
+ extra.append("shape={}".format(tens.shape))
+ return valid, " ".join(extra)
- # check data type
- tensors = [t for t in op.get_ifm_ifm2_weights_ofm() if t is not None]
- if not tensors:
- tensors = op.inputs
- for t in tensors:
- if not (t.dtype.type & BaseType.Int):
- return False
- if (
- t.element_size() > 2
- and op.type
- not in set(("Requantize", "ReduceSum", "CLZ",))
- | cls.binary_elem_wise_add_mul_sub
- | cls.binary_elem_wise_shift_ops
- ):
- return False
- # check size
- if any(dim > 65536 for dim in t.shape):
- return False
+ @classmethod
+ @docstring_format_args([shapeless_input_ops])
+ def constraint_tens_shapeless(cls, op):
+ "Scalar or Broadcasting Tensors are only valid for Input Tensors, and when op type is: {}"
+ valid = True
+ extra = []
+ for tens in op.inputs:
+ if tens and tens.shape == []:
+ valid &= op.type in cls.shapeless_input_ops
+ extra.append("shape={}".format(tens.shape))
+ for tens in op.outputs:
+ if tens.shape == []:
+ valid = False
+ extra.append("shape={}".format(tens.shape))
+ return valid, " ".join(extra)
+
+ @staticmethod
+ def constraint_tens_shape_size(op):
+ "Input(s) and Output Tensors must not be greater than 4D"
+ valid = True
+ extra = []
+ for tens in op.inputs + op.outputs:
+ if tens:
+ valid &= len(tens.shape) <= 4
+ extra.append("shape={}".format(tens.shape))
+ return valid, " ".join(extra)
- # check fused activations
- if (
- "fused_activation_function" in op.attrs
- and op.attrs["fused_activation_function"] is not None
- and op.attrs["fused_activation_function"] not in cls.supported_fused_activations
- ):
- return False
+ @classmethod
+ @docstring_format_args([supported_dtypes, supported_int32_tensor_ops])
+ def constraint_tens_dtype(cls, op):
+ "Tensors must be of type: {}. Tensors which are int32 are only valid when op type is: {}"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ tensors = tensors if tensors else op.inputs
+ for tens in tensors:
+ if tens.dtype == DataType.int32:
+ valid &= op.type in cls.supported_int32_tensor_ops
+ else:
+ valid &= tens.dtype in cls.supported_dtypes
+ extra.append("dtype={}".format(tens.dtype))
+ return valid, " ".join(extra)
- # check inf values
- for tens in op.get_ifm_ifm2_weights_ofm():
- if (
- (tens is not None)
- and (tens.quantization is not None)
- and (tens.quantization.scale_f32 is not None)
- and (np.isinf(tens.quantization.scale_f32).any())
- ):
- print("Warning:", op.type, "has inf valued tensor(s), placing on CPU")
- return False
+ @classmethod
+ @docstring_format_args(tens_dim_range)
+ def constraint_tens_dimension(cls, op):
+ "Tensor dimensions must be in the range {}-{} (inclusive)"
+ tens_min, tens_max = cls.tens_dim_range
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ tensors = tensors if tensors else op.inputs
+ for tens in tensors:
+ valid &= all(tens_min <= dim <= tens_max for dim in tens.shape)
+ extra.append("shape={}".format(tens.shape))
+ return valid, " ".join(extra)
- return True
+ @classmethod
+ @docstring_format_args([supported_fused_activations])
+ def constraint_faf(cls, op):
+ "The fused activation function (if present) must be one of type: {}"
+ faf = op.attrs.get("fused_activation_function")
+ valid = (faf is None) or (faf in cls.supported_fused_activations)
+ extra = "fused_activation_function={}".format(faf)
+ return valid, extra
+
+ @staticmethod
+ def constraint_tens_quant_scale(op):
+ "Tensors with quantization scales must be finite"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ for tens in tensors:
+ if tens.quantization is not None and tens.quantization.scale_f32 is not None:
+ valid &= not np.isinf(tens.quantization.scale_f32).any()
+ extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32))
+ return valid, " ".join(extra)
@classmethod
def check_convolution_restrictions(cls, op):
@@ -525,7 +566,7 @@ class SupportedOperators:
return False
# check shape
- if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape:
+ if ifm_tensor.shape != ofm_tensor.shape:
return False
return True
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index df310434..53c20927 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -16,9 +16,12 @@
#
# Description:
# Unit tests for support_operators
+import numpy as np
+
from ethosu.vela.data_type import DataType
from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
@@ -84,3 +87,76 @@ def test_strided_slice():
op = create_strided_slice()
op.attrs["end_mask"] = 0
assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_defined_shape():
+ # Tensors cannot have None in them
+ inp = Tensor([1, 8, None, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shapeless():
+ # Shapeless input is allowed if its of a certain type:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # Shapeless output is not allowed at all:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ assert not support.is_operator_supported(op)
+ # Invalid shapeless input due to op type:
+ inp = Tensor([], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shape_size():
+ # Tensors cannot be > 4D
+ inp = Tensor([1, 1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dtype():
+ # Tensors can only be of type uint8, int8, int16 (and int32)
+ inp = Tensor([1, 8, 8, 8], DataType.float32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.float32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ # For int32, only select op types are allowed:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
+ assert support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 8], DataType.int32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.int32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dimension():
+ # Tensors can only have values in the inclusive range of 1-65535
+ inp = Tensor([1, 8, 8, 0], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 0], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 65536], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 65536], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_faf():
+ # Fused activation functions, if set, must be a valid op type
+ inp = Tensor([1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out, attrs={"fused_activation_function": "Conv2D"})
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_quant_scale():
+ # Quantization scale cannot be infinit
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ op.inputs[0].quantization = QuantizationParameters()
+ op.inputs[0].quantization.scale_f32 = np.inf
+ assert not support.is_operator_supported(op)