aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25b68970..a24eebc5 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -220,7 +220,7 @@ class TFLiteSupportedOperators:
# Conv specific ops:
for op_type in TFLiteSupportedOperators.convolution_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_width_no_upper_limit)
# Conv-like checks:
for op_type in TFLiteSupportedOperators.convolution_like_ops:
@@ -244,10 +244,11 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
# Pooling checks:
- for op_type in TFLiteSupportedOperators.pooling_ops:
+ for op_type in TFLiteSupportedOperators.pooling_ops - TFLiteSupportedOperators.avg_pooling_ops:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
# AVG pooling specific checks:
for op_type in TFLiteSupportedOperators.avg_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range_no_padding)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad)
self.specific_constraints[op_type].append(
@@ -545,7 +546,7 @@ class TFLiteSupportedOperators:
return True, "Op has depth_multiplier=1"
@staticmethod
- def constraint_conv_stride(op):
+ def constraint_stride_width_no_upper_limit(op):
"""Stride width must be greater than or equal to 1.
For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.
Stride height must be between 1 and 3."""
@@ -561,6 +562,17 @@ class TFLiteSupportedOperators:
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod
+ def constraint_stride_range_no_padding(op):
+ """Stride width must be greater than or equal to 1.
+ For stride width greater than 3, valid padding needs to be used."""
+ w, _ = op.get_kernel_stride()
+ valid, message = TFLiteSupportedOperators.constraint_stride_width_no_upper_limit(op)
+ padding = op.attrs.get("padding", None)
+ is_optimized_with_valid_padding = padding in (None, Padding.VALID) or w <= 3
+ valid = valid and is_optimized_with_valid_padding
+ return valid, f"{message}, padding: {padding}"
+
+ @staticmethod
def constraint_depthwise_conv_stride(op):
"Stride values for both width and height must be between 1 and 3"
w, h = op.get_kernel_stride()
@@ -614,10 +626,11 @@ class TFLiteSupportedOperators:
def constraint_filter_range(cls, op):
"Kernel filter values for both width and height must be in the range [{}, {}]"
if op.attrs["padding"] == Padding.SAME:
+ sw, _ = op.get_kernel_stride()
w = op.kernel.width
h = op.kernel.height
filter_min, filter_max = cls.filter_range
- valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max)
+ valid = ((filter_min <= w <= filter_max) or sw == w) and (filter_min <= h <= filter_max)
return valid, f"Op has kernel filter WxH as: {w}x{h}"
return True, "Op has padding=VALID"