diff options
author | Raul Farkas <raul.farkas@arm.com> | 2023-02-09 10:03:27 +0000 |
---|---|---|
committer | Raul Farkas <raul.farkas@arm.com> | 2023-02-09 14:47:59 +0000 |
commit | 59b9ab9121d17793b5a240f7c51028b6b37a7a6e (patch) | |
tree | 56fc4255bd7f6301f54e8b870d3607b08bd3dc0f | |
parent | 1c5904891b51ff8fa90c7fafbd067b39655d1505 (diff) | |
download | ethos-u-vela-59b9ab9121d17793b5a240f7c51028b6b37a7a6e.tar.gz |
MLBEDSW-7331: Reinstate max stride height constraint of 3 for Conv2D
Reinstate constraint for stride height to (1,3) instead of (1,4) for
Conv2D and update unit tests.
Change-Id: I17389ee040eeff0cea08279cab1c038e951569ea
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
-rw-r--r-- | SUPPORTED_OPS.md | 4 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 13 | ||||
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 8 |
3 files changed, 19 insertions, 6 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index 3d045923..1ed37bde 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -1,7 +1,7 @@ # Supported Ops This file was automatically generated by Vela using the `--supported-ops-report` parameter. -Vela version: `3.6.1.dev18+g34cbb970` +Vela version: `3.6.1.dev18+g090f18a5` This file complies with [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md) @@ -123,7 +123,7 @@ This is a list of constraints that the CONV_2D operator must satisfy in order to - Stride values for both width and height must be integer types - Dilation factor values for both width and height must be integer types -- Stride values for both width and height must be between 1 and 4 +- Stride values for height must be between 1 and 3 and for width between 1 and 4 - Dilated kernel height must be in the range [1, 64] - Product of dilated kernel width and height must be in the range [1, 4096] - Weight tensor must be 8-bit diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index efe0d000..2713adf9 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -107,7 +107,18 @@ def test_constraint_conv_pass(): @pytest.mark.parametrize( "stride_w, stride_h, supported", - [[0, 20, False], [4, 4, True], [4, 5, False], [5, 4, False], [3, 3, True], [1, 1, True], [2, 4, True]], + [ + [0, 20, False], + [4, 1, True], + [4, 2, True], + [2, 2, True], + [4, 4, False], + [4, 5, False], + [5, 4, False], + [3, 3, True], + [1, 1, True], + [2, 4, False], + ], ) def test_constraint_stride_range(stride_w: int, stride_h: int, supported: bool): # Stride width and height must lie within a certain range diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 2a1eba7d..26ccfeb6 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -537,10 +537,12 @@ class TFLiteSupportedOperators: @staticmethod def constraint_conv_stride(op): - "Stride values for both width and height must be between 1 and 4" + "Stride values for height must be between 1 and 3 and for width between 1 and 4" w, h = op.get_kernel_stride() - stride_min, stride_max = 1, 4 - valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max) + stride_min_w_h = 1 + stride_max_w = 4 + stride_max_h = 3 + valid = (stride_min_w_h <= w <= stride_max_w) and (stride_min_w_h <= h <= stride_max_h) return valid, f"Op has stride WxH as: {w}x{h}" @staticmethod |