aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py9
1 files changed, 1 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index fd9a9c20..66b9e944 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -316,7 +316,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
# ArgMax specific checks:
- self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
@@ -879,17 +878,11 @@ class TFLiteSupportedOperators:
inp_dims = len(op.inputs[0].shape)
axis = op.inputs[1].values
return (
- axis in (3, -1),
+ axis in (inp_dims - 1, -1),
f"Axis is {axis} and number of input dimensions is {inp_dims}",
)
@staticmethod
- def constraint_argmax_input_dimensions(op):
- "Number of input dimensions must be 4"
- inp_dims = len(op.inputs[0].shape)
- return inp_dims == 4, f"Number of input dimensions is {inp_dims}"
-
- @staticmethod
def constraint_argmax_depth(op):
"IFM depth must be no greater than 127"
ifm_depth = op.inputs[0].shape[-1]