aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py29
1 files changed, 9 insertions, 20 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index dfa27199..c993da13 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -25,31 +25,19 @@ from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
from .operation import Padding
+from .supported_operators_util import docstring_format_args
+from .supported_operators_util import list_formatter
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
-# Custom decorator function to allow formatting docstrings containing "{}"
-def docstring_format_args(args):
- def docstring(func):
- func.__doc__ = func.__doc__.format(*args)
- return func
-
- return docstring
-
-
-def _list_formatter(arg):
- # Order and join into a string representation
- return ", ".join(sorted(map(str, arg)))
-
-
def _optype_formatter(op_list):
# Convert internal op types to external names
output = map(optype_to_builtintype, op_list)
# Remove UNKNOWNs
output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
- return _list_formatter(output)
+ return list_formatter(output)
class SupportedOperators:
@@ -88,7 +76,8 @@ class SupportedOperators:
supported_int32_tensor_ops = (
set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
- relu_ops = Op.op_set(Op.is_relu_op)
+
+ relu_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip,))
activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
npu_post_ops = (
# activation functions
@@ -354,7 +343,7 @@ class SupportedOperators:
return valid, ", ".join(extra)
@classmethod
- @docstring_format_args([_list_formatter(supported_op_dtypes)])
+ @docstring_format_args([list_formatter(supported_op_dtypes)])
def constraint_tens_dtype(cls, op):
"Tensors must be of type: {}"
valid = True
@@ -463,7 +452,7 @@ class SupportedOperators:
return res
@classmethod
- @docstring_format_args([_list_formatter(supported_faf_dtypes)])
+ @docstring_format_args([list_formatter(supported_faf_dtypes)])
def constraint_faf_type(cls, op):
"If a fused activation function is present, the Output tensor must be one of type: {}"
if op.activation is None:
@@ -549,7 +538,7 @@ class SupportedOperators:
return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}"
@classmethod
- @docstring_format_args([_list_formatter(supported_bias_dtypes)])
+ @docstring_format_args([list_formatter(supported_bias_dtypes)])
def constraint_bias_type(cls, op):
"Optional Bias tensor must be of type: {}"
bias = op.bias
@@ -832,7 +821,7 @@ class SupportedOperators:
return valid, f"The pad tensor has the shape: {op.inputs[1].shape}"
@classmethod
- @docstring_format_args([_list_formatter(supported_pad_dtypes)])
+ @docstring_format_args([list_formatter(supported_pad_dtypes)])
def constraint_pad_type(cls, op):
"Pad tensor must be of type: {}"
pad_tensor = op.inputs[1]