From 090f18a55fcd4f7ae8ca1ae633418d05c62cbb6e Mon Sep 17 00:00:00 2001 From: Raul Farkas Date: Tue, 24 Jan 2023 16:29:06 +0000 Subject: MLBEDSW-7237: CONV_2D stride 4 optimisation * Extend stride range from (1,3) to (1,4) * Add stride 4 support when optimising CONV_2D * Add some tests for various strides Change-Id: Iddaeb42c4a6e02695ecdd3740bc8b9dd59a7eb3c Signed-off-by: Raul Farkas --- ethosu/vela/test/test_tflite_supported_operators.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'ethosu/vela/test') diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 6a0b58e3..efe0d000 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -17,6 +17,7 @@ # Description: # Unit tests for tflite support_operators import numpy as np +import pytest from ethosu.vela.data_type import DataType from ethosu.vela.operation import ActivationFunction @@ -104,11 +105,15 @@ def test_constraint_conv_pass(): assert support.is_operator_supported(op) -def test_constraint_stride_range(): +@pytest.mark.parametrize( + "stride_w, stride_h, supported", + [[0, 20, False], [4, 4, True], [4, 5, False], [5, 4, False], [3, 3, True], [1, 1, True], [2, 4, True]], +) +def test_constraint_stride_range(stride_w: int, stride_h: int, supported: bool): # Stride width and height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8]) - op.attrs = {"stride_w": 0, "stride_h": 20} - assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], [1, 1, 1, 1]) + op.attrs = {"stride_w": stride_w, "stride_h": stride_h} + assert support.is_operator_supported(op) == supported def test_constraint_dilated_height_range(): -- cgit v1.2.1