aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-01-30 12:58:46 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-05-10 13:34:42 +0100
commit10d6b3b3fa594b9aca4a72f002acea9f927f9c60 (patch)
tree3b5f71ad590c81e53bca82ab2ffb20196d2408e2 /ethosu/vela/tflite_supported_operators.py
parent69782af3ff2cda96dff09ad66799b3ac8f16c19d (diff)
downloadethos-u-vela-10d6b3b3fa594b9aca4a72f002acea9f927f9c60.tar.gz
MLBEDSW-7283: Add opt cases for strided CONV2D
* Implement a general optimization solution for strided CONV2D that supports a stride_w with no upper bound. * Implement filter zero padding to allow for optimization in those cases in which the filter width is not divisible by the stride width. E.g.: Filter width = 8, stride width = 3 -> Filter width = 8 + 1 (0 padding) = 9, stride width = 3 * Implement partial optimization to reduce the stride to hw supported strides (i.e. 2 and 3) when optimizing to reach a stride = 1 is not possible due to the IFM width not being divisible by the stride width. * Implement optimization for when SAME padding is used. If the pre-opt and post-opt padding do not match, add zero padding to the filter so that the post-opt IFM padding matches. Change-Id: Ia66b0d107281fa9993f6bf4d0c26627ee743253b Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 95c7de33..8e9ab12f 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -542,12 +542,11 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_conv_stride(op):
- "Stride values for height must be between 1 and 3 and for width between 1 and 4"
+ "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3"
w, h = op.get_kernel_stride()
- stride_min_w_h = 1
- stride_max_w = 4
+ stride_min = 1
stride_max_h = 3
- valid = (stride_min_w_h <= w <= stride_max_w) and (stride_min_w_h <= h <= stride_max_h)
+ valid = (stride_min <= w) and (stride_min <= h <= stride_max_h)
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod