aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_supported_operators.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 355b472c..666a5ecc 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -840,12 +840,15 @@ def test_constraint_keep_dims_ifm_ofm():
assert support.is_operator_supported(op)
-def create_mean(input_shape, output_shape, indices, datatype, attrs):
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
- indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
+ if type(axis) is list:
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ elif type(axis) is int:
+ indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
@@ -859,8 +862,22 @@ def test_mean_dtype():
def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
def test_mean_hw_product():