diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 46f7a5d3..ccf61042 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -25,6 +25,7 @@ from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op from .tensor import check_quantized_tens_scaling_equal +from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype @@ -37,6 +38,15 @@ def docstring_format_args(args): return docstring +def _optype_formatter(op_list): + # Convert internal op types to external names + output = map(optype_to_builtintype, op_list) + # Remove UNKNOWNs + output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN) + # Order alphabetically + return sorted(output) + + class SupportedOperators: # Categorised lists of supported operators npu_pre_ops = set((Op.SplitSliceRead,)) @@ -99,6 +109,10 @@ class SupportedOperators: filter_range = (1, 8) filter_height_range = (1, 256) filter_product_range = (1, 256 * 256) + # Ordered, external names of op types for the constraint reasons + docstring_shapeless_input_ops = _optype_formatter(shapeless_input_ops) + docstring_supported_int32_tensor_ops = _optype_formatter(supported_int32_tensor_ops) + docstring_supported_fused_activations = _optype_formatter(supported_fused_activations) def __init__(self): # Setup the generic constraints. Note: the order matters @@ -279,7 +293,7 @@ class SupportedOperators: return valid, f"Output Tensor '{ofm.name}' is scalar" @classmethod - @docstring_format_args([shapeless_input_ops]) + @docstring_format_args([docstring_shapeless_input_ops]) def constraint_tens_input_scalar(cls, op): "Scalar Input tensors are only valid for op type: {}" valid = True @@ -320,7 +334,7 @@ class SupportedOperators: return valid, ", ".join(extra) @classmethod - @docstring_format_args([supported_int32_tensor_ops]) + @docstring_format_args([docstring_supported_int32_tensor_ops]) def constraint_tens_int32_ops(cls, op): "Tensors which are int32 are only valid when op type is: {}" valid = True @@ -377,7 +391,7 @@ class SupportedOperators: return valid, ", ".join(extra) @classmethod - @docstring_format_args([supported_fused_activations]) + @docstring_format_args([docstring_supported_fused_activations]) def constraint_faf(cls, op): "The fused activation function (if present) must be one of type: {}" if op.activation is None: |