aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py52
1 files changed, 23 insertions, 29 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 5676ba1c..55e718e9 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,7 +15,8 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
-from .data_type import BaseType, DataType
+from .data_type import BaseType
+from .data_type import DataType
class SupportedOperators:
@@ -51,17 +52,7 @@ class SupportedOperators:
self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
self.binary_elem_wise_add_mul_sub = set(
- (
- "AddAct",
- "MulAct",
- "SubAct",
- "QuantizedAdd",
- "QuantizedSub",
- "QuantizedMul",
- "Mul",
- "Add",
- "Sub",
- )
+ ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
)
self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
@@ -201,13 +192,13 @@ class SupportedOperators:
return False
elif op.attrs["padding"] == b"VALID":
kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
- if ((ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0))
- or (ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0))):
+ if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
+ ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
+ ):
return False
return self.check_convolution_restrictions(op)
-
def check_pooling_restrictions(self, op):
# check stride
if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
@@ -226,8 +217,9 @@ class SupportedOperators:
# check kernel size
if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
return False
- if (op.attrs["padding"] == b"VALID" and
- (op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256)):
+ if op.attrs["padding"] == b"VALID" and (
+ op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256
+ ):
return False
if op.type in self.max_pooling_ops:
@@ -259,31 +251,33 @@ class SupportedOperators:
# check data type
ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
# input and output datatype must match for these operators
- if (op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops and
- ifm_tensor.dtype != ofm_tensor.dtype):
+ if (
+ op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops
+ and ifm_tensor.dtype != ofm_tensor.dtype
+ ):
return False
- if (op.type in self.binary_elem_wise_add_mul_sub):
+ if op.type in self.binary_elem_wise_add_mul_sub:
# both inputs must have same type
- if (ifm_tensor.dtype != ifm2_tensor.dtype):
+ if ifm_tensor.dtype != ifm2_tensor.dtype:
return False
# signed input check
- if (ifm_tensor.dtype.type & BaseType.Signed):
+ if ifm_tensor.dtype.type & BaseType.Signed:
# output must be signed
- if (ofm_tensor.dtype.type & BaseType.Unsigned):
+ if ofm_tensor.dtype.type & BaseType.Unsigned:
return False
# and 8, 16 or 32-bit
- if (ofm_tensor.element_size() not in (1, 2, 4)):
+ if ofm_tensor.element_size() not in (1, 2, 4):
return False
# unsigned input check, output must be same type or int32
- if (ifm_tensor.dtype.type & BaseType.Unsigned and not
- (ifm_tensor.dtype == ofm_tensor.dtype or
- ofm_tensor.dtype == DataType.int32)):
+ if ifm_tensor.dtype.type & BaseType.Unsigned and not (
+ ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
+ ):
return False
# check batch size
if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
- return False
- if op.type in self.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
+ return False
+ if op.type in self.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
return False