diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index a1fcf6ab..f2c2eb9e 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -24,6 +24,7 @@ from .data_type import DataType from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op +from .operation import Padding from .tensor import check_quantized_tens_scaling_equal from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN from .tflite_mapping import optype_to_builtintype @@ -569,7 +570,7 @@ class SupportedOperators: @staticmethod def constraint_tconv_same(op): "SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride" - if op.attrs["padding"] == b"SAME": + if op.attrs["padding"] == Padding.SAME: w = op.kernel.stride.x h = op.kernel.stride.y ifm_shape = op.ifm.shape @@ -582,7 +583,7 @@ class SupportedOperators: def constraint_tconv_valid(op): """VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride, minus difference between kernel size and stride""" - if op.attrs["padding"] == b"VALID": + if op.attrs["padding"] == Padding.VALID: s_w = op.kernel.stride.x s_h = op.kernel.stride.y k_w = op.kernel.width @@ -626,7 +627,7 @@ class SupportedOperators: @docstring_format_args(filter_range) def constraint_filter_range(cls, op): "Kernel filter values for both width and height must be in the range [{}, {}]" - if op.attrs["padding"] == b"SAME": + if op.attrs["padding"] == Padding.SAME: w = op.kernel.width h = op.kernel.height filter_min, filter_max = cls.filter_range @@ -656,7 +657,7 @@ class SupportedOperators: @docstring_format_args(filter_height_range) def constraint_filter_height_range_valid_pad(op): "VALID padding: Kernel filter height must be in the range [{}, {}]" - if op.attrs["padding"] == b"VALID": + if op.attrs["padding"] == Padding.VALID: return SupportedOperators.constraint_filter_height_range(op) return True, "Op has padding=SAME" @@ -664,7 +665,7 @@ class SupportedOperators: @docstring_format_args(filter_product_range) def constraint_filter_product_range_valid_pad(op): "VALID padding: Product of kernel filter width and height must be in the range [{}, {}]" - if op.attrs["padding"] == b"VALID": + if op.attrs["padding"] == Padding.VALID: return SupportedOperators.constraint_filter_product_range(op) return True, "Op has padding=SAME" |