aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-01-24 16:29:06 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-02-07 15:55:53 +0000
commit090f18a55fcd4f7ae8ca1ae633418d05c62cbb6e (patch)
tree0d88ac2cf3253af50f63c507d8b397831bd32b7a /ethosu/vela/test
parent12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (diff)
downloadethos-u-vela-090f18a55fcd4f7ae8ca1ae633418d05c62cbb6e.tar.gz
MLBEDSW-7237: CONV_2D stride 4 optimisation
* Extend stride range from (1,3) to (1,4) * Add stride 4 support when optimising CONV_2D * Add some tests for various strides Change-Id: Iddaeb42c4a6e02695ecdd3740bc8b9dd59a7eb3c Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 6a0b58e3..efe0d000 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -17,6 +17,7 @@
# Description:
# Unit tests for tflite support_operators
import numpy as np
+import pytest
from ethosu.vela.data_type import DataType
from ethosu.vela.operation import ActivationFunction
@@ -104,11 +105,15 @@ def test_constraint_conv_pass():
assert support.is_operator_supported(op)
-def test_constraint_stride_range():
+@pytest.mark.parametrize(
+ "stride_w, stride_h, supported",
+ [[0, 20, False], [4, 4, True], [4, 5, False], [5, 4, False], [3, 3, True], [1, 1, True], [2, 4, True]],
+)
+def test_constraint_stride_range(stride_w: int, stride_h: int, supported: bool):
# Stride width and height must lie within a certain range
- op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 0, "stride_h": 20}
- assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], [1, 1, 1, 1])
+ op.attrs = {"stride_w": stride_w, "stride_h": stride_h}
+ assert support.is_operator_supported(op) == supported
def test_constraint_dilated_height_range():