aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 832d60f..cd331fd 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -953,3 +953,47 @@ def test_constraint_keep_dims_ifm_ofm():
assert not support.is_operator_supported(op)
op.attrs["keep_num_dims"] = False
assert support.is_operator_supported(op)
+
+
+def create_mean(input_shape, output_shape, indices, datatype, attrs):
+ ifm = Tensor(input_shape, datatype, "in")
+ ifm.quantization = testutil.default_quant_params()
+ indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
+ ofm = Tensor(output_shape, datatype, "out")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
+ return op
+
+
+def test_mean_dtype():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op.ifm.dtype = DataType.int16
+ op.ofm.dtype = DataType.int16
+ assert not support.is_operator_supported(op)
+
+
+def test_mean_properties():
+ op = create_mean([1, 6, 6, 256], [1, 1, 256], [1, 2], DataType.uint8, {})
+ assert support.is_operator_supported(op)
+ op.ifm.quantization.zero_point = 55
+ assert not support.is_operator_supported(op)
+
+
+def test_mean_axis():
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product():
+ op = create_mean([1, 64, 64, 16], [1, 1, 16], [1, 2], DataType.uint8, {})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product_int8():
+ op = create_mean([1, 16, 16, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)