diff options
author | Alexander Hansson <Alexander.Hansson@arm.com> | 2023-06-27 12:36:25 +0000 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-07-11 09:10:59 +0000 |
commit | 1d5e859973ff18f3e4285f0ca04251ca246a182c (patch) | |
tree | e25de299de17ac46269c003585f32b87cedfd137 /ethosu/vela/test/test_tflite_model_semantic.py | |
parent | ca9cc420984eba39b85885bf0d2d7b48bb920da9 (diff) | |
download | ethos-u-vela-1d5e859973ff18f3e4285f0ca04251ca246a182c.tar.gz |
MLBEDSW-7652: Add mean support for batch and channel when shape is 1
- Add support for batch and depth channels when shape is 1
- Refactor reshaping in convert_mean_to_depthwise_conv
Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: If663395934ab58c76ba92b6ebaaf484a389ae699
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index ebfdbf3f..7a82d2c1 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -506,14 +506,21 @@ def test_mean_dtype(): def test_mean_axis(): - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) - assert not semantic_checker.is_operator_semantic_valid(op) op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 6, 6, 1], [1, 1, 1, 1], [3], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + + op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + + op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) assert semantic_checker.is_operator_semantic_valid(op) op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) |