aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 46f7a5d3..ccf61042 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -25,6 +25,7 @@ from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
from .tensor import check_quantized_tens_scaling_equal
+from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -37,6 +38,15 @@ def docstring_format_args(args):
return docstring
+def _optype_formatter(op_list):
+ # Convert internal op types to external names
+ output = map(optype_to_builtintype, op_list)
+ # Remove UNKNOWNs
+ output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
+ # Order alphabetically
+ return sorted(output)
+
+
class SupportedOperators:
# Categorised lists of supported operators
npu_pre_ops = set((Op.SplitSliceRead,))
@@ -99,6 +109,10 @@ class SupportedOperators:
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
+ # Ordered, external names of op types for the constraint reasons
+ docstring_shapeless_input_ops = _optype_formatter(shapeless_input_ops)
+ docstring_supported_int32_tensor_ops = _optype_formatter(supported_int32_tensor_ops)
+ docstring_supported_fused_activations = _optype_formatter(supported_fused_activations)
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -279,7 +293,7 @@ class SupportedOperators:
return valid, f"Output Tensor '{ofm.name}' is scalar"
@classmethod
- @docstring_format_args([shapeless_input_ops])
+ @docstring_format_args([docstring_shapeless_input_ops])
def constraint_tens_input_scalar(cls, op):
"Scalar Input tensors are only valid for op type: {}"
valid = True
@@ -320,7 +334,7 @@ class SupportedOperators:
return valid, ", ".join(extra)
@classmethod
- @docstring_format_args([supported_int32_tensor_ops])
+ @docstring_format_args([docstring_supported_int32_tensor_ops])
def constraint_tens_int32_ops(cls, op):
"Tensors which are int32 are only valid when op type is: {}"
valid = True
@@ -377,7 +391,7 @@ class SupportedOperators:
return valid, ", ".join(extra)
@classmethod
- @docstring_format_args([supported_fused_activations])
+ @docstring_format_args([docstring_supported_fused_activations])
def constraint_faf(cls, op):
"The fused activation function (if present) must be one of type: {}"
if op.activation is None: