diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 29 |
1 files changed, 9 insertions, 20 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index dfa27199..c993da13 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -25,31 +25,19 @@ from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op from .operation import Padding +from .supported_operators_util import docstring_format_args +from .supported_operators_util import list_formatter from .tensor import check_quantized_tens_scaling_equal from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype -# Custom decorator function to allow formatting docstrings containing "{}" -def docstring_format_args(args): - def docstring(func): - func.__doc__ = func.__doc__.format(*args) - return func - - return docstring - - -def _list_formatter(arg): - # Order and join into a string representation - return ", ".join(sorted(map(str, arg))) - - def _optype_formatter(op_list): # Convert internal op types to external names output = map(optype_to_builtintype, op_list) # Remove UNKNOWNs output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN) - return _list_formatter(output) + return list_formatter(output) class SupportedOperators: @@ -88,7 +76,8 @@ class SupportedOperators: supported_int32_tensor_ops = ( set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) - relu_ops = Op.op_set(Op.is_relu_op) + + relu_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip,)) activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish)) npu_post_ops = ( # activation functions @@ -354,7 +343,7 @@ class SupportedOperators: return valid, ", ".join(extra) @classmethod - @docstring_format_args([_list_formatter(supported_op_dtypes)]) + @docstring_format_args([list_formatter(supported_op_dtypes)]) def constraint_tens_dtype(cls, op): "Tensors must be of type: {}" valid = True @@ -463,7 +452,7 @@ class SupportedOperators: return res @classmethod - @docstring_format_args([_list_formatter(supported_faf_dtypes)]) + @docstring_format_args([list_formatter(supported_faf_dtypes)]) def constraint_faf_type(cls, op): "If a fused activation function is present, the Output tensor must be one of type: {}" if op.activation is None: @@ -549,7 +538,7 @@ class SupportedOperators: return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}" @classmethod - @docstring_format_args([_list_formatter(supported_bias_dtypes)]) + @docstring_format_args([list_formatter(supported_bias_dtypes)]) def constraint_bias_type(cls, op): "Optional Bias tensor must be of type: {}" bias = op.bias @@ -832,7 +821,7 @@ class SupportedOperators: return valid, f"The pad tensor has the shape: {op.inputs[1].shape}" @classmethod - @docstring_format_args([_list_formatter(supported_pad_dtypes)]) + @docstring_format_args([list_formatter(supported_pad_dtypes)]) def constraint_pad_type(cls, op): "Pad tensor must be of type: {}" pad_tensor = op.inputs[1] |