From dec6fbcb16fa2f3d7254c4beb3235ab50f72a923 Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Wed, 28 Apr 2021 10:55:46 +0200 Subject: MLBEDSW-4501: Support MEAN single axis variation When a MEAN operator with a single reduction axis specifies the axis index attribute as an array with a single element rather than a scalar index, the operator is placed on the CPU even though it is technically supported. This commit fixes this issue and also adds some new tests for the axis constraints. Signed-off-by: Dwight Lidman Change-Id: Ia287f3b9cc80a805e972cd4b2962e52526a8dc16 --- ethosu/vela/supported_operators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'ethosu/vela/supported_operators.py') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 5bf2c459..dfa27199 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -1040,11 +1040,11 @@ class SupportedOperators: def constraint_mean_axis(op): "Axis indices must correspond to height and width axes" dims = len(op.inputs[0].shape) - axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) if dims == 2 or dims == 3: - valid = axis in (0, 1, [0, 1], [1, 0]) + valid = axis in (0, 1, [0], [1], [0, 1], [1, 0]) elif dims == 4: - valid = axis in (1, 2, [1, 2], [2, 1]) + valid = axis in (1, 2, [1], [2], [1, 2], [2, 1]) return valid, f"Axis is {axis}" @classmethod @@ -1082,7 +1082,7 @@ class SupportedOperators: keep_dims is set to True and IFM datatype is int8""" shape = op.ifm.shape - axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool # and constraint_mean_height_width_product if ( -- cgit v1.2.1