diff options
author | Raul Farkas <raul.farkas@arm.com> | 2023-05-16 17:18:31 +0100 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-06-16 12:25:03 +0000 |
commit | 3b64f068db4ea8e954a1b472de169dd423b8c049 (patch) | |
tree | cbd0c98da22bb62473daf08fdb6b53209ef6d971 /ethosu/vela/tflite_supported_operators.py | |
parent | 5d24821355ea5c3af1d069fd50864c5f2f0effd3 (diff) | |
download | ethos-u-vela-3b64f068db4ea8e954a1b472de169dd423b8c049.tar.gz |
MLBEDSW-7648: Fix bug with filter padding in conv2d
* Fix bug that caused filter padding to not be added proportionally
compared to the hardware padding added to IFM.
* Update needed_total_padding function that calculates hardware padding
to also account for the cases in which IFM width is not divisible by
the stride width.
* Update supported ops constraint on strides for conv2d to mark ops with
stride width > 3 and IFM width that is not divisible by the
optimization resize factor as not supported.
* Update unit tests that verify correct functionality when checking
whether ops are supported or not.
Change-Id: I62f14cca890b779ca787a9603fa37c873ad522f8
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 0dfdc666..25b68970 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -29,6 +29,7 @@ from .supported_operators_util import list_formatter from .tensor import check_quantized_tens_scaling_equal from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype +from .utils import calc_resize_factor def _optype_formatter(op_list): @@ -545,11 +546,18 @@ class TFLiteSupportedOperators: @staticmethod def constraint_conv_stride(op): - "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3" + """Stride width must be greater than or equal to 1. + For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3. + Stride height must be between 1 and 3.""" w, h = op.get_kernel_stride() stride_min = 1 stride_max_h = 3 - valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) + ifm_width = op.ifm.shape[2] + _, optimized_stride = calc_resize_factor(ifm_width, w) if w > 1 else (1, w) + # Optimized stride indicates the final Conv2D stride width after all optimizations are performed + can_optimize_stride_width_gt_3 = optimized_stride <= 3 + valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) and can_optimize_stride_width_gt_3 + return valid, f"Op has stride WxH as: {w}x{h}" @staticmethod |