diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 7a82d2c1..e7fd3073 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -121,14 +121,32 @@ def test_constraint_conv_pass(): def test_constraint_stride_type(): # Stride width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1.5, "stride_h": "1"} assert not semantic_checker.is_operator_semantic_valid(op) +def test_constraint_conv_groups_ifm_depth(): + # Test IFM depth is a whole multiple of the filter kernel depth + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 3, 5]) + assert semantic_checker.is_operator_semantic_valid(op) + + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 4, 5]) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_conv_groups_num_filters(): + # Test number of filter kernels is equally divisible by the number of convolution groups + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 20], weights_shape=[1, 1, 3, 20]) + assert semantic_checker.is_operator_semantic_valid(op) + + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 21], weights_shape=[1, 1, 3, 21]) + assert not semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_dilation_type(): # Dilation width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"} assert not semantic_checker.is_operator_semantic_valid(op) |