diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 25b68970..a24eebc5 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -220,7 +220,7 @@ class TFLiteSupportedOperators: # Conv specific ops: for op_type in TFLiteSupportedOperators.convolution_ops: - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_width_no_upper_limit) # Conv-like checks: for op_type in TFLiteSupportedOperators.convolution_like_ops: @@ -244,10 +244,11 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier) # Pooling checks: - for op_type in TFLiteSupportedOperators.pooling_ops: + for op_type in TFLiteSupportedOperators.pooling_ops - TFLiteSupportedOperators.avg_pooling_ops: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range) # AVG pooling specific checks: for op_type in TFLiteSupportedOperators.avg_pooling_ops: + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range_no_padding) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad) self.specific_constraints[op_type].append( @@ -545,7 +546,7 @@ class TFLiteSupportedOperators: return True, "Op has depth_multiplier=1" @staticmethod - def constraint_conv_stride(op): + def constraint_stride_width_no_upper_limit(op): """Stride width must be greater than or equal to 1. For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3. Stride height must be between 1 and 3.""" @@ -561,6 +562,17 @@ class TFLiteSupportedOperators: return valid, f"Op has stride WxH as: {w}x{h}" @staticmethod + def constraint_stride_range_no_padding(op): + """Stride width must be greater than or equal to 1. + For stride width greater than 3, valid padding needs to be used.""" + w, _ = op.get_kernel_stride() + valid, message = TFLiteSupportedOperators.constraint_stride_width_no_upper_limit(op) + padding = op.attrs.get("padding", None) + is_optimized_with_valid_padding = padding in (None, Padding.VALID) or w <= 3 + valid = valid and is_optimized_with_valid_padding + return valid, f"{message}, padding: {padding}" + + @staticmethod def constraint_depthwise_conv_stride(op): "Stride values for both width and height must be between 1 and 3" w, h = op.get_kernel_stride() @@ -614,10 +626,11 @@ class TFLiteSupportedOperators: def constraint_filter_range(cls, op): "Kernel filter values for both width and height must be in the range [{}, {}]" if op.attrs["padding"] == Padding.SAME: + sw, _ = op.get_kernel_stride() w = op.kernel.width h = op.kernel.height filter_min, filter_max = cls.filter_range - valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max) + valid = ((filter_min <= w <= filter_max) or sw == w) and (filter_min <= h <= filter_max) return valid, f"Op has kernel filter WxH as: {w}x{h}" return True, "Op has padding=VALID" |