aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2021-04-28 10:55:46 +0200
committerDwight Lidman <dwight.lidman@arm.com>2021-04-29 07:55:26 +0000
commitdec6fbcb16fa2f3d7254c4beb3235ab50f72a923 (patch)
tree3657084dd74701207654e673b82dab3793afdcc3
parent2f75457df27da610afdf01b1c86535030b022a45 (diff)
downloadethos-u-vela-dec6fbcb16fa2f3d7254c4beb3235ab50f72a923.tar.gz
MLBEDSW-4501: Support MEAN single axis variation
When a MEAN operator with a single reduction axis specifies the axis index attribute as an array with a single element rather than a scalar index, the operator is placed on the CPU even though it is technically supported. This commit fixes this issue and also adds some new tests for the axis constraints. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: Ia287f3b9cc80a805e972cd4b2962e52526a8dc16
-rw-r--r--ethosu/vela/graph_optimiser.py4
-rw-r--r--ethosu/vela/supported_operators.py8
-rw-r--r--ethosu/vela/test/test_supported_operators.py23
3 files changed, 26 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 642f1349..7c60368d 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1472,8 +1472,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
dims = len(shape)
# Height and width axes have different index depending on dimensions
- if axis.shape == []: # single axis
- axis = int(axis.values)
+ if len(axis.shape) <= 1: # single axis
+ axis = int(axis.values) if len(axis.shape) == 0 else axis.values[0]
if dims in (2, 3):
if axis == 0:
h, w = shape[axis], 1
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 5bf2c459..dfa27199 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -1040,11 +1040,11 @@ class SupportedOperators:
def constraint_mean_axis(op):
"Axis indices must correspond to height and width axes"
dims = len(op.inputs[0].shape)
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
if dims == 2 or dims == 3:
- valid = axis in (0, 1, [0, 1], [1, 0])
+ valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
elif dims == 4:
- valid = axis in (1, 2, [1, 2], [2, 1])
+ valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
return valid, f"Axis is {axis}"
@classmethod
@@ -1082,7 +1082,7 @@ class SupportedOperators:
keep_dims is set to True and
IFM datatype is int8"""
shape = op.ifm.shape
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
# doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
# and constraint_mean_height_width_product
if (
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 355b472c..666a5ecc 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -840,12 +840,15 @@ def test_constraint_keep_dims_ifm_ofm():
assert support.is_operator_supported(op)
-def create_mean(input_shape, output_shape, indices, datatype, attrs):
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
- indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
+ if type(axis) is list:
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ elif type(axis) is int:
+ indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
@@ -859,8 +862,22 @@ def test_mean_dtype():
def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
def test_mean_hw_product():