aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py22
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py4
2 files changed, 23 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 7a82d2c1..e7fd3073 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -121,14 +121,32 @@ def test_constraint_conv_pass():
def test_constraint_stride_type():
# Stride width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1.5, "stride_h": "1"}
assert not semantic_checker.is_operator_semantic_valid(op)
+def test_constraint_conv_groups_ifm_depth():
+ # Test IFM depth is a whole multiple of the filter kernel depth
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 3, 5])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 4, 5])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_conv_groups_num_filters():
+ # Test number of filter kernels is equally divisible by the number of convolution groups
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 20], weights_shape=[1, 1, 3, 20])
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 21], weights_shape=[1, 1, 3, 21])
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
def test_constraint_dilation_type():
# Dilation width and height must be integer types
- op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
assert not semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index f2ad8586..f54211f0 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -16,6 +16,8 @@
#
# Description:
# Unit tests for tflite support_operators
+from typing import List
+
import numpy as np
import pytest
@@ -121,7 +123,7 @@ def test_constraint_conv_pass():
[[1, 8, 40, 8], 8, 1, True],
],
)
-def test_constraint_stride_range(ifm_shape: list[int], stride_w: int, stride_h: int, supported: bool):
+def test_constraint_stride_range(ifm_shape: List[int], stride_w: int, stride_h: int, supported: bool):
# Stride width and height must lie within a certain range
op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 8, 8, 8], [1, 1, 1, 1])
op.attrs = {"stride_w": stride_w, "stride_h": stride_h}