diff options
author | Tim Hall <tim.hall@arm.com> | 2023-06-27 12:07:49 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2023-07-12 11:31:33 +0100 |
commit | 9cf63a3612491198a39f6bd1f4a587589b3ac20a (patch) | |
tree | 336a326892458240720d38537f7ba5e55211defe /ethosu/vela/test | |
parent | da8741a14c3774d3161f59019d3003a2ee944400 (diff) | |
download | ethos-u-vela-9cf63a3612491198a39f6bd1f4a587589b3ac20a.tar.gz |
MLBEDSW-7756: MLCE: Grouped convolutions runtime problem
- Added graph optimiser function to convert convolution groups into
a split followed by separate convolutions and then a concat
- Added semantic check for convolution groups
- Added unit tests for convolution groups semantic checks
- Fixed a minor typing issue with test_constraint_stride_range
Change-Id: I78ade408aa23469a79c9f517c4751da8619b77a9
Signed-off-by: Tim Hall <tim.hall@arm.com>
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 22 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 4 |
2 files changed, 23 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 7a82d2c1..e7fd3073 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -121,14 +121,32 @@ def test_constraint_conv_pass(): def test_constraint_stride_type(): # Stride width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1.5, "stride_h": "1"} assert not semantic_checker.is_operator_semantic_valid(op) +def test_constraint_conv_groups_ifm_depth(): + # Test IFM depth is a whole multiple of the filter kernel depth + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 3, 5]) + assert semantic_checker.is_operator_semantic_valid(op) + + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 4, 5]) + assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_conv_groups_num_filters(): + # Test number of filter kernels is equally divisible by the number of convolution groups + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 20], weights_shape=[1, 1, 3, 20]) + assert semantic_checker.is_operator_semantic_valid(op) + + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 21], weights_shape=[1, 1, 3, 21]) + assert not semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_dilation_type(): # Dilation width and height must be integer types - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1]) op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"} assert not semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index f2ad8586..f54211f0 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -16,6 +16,8 @@ # # Description: # Unit tests for tflite support_operators +from typing import List + import numpy as np import pytest @@ -121,7 +123,7 @@ def test_constraint_conv_pass(): [[1, 8, 40, 8], 8, 1, True], ], ) -def test_constraint_stride_range(ifm_shape: list[int], stride_w: int, stride_h: int, supported: bool): +def test_constraint_stride_range(ifm_shape: List[int], stride_w: int, stride_h: int, supported: bool): # Stride width and height must lie within a certain range op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 8, 8, 8], [1, 1, 1, 1]) op.attrs = {"stride_w": stride_w, "stride_h": stride_h} |