diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 0dfdc666..25b68970 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -29,6 +29,7 @@ from .supported_operators_util import list_formatter from .tensor import check_quantized_tens_scaling_equal from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype +from .utils import calc_resize_factor def _optype_formatter(op_list): @@ -545,11 +546,18 @@ class TFLiteSupportedOperators: @staticmethod def constraint_conv_stride(op): - "Stride width must be greater than or equal to 1 and stride height must be between 1 and 3" + """Stride width must be greater than or equal to 1. + For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3. + Stride height must be between 1 and 3.""" w, h = op.get_kernel_stride() stride_min = 1 stride_max_h = 3 - valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) + ifm_width = op.ifm.shape[2] + _, optimized_stride = calc_resize_factor(ifm_width, w) if w > 1 else (1, w) + # Optimized stride indicates the final Conv2D stride width after all optimizations are performed + can_optimize_stride_width_gt_3 = optimized_stride <= 3 + valid = (stride_min <= w) and (stride_min <= h <= stride_max_h) and can_optimize_stride_width_gt_3 + return valid, f"Op has stride WxH as: {w}x{h}" @staticmethod |