aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py137
1 files changed, 82 insertions, 55 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 88e10835..4e989124 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -122,16 +122,18 @@ class SupportedOperators:
self.supported_operator_restrictions.update(
{op: self.check_activation_ops for op in SupportedOperators.activation_ops}
)
- # Setup the generic constraints
+ # Setup the generic constraints. Note: the order matters
self.generic_constraints = []
self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
- self.generic_constraints.append(SupportedOperators.constraint_tens_shapeless)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
- self.generic_constraints.append(SupportedOperators.constraint_faf)
- self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
+ self.generic_constraints.append(SupportedOperators.constraint_faf)
def is_operator_supported(self, op):
if op.type not in SupportedOperators.supported_operators:
@@ -140,7 +142,7 @@ class SupportedOperators:
for constraint in self.generic_constraints:
valid, extra = constraint(op)
if not valid:
- print('Warning: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type))
+ print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
print(" - {}".format(constraint.__doc__))
if extra:
print(" {}".format(extra))
@@ -154,89 +156,93 @@ class SupportedOperators:
"Input(s) and Output Tensors must have a defined shape"
valid = True
extra = []
- for tens in op.inputs + op.outputs:
- if tens:
- valid &= tens.has_fully_defined_shape()
- extra.append("shape={}".format(tens.shape))
- return valid, " ".join(extra)
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if not tens.has_fully_defined_shape():
+ valid = False
+ extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+ return valid, ", ".join(extra)
- @classmethod
- @docstring_format_args([shapeless_input_ops])
- def constraint_tens_shapeless(cls, op):
- "Scalar or Broadcasting Tensors are only valid for Input Tensors, and when op type is: {}"
+ @staticmethod
+ def constraint_tens_output_shapeless(op):
+ "Scalar or Broadcasting Tensors are only valid for Input Tensors"
valid = True
extra = []
- for tens in op.inputs:
- if tens and tens.shape == []:
- valid &= op.type in cls.shapeless_input_ops
- extra.append("shape={}".format(tens.shape))
for tens in op.outputs:
if tens.shape == []:
valid = False
- extra.append("shape={}".format(tens.shape))
- return valid, " ".join(extra)
+ extra.append("Output Tensor '{}' is shapeless".format(tens.name))
+ return valid, ", ".join(extra)
+
+ @classmethod
+ @docstring_format_args([shapeless_input_ops])
+ def constraint_tens_input_shapeless(cls, op):
+ "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
+ valid = False
+ extra.append(tens.name)
+ extra = "Op '{}' has shapeless input tensor(s): {}".format(op.name, ", ".join(extra))
+ return valid, extra
@staticmethod
def constraint_tens_shape_size(op):
"Input(s) and Output Tensors must not be greater than 4D"
valid = True
extra = []
- for tens in op.inputs + op.outputs:
- if tens:
- valid &= len(tens.shape) <= 4
- extra.append("shape={}".format(tens.shape))
- return valid, " ".join(extra)
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if len(tens.shape) > 4:
+ valid = False
+ extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+ return valid, ", ".join(extra)
@classmethod
- @docstring_format_args([supported_dtypes, supported_int32_tensor_ops])
+ @docstring_format_args([supported_dtypes])
def constraint_tens_dtype(cls, op):
- "Tensors must be of type: {}. Tensors which are int32 are only valid when op type is: {}"
+ "Tensors must be of type: {}"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
tensors = tensors if tensors else op.inputs
for tens in tensors:
- if tens.dtype == DataType.int32:
- valid &= op.type in cls.supported_int32_tensor_ops
- else:
- valid &= tens.dtype in cls.supported_dtypes
- extra.append("dtype={}".format(tens.dtype))
- return valid, " ".join(extra)
+ if tens.dtype not in cls.supported_dtypes:
+ valid = False
+ extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
+ return valid, ", ".join(extra)
@classmethod
- @docstring_format_args(tens_dim_range)
- def constraint_tens_dimension(cls, op):
- "Tensor dimensions must be in the range {}-{} (inclusive)"
- tens_min, tens_max = cls.tens_dim_range
+ @docstring_format_args([supported_int32_tensor_ops])
+ def constraint_tens_int32_ops(cls, op):
+ "Tensors which are int32 are only valid when op type is: {}"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
tensors = tensors if tensors else op.inputs
for tens in tensors:
- valid &= all(tens_min <= dim <= tens_max for dim in tens.shape)
- extra.append("shape={}".format(tens.shape))
- return valid, " ".join(extra)
-
- @classmethod
- @docstring_format_args([supported_fused_activations])
- def constraint_faf(cls, op):
- "The fused activation function (if present) must be one of type: {}"
- faf = op.activation
- valid = (faf is None) or (faf in cls.supported_fused_activations)
- extra = "fused_activation_function={}".format(faf)
+ if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
+ valid = False
+ extra.append(tens.name)
+ extra = "Op '{}' has int32 tensor(s): {}".format(op.name, ", ".join(extra))
return valid, extra
- @staticmethod
- def constraint_tens_quant_scale(op):
- "Tensors with quantization scales must be finite"
+ @classmethod
+ @docstring_format_args(tens_dim_range)
+ def constraint_tens_dimension(cls, op):
+ "Tensor dimensions must be in the range {}-{} (inclusive)"
+ tens_min, tens_max = cls.tens_dim_range
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ tensors = tensors if tensors else op.inputs
for tens in tensors:
- if tens.quantization is not None and tens.quantization.scale_f32 is not None:
- valid &= not np.isinf(tens.quantization.scale_f32).any()
- extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32))
- return valid, " ".join(extra)
+ if not all(tens_min <= dim <= tens_max for dim in tens.shape):
+ valid = False
+ extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+ return valid, ", ".join(extra)
@staticmethod
def constraint_tens_quant_none_check(op):
@@ -250,6 +256,27 @@ class SupportedOperators:
extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
return valid, ", ".join(extra)
+ @staticmethod
+ def constraint_tens_quant_scale(op):
+ "Tensors with quantization scales must be finite"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ for tens in tensors:
+ if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
+ valid = False
+ extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
+ return valid, ", ".join(extra)
+
+ @classmethod
+ @docstring_format_args([supported_fused_activations])
+ def constraint_faf(cls, op):
+ "The fused activation function (if present) must be one of type: {}"
+ faf = op.activation
+ valid = (faf is None) or (faf in cls.supported_fused_activations)
+ extra = "Op '{}' has its fused activation function as: {}".format(op.name, faf)
+ return valid, extra
+
@classmethod
def check_convolution_restrictions(cls, op):
# check stride