aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-12-14 15:51:20 +0000
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-17 17:16:02 +0000
commit168954814fb6a1cc5e7b2d44784b24402ef30199 (patch)
tree35693aeee7c291695ba83f27db7f8d81272b787c /ethosu/vela/supported_operators.py
parentf842b69d007e70d70fc5cef3b6f1f50b4cabbd90 (diff)
downloadethos-u-vela-168954814fb6a1cc5e7b2d44784b24402ef30199.tar.gz
MLBEDSW-3694 Replace padding with enum
Use an Enum instead of a bytestring to specify VALID or SAME padding Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: I4e87f8c32b3bfac176d822a68de061e85a558fce
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index a1fcf6ab..f2c2eb9e 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -24,6 +24,7 @@ from .data_type import DataType
from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
+from .operation import Padding
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -569,7 +570,7 @@ class SupportedOperators:
@staticmethod
def constraint_tconv_same(op):
"SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.stride.x
h = op.kernel.stride.y
ifm_shape = op.ifm.shape
@@ -582,7 +583,7 @@ class SupportedOperators:
def constraint_tconv_valid(op):
"""VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
minus difference between kernel size and stride"""
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
s_w = op.kernel.stride.x
s_h = op.kernel.stride.y
k_w = op.kernel.width
@@ -626,7 +627,7 @@ class SupportedOperators:
@docstring_format_args(filter_range)
def constraint_filter_range(cls, op):
"Kernel filter values for both width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.width
h = op.kernel.height
filter_min, filter_max = cls.filter_range
@@ -656,7 +657,7 @@ class SupportedOperators:
@docstring_format_args(filter_height_range)
def constraint_filter_height_range_valid_pad(op):
"VALID padding: Kernel filter height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_height_range(op)
return True, "Op has padding=SAME"
@@ -664,7 +665,7 @@ class SupportedOperators:
@docstring_format_args(filter_product_range)
def constraint_filter_product_range_valid_pad(op):
"VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_product_range(op)
return True, "Op has padding=SAME"