From 56811e6d3c62ae017f6eb298fb553f7d1e77cc96 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Mon, 27 Mar 2023 11:33:50 +0200 Subject: MLBEDSW-7439: Add support for input dims < 4 for ArgMax - Updated ARG_MAX to support IFM rank less than 4 - Regenerated SUPPORTED_OPS.md Change-Id: Icd8e72733279413cbea49021325e1ab06fdc6011 Signed-off-by: Johan Alfven --- ethosu/vela/tflite_graph_optimiser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'ethosu/vela/tflite_graph_optimiser.py') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 518b6db0..e0c7fd2c 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -41,6 +41,7 @@ from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence from .numeric_util import clamp_sigmoid +from .numeric_util import full_shape from .numeric_util import round_away_zero from .operation import create_activation_function from .operation import ExplicitScaling @@ -524,7 +525,7 @@ def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng): op.name = "depthwise_conv_SHL_7" op.type = Op.DepthwiseConv2DBias op.attrs.update(dw_op_attrs) - n, h, w, c = ifm.shape + n, h, w, c = full_shape(4, ifm.shape, 1) shape = [1, 1, 1, c] kernel = np.dstack([2**7] * c) op.inputs = [] -- cgit v1.2.1