aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-02-09 10:03:27 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-02-09 14:47:59 +0000
commit59b9ab9121d17793b5a240f7c51028b6b37a7a6e (patch)
tree56fc4255bd7f6301f54e8b870d3607b08bd3dc0f /ethosu/vela
parent1c5904891b51ff8fa90c7fafbd067b39655d1505 (diff)
downloadethos-u-vela-59b9ab9121d17793b5a240f7c51028b6b37a7a6e.tar.gz
MLBEDSW-7331: Reinstate max stride height constraint of 3 for Conv2D
Reinstate constraint for stride height to (1,3) instead of (1,4) for Conv2D and update unit tests. Change-Id: I17389ee040eeff0cea08279cab1c038e951569ea Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py13
-rw-r--r--ethosu/vela/tflite_supported_operators.py8
2 files changed, 17 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index efe0d000..2713adf9 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -107,7 +107,18 @@ def test_constraint_conv_pass():
@pytest.mark.parametrize(
"stride_w, stride_h, supported",
- [[0, 20, False], [4, 4, True], [4, 5, False], [5, 4, False], [3, 3, True], [1, 1, True], [2, 4, True]],
+ [
+ [0, 20, False],
+ [4, 1, True],
+ [4, 2, True],
+ [2, 2, True],
+ [4, 4, False],
+ [4, 5, False],
+ [5, 4, False],
+ [3, 3, True],
+ [1, 1, True],
+ [2, 4, False],
+ ],
)
def test_constraint_stride_range(stride_w: int, stride_h: int, supported: bool):
# Stride width and height must lie within a certain range
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 2a1eba7d..26ccfeb6 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -537,10 +537,12 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_conv_stride(op):
- "Stride values for both width and height must be between 1 and 4"
+ "Stride values for height must be between 1 and 3 and for width between 1 and 4"
w, h = op.get_kernel_stride()
- stride_min, stride_max = 1, 4
- valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
+ stride_min_w_h = 1
+ stride_max_w = 4
+ stride_max_h = 3
+ valid = (stride_min_w_h <= w <= stride_max_w) and (stride_min_w_h <= h <= stride_max_h)
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod