aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
authorAlexander Hansson <Alexander.Hansson@arm.com>2023-06-27 12:36:25 +0000
committerJohan Alfven <johan.alfven@arm.com>2023-07-11 09:10:59 +0000
commit1d5e859973ff18f3e4285f0ca04251ca246a182c (patch)
treee25de299de17ac46269c003585f32b87cedfd137 /ethosu/vela/test/test_tflite_model_semantic.py
parentca9cc420984eba39b85885bf0d2d7b48bb920da9 (diff)
downloadethos-u-vela-1d5e859973ff18f3e4285f0ca04251ca246a182c.tar.gz
MLBEDSW-7652: Add mean support for batch and channel when shape is 1
- Add support for batch and depth channels when shape is 1 - Refactor reshaping in convert_mean_to_depthwise_conv Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com> Change-Id: If663395934ab58c76ba92b6ebaaf484a389ae699
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index ebfdbf3f..7a82d2c1 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -506,14 +506,21 @@ def test_mean_dtype():
def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
- assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 1], [1, 1, 1, 1], [3], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})