aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-05-16 17:18:31 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-06-16 12:25:03 +0000
commit3b64f068db4ea8e954a1b472de169dd423b8c049 (patch)
treecbd0c98da22bb62473daf08fdb6b53209ef6d971 /ethosu/vela/tflite_supported_operators.py
parent5d24821355ea5c3af1d069fd50864c5f2f0effd3 (diff)
downloadethos-u-vela-3b64f068db4ea8e954a1b472de169dd423b8c049.tar.gz
MLBEDSW-7648: Fix bug with filter padding in conv2d
* Fix bug that caused filter padding to not be added proportionally compared to the hardware padding added to IFM. * Update needed_total_padding function that calculates hardware padding to also account for the cases in which IFM width is not divisible by the stride width. * Update supported ops constraint on strides for conv2d to mark ops with stride width > 3 and IFM width that is not divisible by the optimization resize factor as not supported. * Update unit tests that verify correct functionality when checking whether ops are supported or not. Change-Id: I62f14cca890b779ca787a9603fa37c873ad522f8 Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 0dfdc666..25b68970 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -29,6 +29,7 @@ from .supported_operators_util import list_formatter
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
+from .utils import calc_resize_factor
def _optype_formatter(op_list):
@@ -545,11 +546,18 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_conv_stride(op):
- "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3"
+ """Stride width must be greater than or equal to 1.
+ For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.
+ Stride height must be between 1 and 3."""
w, h = op.get_kernel_stride()
stride_min = 1
stride_max_h = 3
- valid = (stride_min <= w) and (stride_min <= h <= stride_max_h)
+ ifm_width = op.ifm.shape[2]
+ _, optimized_stride = calc_resize_factor(ifm_width, w) if w > 1 else (1, w)
+ # Optimized stride indicates the final Conv2D stride width after all optimizations are performed
+ can_optimize_stride_width_gt_3 = optimized_stride <= 3
+ valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) and can_optimize_stride_width_gt_3
+
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod