aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2022-11-11 18:19:53 +0000
committerTim Hall <tim.hall@arm.com>2022-11-15 17:52:08 +0000
commitea4ba666c035827aabe9a807503c185a6a9d3f0f (patch)
treedcfe7fb07e7c02fd011904d46e371108883101ff /ethosu/vela/tflite_supported_operators.py
parent16da6abddb1d791b4068b1d088beb3c5589fa722 (diff)
downloadethos-u-vela-ea4ba666c035827aabe9a807503c185a6a9d3f0f.tar.gz
MLBEDSW-6905: Add dilation greater than 2 support
- Added graph optimisation pass to support dilations greater than 2 in either dimension - Removed supported operators restrictions - Removed erroneous dilation on TRANSPOSE_CONV - Updated unit tests and documentation Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Ide302374b0d5eff25c20501383a63f6aa7625c52
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py19
1 files changed, 4 insertions, 15 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index fd8bbeef..abbfb171 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -189,7 +189,6 @@ class TFLiteSupportedOperators:
# Defined ranges for allowed values:
tens_dim_range = (1, 65535)
stride_range = (1, 3)
- dilation_range = (1, 2)
dilated_height_range = (1, 64)
dilated_product_range = (1, 64 * 64)
weights_limit = 127 * 65536
@@ -225,8 +224,10 @@ class TFLiteSupportedOperators:
# Conv-like checks:
for op_type in TFLiteSupportedOperators.convolution_like_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilation_range)
+ if op_type not in TFLiteSupportedOperators.transpose_convolution_ops:
+ # Transpose Conv has a specific stride constraint (see constraint_tconv_stride below)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
+
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_height_range)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_product_range)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
@@ -234,9 +235,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_limit)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
- # Remove stride contraint from Transpose Conv because it has a specific one (see below)
- for op_type in TFLiteSupportedOperators.transpose_convolution_ops:
- self.specific_constraints[op_type].remove(TFLiteSupportedOperators.constraint_stride_range)
# Transpose Conv specific checks:
for op_type in TFLiteSupportedOperators.transpose_convolution_ops:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_stride)
@@ -434,15 +432,6 @@ class TFLiteSupportedOperators:
return valid, f"Op has stride WxH as: {w}x{h}"
@classmethod
- @docstring_format_args(dilation_range)
- def constraint_dilation_range(cls, op):
- "Dilation factor values for both width and height must be in the range [{}, {}]"
- w, h = op.get_kernel_dilation()
- dilation_min, dilation_max = cls.dilation_range
- valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
- return valid, f"Op has dilation factor WxH as: {w}x{h}"
-
- @classmethod
@docstring_format_args(dilated_height_range)
def constraint_dilated_height_range(cls, op):
"Dilated kernel height must be in the range [{}, {}]"