aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index ebfdbf3f..7a82d2c1 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -506,14 +506,21 @@ def test_mean_dtype():
def test_mean_axis():
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
- assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 1], [1, 1, 1, 1], [3], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
- op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
+ op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+ assert semantic_checker.is_operator_semantic_valid(op)
+
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})