aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-05-09 09:09:17 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-06-16 12:26:19 +0000
commit3e7157ba59f12aa0d277a9b3a7cb3f8a19267338 (patch)
treea06a40060fc1d9a44b3688fea916b61d26c56a65 /ethosu/vela/tflite_supported_operators.py
parent3b64f068db4ea8e954a1b472de169dd423b8c049 (diff)
downloadethos-u-vela-3e7157ba59f12aa0d277a9b3a7cb3f8a19267338.tar.gz
MLBEDSW-7315: Add support for AvgPool with stride_width > 3
* Convert AvgPool with stride_width > 3 and Valid padding to Conv2D to optimize it to run on NPU. Change-Id: I06ab412357f0b09b1498f9019a9d1963a324ad34 Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25b68970..a24eebc5 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -220,7 +220,7 @@ class TFLiteSupportedOperators:
# Conv specific ops:
for op_type in TFLiteSupportedOperators.convolution_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_width_no_upper_limit)
# Conv-like checks:
for op_type in TFLiteSupportedOperators.convolution_like_ops:
@@ -244,10 +244,11 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
# Pooling checks:
- for op_type in TFLiteSupportedOperators.pooling_ops:
+ for op_type in TFLiteSupportedOperators.pooling_ops - TFLiteSupportedOperators.avg_pooling_ops:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
# AVG pooling specific checks:
for op_type in TFLiteSupportedOperators.avg_pooling_ops:
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range_no_padding)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad)
self.specific_constraints[op_type].append(
@@ -545,7 +546,7 @@ class TFLiteSupportedOperators:
return True, "Op has depth_multiplier=1"
@staticmethod
- def constraint_conv_stride(op):
+ def constraint_stride_width_no_upper_limit(op):
"""Stride width must be greater than or equal to 1.
For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.
Stride height must be between 1 and 3."""
@@ -561,6 +562,17 @@ class TFLiteSupportedOperators:
return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod
+ def constraint_stride_range_no_padding(op):
+ """Stride width must be greater than or equal to 1.
+ For stride width greater than 3, valid padding needs to be used."""
+ w, _ = op.get_kernel_stride()
+ valid, message = TFLiteSupportedOperators.constraint_stride_width_no_upper_limit(op)
+ padding = op.attrs.get("padding", None)
+ is_optimized_with_valid_padding = padding in (None, Padding.VALID) or w <= 3
+ valid = valid and is_optimized_with_valid_padding
+ return valid, f"{message}, padding: {padding}"
+
+ @staticmethod
def constraint_depthwise_conv_stride(op):
"Stride values for both width and height must be between 1 and 3"
w, h = op.get_kernel_stride()
@@ -614,10 +626,11 @@ class TFLiteSupportedOperators:
def constraint_filter_range(cls, op):
"Kernel filter values for both width and height must be in the range [{}, {}]"
if op.attrs["padding"] == Padding.SAME:
+ sw, _ = op.get_kernel_stride()
w = op.kernel.width
h = op.kernel.height
filter_min, filter_max = cls.filter_range
- valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max)
+ valid = ((filter_min <= w <= filter_max) or sw == w) and (filter_min <= h <= filter_max)
return valid, f"Op has kernel filter WxH as: {w}x{h}"
return True, "Op has padding=VALID"