aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 0dfdc666..25b68970 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -29,6 +29,7 @@ from .supported_operators_util import list_formatter
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
+from .utils import calc_resize_factor
def _optype_formatter(op_list):
@@ -545,11 +546,18 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_conv_stride(op):
- "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3"
+ """Stride width must be greater than or equal to 1.
+ For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.
+ Stride height must be between 1 and 3."""
w, h = op.get_kernel_stride()
stride_min = 1
stride_max_h = 3
- valid = (stride_min <= w) and (stride_min <= h <= stride_max_h)
+ ifm_width = op.ifm.shape[2]
+ _, optimized_stride = calc_resize_factor(ifm_width, w) if w > 1 else (1, w)
+ # Optimized stride indicates the final Conv2D stride width after all optimizations are performed
+ can_optimize_stride_width_gt_3 = optimized_stride <= 3
+ valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) and can_optimize_stride_width_gt_3
+
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod