diff options
author | Dwight Lidman <dwight.lidman@arm.com> | 2021-04-28 10:55:46 +0200 |
---|---|---|
committer | Dwight Lidman <dwight.lidman@arm.com> | 2021-04-29 07:55:26 +0000 |
commit | dec6fbcb16fa2f3d7254c4beb3235ab50f72a923 (patch) | |
tree | 3657084dd74701207654e673b82dab3793afdcc3 /ethosu/vela/test | |
parent | 2f75457df27da610afdf01b1c86535030b022a45 (diff) | |
download | ethos-u-vela-dec6fbcb16fa2f3d7254c4beb3235ab50f72a923.tar.gz |
MLBEDSW-4501: Support MEAN single axis variation
When a MEAN operator with a single reduction axis
specifies the axis index attribute as an array with
a single element rather than a scalar index, the
operator is placed on the CPU even though it is
technically supported.
This commit fixes this issue and also adds some new
tests for the axis constraints.
Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: Ia287f3b9cc80a805e972cd4b2962e52526a8dc16
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 355b472c..666a5ecc 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -840,12 +840,15 @@ def test_constraint_keep_dims_ifm_ofm(): assert support.is_operator_supported(op) -def create_mean(input_shape, output_shape, indices, datatype, attrs): +def create_mean(input_shape, output_shape, axis, datatype, attrs): ifm = Tensor(input_shape, datatype, "in") ifm.quantization = testutil.default_quant_params() - indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8) ofm = Tensor(output_shape, datatype, "out") ofm.quantization = testutil.default_quant_params() + if type(axis) is list: + indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8) + elif type(axis) is int: + indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8) op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs) return op @@ -859,8 +862,22 @@ def test_mean_dtype(): def test_mean_axis(): - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) + assert not support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True}) + assert not support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True}) assert not support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + assert not support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) + assert support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) + assert support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True}) + assert support.is_operator_supported(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True}) + assert support.is_operator_supported(op) def test_mean_hw_product(): |