diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index ebfdbf3f..7a82d2c1 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -506,14 +506,21 @@ def test_mean_dtype(): def test_mean_axis(): - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) - assert not semantic_checker.is_operator_semantic_valid(op) op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 6, 6, 1], [1, 1, 1, 1], [3], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + + op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) - op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + + op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) assert not semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True}) + assert semantic_checker.is_operator_semantic_valid(op) + op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) assert semantic_checker.is_operator_semantic_valid(op) op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True}) |