aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 26ccfeb6..fd9a9c20 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -81,6 +81,8 @@ class TFLiteSupportedOperators:
| fc_vector_products
# Mean (converts to depthwise conv)
| set((Op.Mean,))
+ # ArgMax (converts to depthwise conv and maxpool)
+ | set((Op.ArgMax,))
)
unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
binary_elem_wise_min_max_ops = set(
@@ -106,15 +108,7 @@ class TFLiteSupportedOperators:
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
- set(
- (
- Op.ReduceSum,
- Op.CLZ,
- Op.Shape,
- )
- )
- | binary_elem_wise_add_mul_sub
- | binary_elem_wise_shift_ops
+ set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
relu_ops = set(
@@ -321,6 +315,11 @@ class TFLiteSupportedOperators:
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
+ # ArgMax specific checks:
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions)
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -873,3 +872,25 @@ class TFLiteSupportedOperators:
extra = ", ".join(extra)
return valid, f"Op has non-const input(s): {extra}"
+
+ @staticmethod
+ def constraint_argmax_axis(op):
+ "Operation must be performed along the depth axis"
+ inp_dims = len(op.inputs[0].shape)
+ axis = op.inputs[1].values
+ return (
+ axis in (3, -1),
+ f"Axis is {axis} and number of input dimensions is {inp_dims}",
+ )
+
+ @staticmethod
+ def constraint_argmax_input_dimensions(op):
+ "Number of input dimensions must be 4"
+ inp_dims = len(op.inputs[0].shape)
+ return inp_dims == 4, f"Number of input dimensions is {inp_dims}"
+
+ @staticmethod
+ def constraint_argmax_depth(op):
+ "IFM depth must be no greater than 127"
+ ifm_depth = op.inputs[0].shape[-1]
+ return ifm_depth <= 127, f"IFM depth is {ifm_depth}"