aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/graph_optimiser.py4
-rw-r--r--ethosu/vela/supported_operators.py8
-rw-r--r--ethosu/vela/test/test_supported_operators.py23
3 files changed, 26 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 642f1349..7c60368d 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1472,8 +1472,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
dims = len(shape)
# Height and width axes have different index depending on dimensions
- if axis.shape == []: # single axis
- axis = int(axis.values)
+ if len(axis.shape) <= 1: # single axis
+ axis = int(axis.values) if len(axis.shape) == 0 else axis.values[0]
if dims in (2, 3):
if axis == 0:
h, w = shape[axis], 1
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 5bf2c459..dfa27199 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -1040,11 +1040,11 @@ class SupportedOperators:
def constraint_mean_axis(op):
"Axis indices must correspond to height and width axes"
dims = len(op.inputs[0].shape)
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
if dims == 2 or dims == 3:
- valid = axis in (0, 1, [0, 1], [1, 0])
+ valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
elif dims == 4:
- valid = axis in (1, 2, [1, 2], [2, 1])
+ valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
return valid, f"Axis is {axis}"
@classmethod
@@ -1082,7 +1082,7 @@ class SupportedOperators:
keep_dims is set to True and
IFM datatype is int8"""
shape = op.ifm.shape
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
# doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
# and constraint_mean_height_width_product
if (
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 355b472c..666a5ecc 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -840,12 +840,15 @@ def test_constraint_keep_dims_ifm_ofm():
assert support.is_operator_supported(op)
-def create_mean(input_shape, output_shape, indices, datatype, attrs):
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
- indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
+ if type(axis) is list:
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ elif type(axis) is int:
+ indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
@@ -859,8 +862,22 @@ def test_mean_dtype():
def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
def test_mean_hw_product():