diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index ea39b478..2a1eba7d 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -223,12 +223,12 @@ class TFLiteSupportedOperators: # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) + # Conv specific ops: + for op_type in TFLiteSupportedOperators.convolution_ops: + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride) + # Conv-like checks: for op_type in TFLiteSupportedOperators.convolution_like_ops: - if op_type not in TFLiteSupportedOperators.transpose_convolution_ops: - # Transpose Conv has a specific stride constraint (see constraint_tconv_stride below) - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range) - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_height_range) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_dilated_product_range) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type) @@ -237,6 +237,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_shape) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit) + # Transpose Conv specific checks: for op_type in TFLiteSupportedOperators.transpose_convolution_ops: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_stride) @@ -244,6 +245,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_tconv_valid) # Depthwise Conv specific checks: for op_type in TFLiteSupportedOperators.depthwise_convolution_ops: + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depthwise_conv_stride) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier) # Pooling checks: @@ -534,6 +536,22 @@ class TFLiteSupportedOperators: return True, "Op has depth_multiplier=1" @staticmethod + def constraint_conv_stride(op): + "Stride values for both width and height must be between 1 and 4" + w, h = op.get_kernel_stride() + stride_min, stride_max = 1, 4 + valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max) + return valid, f"Op has stride WxH as: {w}x{h}" + + @staticmethod + def constraint_depthwise_conv_stride(op): + "Stride values for both width and height must be between 1 and 3" + w, h = op.get_kernel_stride() + stride_min, stride_max = 1, 3 + valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max) + return valid, f"Op has stride WxH as: {w}x{h}" + + @staticmethod def constraint_tconv_stride(op): "Stride values for both width and height must be 2" w = op.kernel.stride.x |