From 56811e6d3c62ae017f6eb298fb553f7d1e77cc96 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Mon, 27 Mar 2023 11:33:50 +0200 Subject: MLBEDSW-7439: Add support for input dims < 4 for ArgMax - Updated ARG_MAX to support IFM rank less than 4 - Regenerated SUPPORTED_OPS.md Change-Id: Icd8e72733279413cbea49021325e1ab06fdc6011 Signed-off-by: Johan Alfven --- ethosu/vela/tflite_supported_operators.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) (limited to 'ethosu/vela/tflite_supported_operators.py') diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index fd9a9c20..66b9e944 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -316,7 +316,6 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) # ArgMax specific checks: - self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth) @@ -879,16 +878,10 @@ class TFLiteSupportedOperators: inp_dims = len(op.inputs[0].shape) axis = op.inputs[1].values return ( - axis in (3, -1), + axis in (inp_dims - 1, -1), f"Axis is {axis} and number of input dimensions is {inp_dims}", ) - @staticmethod - def constraint_argmax_input_dimensions(op): - "Number of input dimensions must be 4" - inp_dims = len(op.inputs[0].shape) - return inp_dims == 4, f"Number of input dimensions is {inp_dims}" - @staticmethod def constraint_argmax_depth(op): "IFM depth must be no greater than 127" -- cgit v1.2.1