diff options
author | Raul Farkas <raul.farkas@arm.com> | 2023-01-30 12:58:46 +0000 |
---|---|---|
committer | Raul Farkas <raul.farkas@arm.com> | 2023-05-10 13:34:42 +0100 |
commit | 10d6b3b3fa594b9aca4a72f002acea9f927f9c60 (patch) | |
tree | 3b5f71ad590c81e53bca82ab2ffb20196d2408e2 /ethosu/vela/tflite_supported_operators.py | |
parent | 69782af3ff2cda96dff09ad66799b3ac8f16c19d (diff) | |
download | ethos-u-vela-10d6b3b3fa594b9aca4a72f002acea9f927f9c60.tar.gz |
MLBEDSW-7283: Add opt cases for strided CONV2D
* Implement a general optimization solution for strided CONV2D that
supports a stride_w with no upper bound.
* Implement filter zero padding to allow for optimization in those cases
in which the filter width is not divisible by the stride width.
E.g.: Filter width = 8, stride width = 3 ->
Filter width = 8 + 1 (0 padding) = 9, stride width = 3
* Implement partial optimization to reduce the stride to hw supported
strides (i.e. 2 and 3) when optimizing to reach a stride = 1 is not
possible due to the IFM width not being divisible by the stride width.
* Implement optimization for when SAME padding is used. If the pre-opt
and post-opt padding do not match, add zero padding to the filter so
that the post-opt IFM padding matches.
Change-Id: Ia66b0d107281fa9993f6bf4d0c26627ee743253b
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 95c7de33..8e9ab12f 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -542,12 +542,11 @@ class TFLiteSupportedOperators: @staticmethod def constraint_conv_stride(op): - "Stride values for height must be between 1 and 3 and for width between 1 and 4" + "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3" w, h = op.get_kernel_stride() - stride_min_w_h = 1 - stride_max_w = 4 + stride_min = 1 stride_max_h = 3 - valid = (stride_min_w_h <= w <= stride_max_w) and (stride_min_w_h <= h <= stride_max_h) + valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) return valid, f"Op has stride WxH as: {w}x{h}" @staticmethod |